@@ -48,10 +48,15 @@ def pull(self, arm_index: int) -> int:
48
48
Pull an arm of the bandit.
49
49
50
50
Args:
51
- arm : The arm to pull.
51
+ arm_index : The arm to pull.
52
52
53
53
Returns:
54
54
The reward for the arm.
55
+
56
+ Example:
57
+ >>> bandit = Bandit([0.1, 0.5, 0.9])
58
+ >>> isinstance(bandit.pull(0), int)
59
+ True
55
60
"""
56
61
rng = np .random .default_rng ()
57
62
return 1 if rng .random () < self .probabilities [arm_index ] else 0
@@ -86,6 +91,11 @@ def select_arm(self):
86
91
87
92
Returns:
88
93
The index of the arm to pull.
94
+
95
+ Example:
96
+ >>> strategy = EpsilonGreedy(epsilon=0.1, k=3)
97
+ >>> 0 <= strategy.select_arm() < 3
98
+ True
89
99
"""
90
100
rng = np .random .default_rng ()
91
101
@@ -101,6 +111,12 @@ def update(self, arm_index: int, reward: int):
101
111
Args:
102
112
arm_index: The index of the arm to pull.
103
113
reward: The reward for the arm.
114
+
115
+ Example:
116
+ >>> strategy = EpsilonGreedy(epsilon=0.1, k=3)
117
+ >>> strategy.update(0, 1)
118
+ >>> strategy.counts[0] == 1
119
+ True
104
120
"""
105
121
self .counts [arm_index ] += 1
106
122
n = self .counts [arm_index ]
@@ -135,6 +151,11 @@ def select_arm(self):
135
151
136
152
Returns:
137
153
The index of the arm to pull.
154
+
155
+ Example:
156
+ >>> strategy = UCB(k=3)
157
+ >>> 0 <= strategy.select_arm() < 3
158
+ True
138
159
"""
139
160
if self .total_counts < self .k :
140
161
return self .total_counts
@@ -149,6 +170,12 @@ def update(self, arm_index: int, reward: int):
149
170
Args:
150
171
arm_index: The index of the arm to pull.
151
172
reward: The reward for the arm.
173
+
174
+ Example:
175
+ >>> strategy = UCB(k=3)
176
+ >>> strategy.update(0, 1)
177
+ >>> strategy.counts[0] == 1
178
+ True
152
179
"""
153
180
self .counts [arm_index ] += 1
154
181
self .total_counts += 1
@@ -184,6 +211,11 @@ def select_arm(self):
184
211
Returns:
185
212
The index of the arm to pull based on the Thompson Sampling strategy
186
213
which relies on the Beta distribution.
214
+
215
+ Example:
216
+ >>> strategy = ThompsonSampling(k=3)
217
+ >>> 0 <= strategy.select_arm() < 3
218
+ True
187
219
"""
188
220
rng = np .random .default_rng ()
189
221
@@ -199,6 +231,12 @@ def update(self, arm_index: int, reward: int):
199
231
Args:
200
232
arm_index: The index of the arm to pull.
201
233
reward: The reward for the arm.
234
+
235
+ Example:
236
+ >>> strategy = ThompsonSampling(k=3)
237
+ >>> strategy.update(0, 1)
238
+ >>> strategy.successes[0] == 1
239
+ True
202
240
"""
203
241
if reward == 1 :
204
242
self .successes [arm_index ] += 1
@@ -210,7 +248,7 @@ def update(self, arm_index: int, reward: int):
210
248
class RandomStrategy :
211
249
"""
212
250
A class for choosing totally random at each round to give
213
- a better comparison with the other optimisedstrategies .
251
+ a better comparison with the other optimised strategies .
214
252
"""
215
253
216
254
def __init__ (self , k : int ):
@@ -228,6 +266,11 @@ def select_arm(self):
228
266
229
267
Returns:
230
268
The index of the arm to pull.
269
+
270
+ Example:
271
+ >>> strategy = RandomStrategy(k=3)
272
+ >>> 0 <= strategy.select_arm() < 3
273
+ True
231
274
"""
232
275
rng = np .random .default_rng ()
233
276
return rng .integers (self .k )
@@ -239,6 +282,10 @@ def update(self, arm_index: int, reward: int):
239
282
Args:
240
283
arm_index: The index of the arm to pull.
241
284
reward: The reward for the arm.
285
+
286
+ Example:
287
+ >>> strategy = RandomStrategy(k=3)
288
+ >>> strategy.update(0, 1)
242
289
"""
243
290
244
291
@@ -268,6 +315,11 @@ def select_arm(self):
268
315
269
316
Returns:
270
317
The index of the arm to pull.
318
+
319
+ Example:
320
+ >>> strategy = GreedyStrategy(k=3)
321
+ >>> 0 <= strategy.select_arm() < 3
322
+ True
271
323
"""
272
324
return np .argmax (self .values )
273
325
@@ -278,6 +330,12 @@ def update(self, arm_index: int, reward: int):
278
330
Args:
279
331
arm_index: The index of the arm to pull.
280
332
reward: The reward for the arm.
333
+
334
+ Example:
335
+ >>> strategy = GreedyStrategy(k=3)
336
+ >>> strategy.update(0, 1)
337
+ >>> strategy.counts[0] == 1
338
+ True
281
339
"""
282
340
self .counts [arm_index ] += 1
283
341
n = self .counts [arm_index ]
@@ -329,4 +387,6 @@ def test_mab_strategies():
329
387
330
388
331
389
if __name__ == "__main__" :
390
+ import doctest
391
+ doctest .testmod ()
332
392
test_mab_strategies ()
0 commit comments