Skip to content

Commit c1ed3c0

Browse files
committed
added multi arm bandit alg with three strategies to solve it
1 parent a4576dc commit c1ed3c0

File tree

1 file changed

+332
-0
lines changed

1 file changed

+332
-0
lines changed

machine_learning/mab.py

+332
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
"""
2+
Multi-Armed Bandit (MAB) is a problem in reinforcement learning where an agent must
3+
learn to choose the best action from a set of actions to maximize its reward.
4+
5+
learn more here: https://en.wikipedia.org/wiki/Multi-armed_bandit
6+
7+
8+
The MAB problem can be described as follows:
9+
- There are N arms, each with a different probability of giving a reward.
10+
- The agent must learn to choose the best arm to pull in order to maximize its reward.
11+
12+
Here there are 3 optimising strategies have been implemented:
13+
- Epsilon-Greedy
14+
- Upper Confidence Bound (UCB)
15+
- Thompson Sampling
16+
17+
There are two other strategies implemented to show the performance of
18+
the optimising strategies:
19+
- Random strategy (full exploration)
20+
- Greedy strategy (full exploitation)
21+
22+
The performance of the strategies is evaluated by the cumulative reward
23+
over a number of rounds.
24+
25+
"""
26+
27+
import matplotlib.pyplot as plt
28+
import numpy as np
29+
30+
31+
class Bandit:
32+
"""
33+
A class to represent a multi-armed bandit.
34+
"""
35+
36+
def __init__(self, probabilities: list[float]):
37+
"""
38+
Initialize the bandit with a list of probabilities for each arm.
39+
40+
Args:
41+
probabilities: List of probabilities for each arm.
42+
"""
43+
self.probabilities = probabilities
44+
self.k = len(probabilities)
45+
46+
def pull(self, arm_index: int) -> int:
47+
"""
48+
Pull an arm of the bandit.
49+
50+
Args:
51+
arm: The arm to pull.
52+
53+
Returns:
54+
The reward for the arm.
55+
"""
56+
rng = np.random.default_rng()
57+
return 1 if rng.random() < self.probabilities[arm_index] else 0
58+
59+
60+
# Epsilon-Greedy strategy
61+
62+
63+
class EpsilonGreedy:
64+
"""
65+
A class for a simple implementation of the Epsilon-Greedy strategy.
66+
Follow this link to learn more:
67+
https://medium.com/analytics-vidhya/the-epsilon-greedy-algorithm-for-reinforcement-learning-5fe6f96dc870
68+
"""
69+
70+
def __init__(self, epsilon: float, k: int):
71+
"""
72+
Initialize the Epsilon-Greedy strategy.
73+
74+
Args:
75+
epsilon: The probability of exploring new arms.
76+
k: The number of arms.
77+
"""
78+
self.epsilon = epsilon
79+
self.k = k
80+
self.counts = np.zeros(k)
81+
self.values = np.zeros(k)
82+
83+
def select_arm(self):
84+
"""
85+
Select an arm to pull.
86+
87+
Returns:
88+
The index of the arm to pull.
89+
"""
90+
rng = np.random.default_rng()
91+
92+
if rng.random() < self.epsilon:
93+
return rng.integers(self.k)
94+
else:
95+
return np.argmax(self.values)
96+
97+
def update(self, arm_index: int, reward: int):
98+
"""
99+
Update the strategy.
100+
101+
Args:
102+
arm_index: The index of the arm to pull.
103+
reward: The reward for the arm.
104+
"""
105+
self.counts[arm_index] += 1
106+
n = self.counts[arm_index]
107+
self.values[arm_index] += (reward - self.values[arm_index]) / n
108+
109+
110+
# Upper Confidence Bound (UCB)
111+
112+
113+
class UCB:
114+
"""
115+
A class for the Upper Confidence Bound (UCB) strategy.
116+
Follow this link to learn more:
117+
https://people.maths.bris.ac.uk/~maajg/teaching/stochopt/ucb.pdf
118+
"""
119+
120+
def __init__(self, k: int):
121+
"""
122+
Initialize the UCB strategy.
123+
124+
Args:
125+
k: The number of arms.
126+
"""
127+
self.k = k
128+
self.counts = np.zeros(k)
129+
self.values = np.zeros(k)
130+
self.total_counts = 0
131+
132+
def select_arm(self):
133+
"""
134+
Select an arm to pull.
135+
136+
Returns:
137+
The index of the arm to pull.
138+
"""
139+
if self.total_counts < self.k:
140+
return self.total_counts
141+
ucb_values = self.values + \
142+
np.sqrt(2 * np.log(self.total_counts) / self.counts)
143+
return np.argmax(ucb_values)
144+
145+
def update(self, arm_index: int, reward: int):
146+
"""
147+
Update the strategy.
148+
149+
Args:
150+
arm_index: The index of the arm to pull.
151+
reward: The reward for the arm.
152+
"""
153+
self.counts[arm_index] += 1
154+
self.total_counts += 1
155+
n = self.counts[arm_index]
156+
self.values[arm_index] += (reward - self.values[arm_index]) / n
157+
158+
159+
# Thompson Sampling
160+
161+
162+
class ThompsonSampling:
163+
"""
164+
A class for the Thompson Sampling strategy.
165+
Follow this link to learn more:
166+
https://en.wikipedia.org/wiki/Thompson_sampling
167+
"""
168+
169+
def __init__(self, k: int):
170+
"""
171+
Initialize the Thompson Sampling strategy.
172+
173+
Args:
174+
k: The number of arms.
175+
"""
176+
self.k = k
177+
self.successes = np.zeros(k)
178+
self.failures = np.zeros(k)
179+
180+
def select_arm(self):
181+
"""
182+
Select an arm to pull.
183+
184+
Returns:
185+
The index of the arm to pull based on the Thompson Sampling strategy
186+
which relies on the Beta distribution.
187+
"""
188+
rng = np.random.default_rng()
189+
190+
samples = [
191+
rng.beta(self.successes[i] + 1, self.failures[i] + 1) for i in range(self.k)
192+
]
193+
return np.argmax(samples)
194+
195+
def update(self, arm_index: int, reward: int):
196+
"""
197+
Update the strategy.
198+
199+
Args:
200+
arm_index: The index of the arm to pull.
201+
reward: The reward for the arm.
202+
"""
203+
if reward == 1:
204+
self.successes[arm_index] += 1
205+
else:
206+
self.failures[arm_index] += 1
207+
208+
209+
# Random strategy (full exploration)
210+
class RandomStrategy:
211+
"""
212+
A class for choosing totally random at each round to give
213+
a better comparison with the other optimisedstrategies.
214+
"""
215+
216+
def __init__(self, k: int):
217+
"""
218+
Initialize the Random strategy.
219+
220+
Args:
221+
k: The number of arms.
222+
"""
223+
self.k = k
224+
225+
def select_arm(self):
226+
"""
227+
Select an arm to pull.
228+
229+
Returns:
230+
The index of the arm to pull.
231+
"""
232+
rng = np.random.default_rng()
233+
return rng.integers(self.k)
234+
235+
def update(self, arm_index: int, reward: int):
236+
"""
237+
Update the strategy.
238+
239+
Args:
240+
arm_index: The index of the arm to pull.
241+
reward: The reward for the arm.
242+
"""
243+
244+
245+
# Greedy strategy (full exploitation)
246+
247+
248+
class GreedyStrategy:
249+
"""
250+
A class for the Greedy strategy to show how full exploitation can be
251+
detrimental to the performance of the strategy.
252+
"""
253+
254+
def __init__(self, k: int):
255+
"""
256+
Initialize the Greedy strategy.
257+
258+
Args:
259+
k: The number of arms.
260+
"""
261+
self.k = k
262+
self.counts = np.zeros(k)
263+
self.values = np.zeros(k)
264+
265+
def select_arm(self):
266+
"""
267+
Select an arm to pull.
268+
269+
Returns:
270+
The index of the arm to pull.
271+
"""
272+
return np.argmax(self.values)
273+
274+
def update(self, arm_index: int, reward: int):
275+
"""
276+
Update the strategy.
277+
278+
Args:
279+
arm_index: The index of the arm to pull.
280+
reward: The reward for the arm.
281+
"""
282+
self.counts[arm_index] += 1
283+
n = self.counts[arm_index]
284+
self.values[arm_index] += (reward - self.values[arm_index]) / n
285+
286+
287+
def test_mab_strategies():
288+
"""
289+
Test the MAB strategies.
290+
"""
291+
# Simulation
292+
k = 4
293+
arms_probabilities = [0.1, 0.3, 0.5, 0.8] # True probabilities
294+
295+
bandit = Bandit(arms_probabilities)
296+
strategies = {
297+
"Epsilon-Greedy": EpsilonGreedy(epsilon=0.1, k=k),
298+
"UCB": UCB(k=k),
299+
"Thompson Sampling": ThompsonSampling(k=k),
300+
"Full Exploration(Random)": RandomStrategy(k=k),
301+
"Full Exploitation(Greedy)": GreedyStrategy(k=k),
302+
}
303+
304+
num_rounds = 1000
305+
results = {}
306+
307+
for name, strategy in strategies.items():
308+
rewards = []
309+
total_reward = 0
310+
for _ in range(num_rounds):
311+
arm = strategy.select_arm()
312+
current_reward = bandit.pull(arm)
313+
strategy.update(arm, current_reward)
314+
total_reward += current_reward
315+
rewards.append(total_reward)
316+
results[name] = rewards
317+
318+
# Plotting results
319+
plt.figure(figsize=(12, 6))
320+
for name, rewards in results.items():
321+
plt.plot(rewards, label=name)
322+
323+
plt.title("Cumulative Reward of Multi-Armed Bandit Strategies")
324+
plt.xlabel("Round")
325+
plt.ylabel("Cumulative Reward")
326+
plt.legend()
327+
plt.grid()
328+
plt.show()
329+
330+
331+
if __name__ == "__main__":
332+
test_mab_strategies()

0 commit comments

Comments
 (0)