Skip to content

added multi armed bandit problem with three strategies to solve it #12668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Prev Previous commit
Next Next commit
added doctest tests
  • Loading branch information
sephml committed Apr 11, 2025
commit ddbce9174f71c668ac8df6ec9074bac16d544f73
64 changes: 62 additions & 2 deletions machine_learning/mab.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,15 @@ def pull(self, arm_index: int) -> int:
Pull an arm of the bandit.

Args:
arm: The arm to pull.
arm_index: The arm to pull.

Returns:
The reward for the arm.

Example:
>>> bandit = Bandit([0.1, 0.5, 0.9])
>>> isinstance(bandit.pull(0), int)
True
"""
rng = np.random.default_rng()
return 1 if rng.random() < self.probabilities[arm_index] else 0
Expand Down Expand Up @@ -86,6 +91,11 @@ def select_arm(self):

Returns:
The index of the arm to pull.

Example:
>>> strategy = EpsilonGreedy(epsilon=0.1, k=3)
>>> 0 <= strategy.select_arm() < 3
True
"""
rng = np.random.default_rng()

Expand All @@ -101,6 +111,12 @@ def update(self, arm_index: int, reward: int):
Args:
arm_index: The index of the arm to pull.
reward: The reward for the arm.

Example:
>>> strategy = EpsilonGreedy(epsilon=0.1, k=3)
>>> strategy.update(0, 1)
>>> strategy.counts[0] == 1
True
"""
self.counts[arm_index] += 1
n = self.counts[arm_index]
Expand Down Expand Up @@ -135,6 +151,11 @@ def select_arm(self):

Returns:
The index of the arm to pull.

Example:
>>> strategy = UCB(k=3)
>>> 0 <= strategy.select_arm() < 3
True
"""
if self.total_counts < self.k:
return self.total_counts
Expand All @@ -149,6 +170,12 @@ def update(self, arm_index: int, reward: int):
Args:
arm_index: The index of the arm to pull.
reward: The reward for the arm.

Example:
>>> strategy = UCB(k=3)
>>> strategy.update(0, 1)
>>> strategy.counts[0] == 1
True
"""
self.counts[arm_index] += 1
self.total_counts += 1
Expand Down Expand Up @@ -184,6 +211,11 @@ def select_arm(self):
Returns:
The index of the arm to pull based on the Thompson Sampling strategy
which relies on the Beta distribution.

Example:
>>> strategy = ThompsonSampling(k=3)
>>> 0 <= strategy.select_arm() < 3
True
"""
rng = np.random.default_rng()

Expand All @@ -199,6 +231,12 @@ def update(self, arm_index: int, reward: int):
Args:
arm_index: The index of the arm to pull.
reward: The reward for the arm.

Example:
>>> strategy = ThompsonSampling(k=3)
>>> strategy.update(0, 1)
>>> strategy.successes[0] == 1
True
"""
if reward == 1:
self.successes[arm_index] += 1
Expand All @@ -210,7 +248,7 @@ def update(self, arm_index: int, reward: int):
class RandomStrategy:
"""
A class for choosing totally random at each round to give
a better comparison with the other optimisedstrategies.
a better comparison with the other optimised strategies.
"""

def __init__(self, k: int):
Expand All @@ -228,6 +266,11 @@ def select_arm(self):

Returns:
The index of the arm to pull.

Example:
>>> strategy = RandomStrategy(k=3)
>>> 0 <= strategy.select_arm() < 3
True
"""
rng = np.random.default_rng()
return rng.integers(self.k)
Expand All @@ -239,6 +282,10 @@ def update(self, arm_index: int, reward: int):
Args:
arm_index: The index of the arm to pull.
reward: The reward for the arm.

Example:
>>> strategy = RandomStrategy(k=3)
>>> strategy.update(0, 1)
"""


Expand Down Expand Up @@ -268,6 +315,11 @@ def select_arm(self):

Returns:
The index of the arm to pull.

Example:
>>> strategy = GreedyStrategy(k=3)
>>> 0 <= strategy.select_arm() < 3
True
"""
return np.argmax(self.values)

Expand All @@ -278,6 +330,12 @@ def update(self, arm_index: int, reward: int):
Args:
arm_index: The index of the arm to pull.
reward: The reward for the arm.

Example:
>>> strategy = GreedyStrategy(k=3)
>>> strategy.update(0, 1)
>>> strategy.counts[0] == 1
True
"""
self.counts[arm_index] += 1
n = self.counts[arm_index]
Expand Down Expand Up @@ -329,4 +387,6 @@ def test_mab_strategies():


if __name__ == "__main__":
import doctest
doctest.testmod()
test_mab_strategies()
Loading