Skip to content

Commit ddbce91

Browse files
committed
added doctest tests
1 parent c1ed3c0 commit ddbce91

File tree

1 file changed

+62
-2
lines changed

1 file changed

+62
-2
lines changed

machine_learning/mab.py

+62-2
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,15 @@ def pull(self, arm_index: int) -> int:
4848
Pull an arm of the bandit.
4949
5050
Args:
51-
arm: The arm to pull.
51+
arm_index: The arm to pull.
5252
5353
Returns:
5454
The reward for the arm.
55+
56+
Example:
57+
>>> bandit = Bandit([0.1, 0.5, 0.9])
58+
>>> isinstance(bandit.pull(0), int)
59+
True
5560
"""
5661
rng = np.random.default_rng()
5762
return 1 if rng.random() < self.probabilities[arm_index] else 0
@@ -86,6 +91,11 @@ def select_arm(self):
8691
8792
Returns:
8893
The index of the arm to pull.
94+
95+
Example:
96+
>>> strategy = EpsilonGreedy(epsilon=0.1, k=3)
97+
>>> 0 <= strategy.select_arm() < 3
98+
True
8999
"""
90100
rng = np.random.default_rng()
91101

@@ -101,6 +111,12 @@ def update(self, arm_index: int, reward: int):
101111
Args:
102112
arm_index: The index of the arm to pull.
103113
reward: The reward for the arm.
114+
115+
Example:
116+
>>> strategy = EpsilonGreedy(epsilon=0.1, k=3)
117+
>>> strategy.update(0, 1)
118+
>>> strategy.counts[0] == 1
119+
True
104120
"""
105121
self.counts[arm_index] += 1
106122
n = self.counts[arm_index]
@@ -135,6 +151,11 @@ def select_arm(self):
135151
136152
Returns:
137153
The index of the arm to pull.
154+
155+
Example:
156+
>>> strategy = UCB(k=3)
157+
>>> 0 <= strategy.select_arm() < 3
158+
True
138159
"""
139160
if self.total_counts < self.k:
140161
return self.total_counts
@@ -149,6 +170,12 @@ def update(self, arm_index: int, reward: int):
149170
Args:
150171
arm_index: The index of the arm to pull.
151172
reward: The reward for the arm.
173+
174+
Example:
175+
>>> strategy = UCB(k=3)
176+
>>> strategy.update(0, 1)
177+
>>> strategy.counts[0] == 1
178+
True
152179
"""
153180
self.counts[arm_index] += 1
154181
self.total_counts += 1
@@ -184,6 +211,11 @@ def select_arm(self):
184211
Returns:
185212
The index of the arm to pull based on the Thompson Sampling strategy
186213
which relies on the Beta distribution.
214+
215+
Example:
216+
>>> strategy = ThompsonSampling(k=3)
217+
>>> 0 <= strategy.select_arm() < 3
218+
True
187219
"""
188220
rng = np.random.default_rng()
189221

@@ -199,6 +231,12 @@ def update(self, arm_index: int, reward: int):
199231
Args:
200232
arm_index: The index of the arm to pull.
201233
reward: The reward for the arm.
234+
235+
Example:
236+
>>> strategy = ThompsonSampling(k=3)
237+
>>> strategy.update(0, 1)
238+
>>> strategy.successes[0] == 1
239+
True
202240
"""
203241
if reward == 1:
204242
self.successes[arm_index] += 1
@@ -210,7 +248,7 @@ def update(self, arm_index: int, reward: int):
210248
class RandomStrategy:
211249
"""
212250
A class for choosing totally random at each round to give
213-
a better comparison with the other optimisedstrategies.
251+
a better comparison with the other optimised strategies.
214252
"""
215253

216254
def __init__(self, k: int):
@@ -228,6 +266,11 @@ def select_arm(self):
228266
229267
Returns:
230268
The index of the arm to pull.
269+
270+
Example:
271+
>>> strategy = RandomStrategy(k=3)
272+
>>> 0 <= strategy.select_arm() < 3
273+
True
231274
"""
232275
rng = np.random.default_rng()
233276
return rng.integers(self.k)
@@ -239,6 +282,10 @@ def update(self, arm_index: int, reward: int):
239282
Args:
240283
arm_index: The index of the arm to pull.
241284
reward: The reward for the arm.
285+
286+
Example:
287+
>>> strategy = RandomStrategy(k=3)
288+
>>> strategy.update(0, 1)
242289
"""
243290

244291

@@ -268,6 +315,11 @@ def select_arm(self):
268315
269316
Returns:
270317
The index of the arm to pull.
318+
319+
Example:
320+
>>> strategy = GreedyStrategy(k=3)
321+
>>> 0 <= strategy.select_arm() < 3
322+
True
271323
"""
272324
return np.argmax(self.values)
273325

@@ -278,6 +330,12 @@ def update(self, arm_index: int, reward: int):
278330
Args:
279331
arm_index: The index of the arm to pull.
280332
reward: The reward for the arm.
333+
334+
Example:
335+
>>> strategy = GreedyStrategy(k=3)
336+
>>> strategy.update(0, 1)
337+
>>> strategy.counts[0] == 1
338+
True
281339
"""
282340
self.counts[arm_index] += 1
283341
n = self.counts[arm_index]
@@ -329,4 +387,6 @@ def test_mab_strategies():
329387

330388

331389
if __name__ == "__main__":
390+
import doctest
391+
doctest.testmod()
332392
test_mab_strategies()

0 commit comments

Comments
 (0)