Skip to content

added multi armed bandit problem with three strategies to solve it #12668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Prev Previous commit
Next Next commit
corrected test cases
  • Loading branch information
sephml committed Apr 13, 2025
commit 9fdf39fe773b1ebf2c77f066d412a9a3f265adfc
16 changes: 8 additions & 8 deletions machine_learning/mab.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def select_arm(self):
Example:
>>> strategy = EpsilonGreedy(epsilon=0.1, k=3)
>>> 0 <= strategy.select_arm() < 3
True
np.True_
"""
rng = np.random.default_rng()

Expand All @@ -116,7 +116,7 @@ def update(self, arm_index: int, reward: int):
>>> strategy = EpsilonGreedy(epsilon=0.1, k=3)
>>> strategy.update(0, 1)
>>> strategy.counts[0] == 1
True
np.True_
"""
self.counts[arm_index] += 1
n = self.counts[arm_index]
Expand Down Expand Up @@ -175,7 +175,7 @@ def update(self, arm_index: int, reward: int):
>>> strategy = UCB(k=3)
>>> strategy.update(0, 1)
>>> strategy.counts[0] == 1
True
np.True_
"""
self.counts[arm_index] += 1
self.total_counts += 1
Expand Down Expand Up @@ -215,7 +215,7 @@ def select_arm(self):
Example:
>>> strategy = ThompsonSampling(k=3)
>>> 0 <= strategy.select_arm() < 3
True
np.True_
"""
rng = np.random.default_rng()

Expand All @@ -236,7 +236,7 @@ def update(self, arm_index: int, reward: int):
>>> strategy = ThompsonSampling(k=3)
>>> strategy.update(0, 1)
>>> strategy.successes[0] == 1
True
np.True_
"""
if reward == 1:
self.successes[arm_index] += 1
Expand Down Expand Up @@ -270,7 +270,7 @@ def select_arm(self):
Example:
>>> strategy = RandomStrategy(k=3)
>>> 0 <= strategy.select_arm() < 3
True
np.True_
"""
rng = np.random.default_rng()
return rng.integers(self.k)
Expand Down Expand Up @@ -319,7 +319,7 @@ def select_arm(self):
Example:
>>> strategy = GreedyStrategy(k=3)
>>> 0 <= strategy.select_arm() < 3
True
np.True_
"""
return np.argmax(self.values)

Expand All @@ -335,7 +335,7 @@ def update(self, arm_index: int, reward: int):
>>> strategy = GreedyStrategy(k=3)
>>> strategy.update(0, 1)
>>> strategy.counts[0] == 1
True
np.True_
"""
self.counts[arm_index] += 1
n = self.counts[arm_index]
Expand Down