-
-
Notifications
You must be signed in to change notification settings - Fork 46.9k
added multi armed bandit problem with three strategies to solve it #12668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sephml
wants to merge
15
commits into
TheAlgorithms:master
Choose a base branch
from
sephml:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 9 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
c1ed3c0
added multi arm bandit alg with three strategies to solve it
sephml ddbce91
added doctest tests
sephml 46fdb1b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 9fdf39f
corrected test cases
sephml 81d197d
Merge branch 'master' of https://github.com/sephml/Python
sephml f80b843
added return type hinting
sephml f2d9038
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 9d7a028
return typehint for test func updated
sephml a824511
Merge branch 'master' of https://github.com/sephml/Python
sephml 7343268
fixed variable name k
sephml d0b6719
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ef11ca4
fixed formatting
sephml 4167ddb
Merge branch 'master' of https://github.com/sephml/Python
sephml c34feff
fix1
sephml c243cd8
fixed issues with mypy, ruff
sephml File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,392 @@ | ||
""" | ||
Multi-Armed Bandit (MAB) is a problem in reinforcement learning where an agent must | ||
learn to choose the best action from a set of actions to maximize its reward. | ||
|
||
learn more here: https://en.wikipedia.org/wiki/Multi-armed_bandit | ||
|
||
|
||
The MAB problem can be described as follows: | ||
- There are N arms, each with a different probability of giving a reward. | ||
- The agent must learn to choose the best arm to pull in order to maximize its reward. | ||
|
||
Here there are 3 optimising strategies have been implemented: | ||
- Epsilon-Greedy | ||
- Upper Confidence Bound (UCB) | ||
- Thompson Sampling | ||
|
||
There are two other strategies implemented to show the performance of | ||
the optimising strategies: | ||
- Random strategy (full exploration) | ||
- Greedy strategy (full exploitation) | ||
|
||
The performance of the strategies is evaluated by the cumulative reward | ||
over a number of rounds. | ||
|
||
""" | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
|
||
class Bandit: | ||
""" | ||
A class to represent a multi-armed bandit. | ||
""" | ||
|
||
def __init__(self, probabilities: list[float]) -> None: | ||
""" | ||
Initialize the bandit with a list of probabilities for each arm. | ||
|
||
Args: | ||
probabilities: List of probabilities for each arm. | ||
""" | ||
self.probabilities = probabilities | ||
self.k = len(probabilities) | ||
|
||
def pull(self, arm_index: int) -> int: | ||
""" | ||
Pull an arm of the bandit. | ||
|
||
Args: | ||
arm_index: The arm to pull. | ||
|
||
Returns: | ||
The reward for the arm. | ||
|
||
Example: | ||
>>> bandit = Bandit([0.1, 0.5, 0.9]) | ||
>>> isinstance(bandit.pull(0), int) | ||
True | ||
""" | ||
rng = np.random.default_rng() | ||
return 1 if rng.random() < self.probabilities[arm_index] else 0 | ||
|
||
|
||
# Epsilon-Greedy strategy | ||
|
||
|
||
class EpsilonGreedy: | ||
""" | ||
A class for a simple implementation of the Epsilon-Greedy strategy. | ||
Follow this link to learn more: | ||
https://medium.com/analytics-vidhya/the-epsilon-greedy-algorithm-for-reinforcement-learning-5fe6f96dc870 | ||
""" | ||
|
||
def __init__(self, epsilon: float, k: int) -> None: | ||
""" | ||
Initialize the Epsilon-Greedy strategy. | ||
|
||
Args: | ||
epsilon: The probability of exploring new arms. | ||
k: The number of arms. | ||
""" | ||
self.epsilon = epsilon | ||
self.k = k | ||
self.counts = np.zeros(k) | ||
self.values = np.zeros(k) | ||
|
||
def select_arm(self) -> int: | ||
""" | ||
Select an arm to pull. | ||
|
||
Returns: | ||
The index of the arm to pull. | ||
|
||
Example: | ||
>>> strategy = EpsilonGreedy(epsilon=0.1, k=3) | ||
>>> 0 <= strategy.select_arm() < 3 | ||
np.True_ | ||
""" | ||
rng = np.random.default_rng() | ||
|
||
if rng.random() < self.epsilon: | ||
return rng.integers(self.k) | ||
else: | ||
return np.argmax(self.values) | ||
|
||
def update(self, arm_index: int, reward: int) -> None: | ||
""" | ||
Update the strategy. | ||
|
||
Args: | ||
arm_index: The index of the arm to pull. | ||
reward: The reward for the arm. | ||
|
||
Example: | ||
>>> strategy = EpsilonGreedy(epsilon=0.1, k=3) | ||
>>> strategy.update(0, 1) | ||
>>> strategy.counts[0] == 1 | ||
np.True_ | ||
""" | ||
self.counts[arm_index] += 1 | ||
n = self.counts[arm_index] | ||
self.values[arm_index] += (reward - self.values[arm_index]) / n | ||
|
||
|
||
# Upper Confidence Bound (UCB) | ||
|
||
|
||
class UCB: | ||
""" | ||
A class for the Upper Confidence Bound (UCB) strategy. | ||
Follow this link to learn more: | ||
https://people.maths.bris.ac.uk/~maajg/teaching/stochopt/ucb.pdf | ||
""" | ||
|
||
def __init__(self, k: int) -> None: | ||
sephml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Initialize the UCB strategy. | ||
|
||
Args: | ||
k: The number of arms. | ||
""" | ||
self.k = k | ||
self.counts = np.zeros(k) | ||
self.values = np.zeros(k) | ||
self.total_counts = 0 | ||
|
||
def select_arm(self) -> int: | ||
""" | ||
Select an arm to pull. | ||
|
||
Returns: | ||
The index of the arm to pull. | ||
|
||
Example: | ||
>>> strategy = UCB(k=3) | ||
>>> 0 <= strategy.select_arm() < 3 | ||
True | ||
""" | ||
if self.total_counts < self.k: | ||
return self.total_counts | ||
ucb_values = self.values + np.sqrt(2 * np.log(self.total_counts) / self.counts) | ||
return np.argmax(ucb_values) | ||
|
||
def update(self, arm_index: int, reward: int) -> None: | ||
""" | ||
Update the strategy. | ||
|
||
Args: | ||
arm_index: The index of the arm to pull. | ||
reward: The reward for the arm. | ||
|
||
Example: | ||
>>> strategy = UCB(k=3) | ||
>>> strategy.update(0, 1) | ||
>>> strategy.counts[0] == 1 | ||
np.True_ | ||
""" | ||
self.counts[arm_index] += 1 | ||
self.total_counts += 1 | ||
n = self.counts[arm_index] | ||
self.values[arm_index] += (reward - self.values[arm_index]) / n | ||
|
||
|
||
# Thompson Sampling | ||
|
||
|
||
class ThompsonSampling: | ||
""" | ||
A class for the Thompson Sampling strategy. | ||
Follow this link to learn more: | ||
https://en.wikipedia.org/wiki/Thompson_sampling | ||
""" | ||
|
||
def __init__(self, k: int) -> None: | ||
sephml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Initialize the Thompson Sampling strategy. | ||
|
||
Args: | ||
k: The number of arms. | ||
""" | ||
self.k = k | ||
self.successes = np.zeros(k) | ||
self.failures = np.zeros(k) | ||
|
||
def select_arm(self) -> int: | ||
""" | ||
Select an arm to pull. | ||
|
||
Returns: | ||
The index of the arm to pull based on the Thompson Sampling strategy | ||
which relies on the Beta distribution. | ||
|
||
Example: | ||
>>> strategy = ThompsonSampling(k=3) | ||
>>> 0 <= strategy.select_arm() < 3 | ||
np.True_ | ||
""" | ||
rng = np.random.default_rng() | ||
|
||
samples = [ | ||
rng.beta(self.successes[i] + 1, self.failures[i] + 1) for i in range(self.k) | ||
] | ||
return np.argmax(samples) | ||
|
||
def update(self, arm_index: int, reward: int) -> None: | ||
""" | ||
Update the strategy. | ||
|
||
Args: | ||
arm_index: The index of the arm to pull. | ||
reward: The reward for the arm. | ||
|
||
Example: | ||
>>> strategy = ThompsonSampling(k=3) | ||
>>> strategy.update(0, 1) | ||
>>> strategy.successes[0] == 1 | ||
np.True_ | ||
""" | ||
if reward == 1: | ||
self.successes[arm_index] += 1 | ||
else: | ||
self.failures[arm_index] += 1 | ||
|
||
|
||
# Random strategy (full exploration) | ||
class RandomStrategy: | ||
""" | ||
A class for choosing totally random at each round to give | ||
a better comparison with the other optimised strategies. | ||
""" | ||
|
||
def __init__(self, k: int): | ||
sephml marked this conversation as resolved.
Show resolved
Hide resolved
sephml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Initialize the Random strategy. | ||
|
||
Args: | ||
k: The number of arms. | ||
""" | ||
self.k = k | ||
|
||
def select_arm(self) -> int: | ||
""" | ||
Select an arm to pull. | ||
|
||
Returns: | ||
The index of the arm to pull. | ||
|
||
Example: | ||
>>> strategy = RandomStrategy(k=3) | ||
>>> 0 <= strategy.select_arm() < 3 | ||
np.True_ | ||
""" | ||
rng = np.random.default_rng() | ||
return rng.integers(self.k) | ||
|
||
def update(self, arm_index: int, reward: int) -> None: | ||
""" | ||
Update the strategy. | ||
|
||
Args: | ||
arm_index: The index of the arm to pull. | ||
reward: The reward for the arm. | ||
|
||
Example: | ||
>>> strategy = RandomStrategy(k=3) | ||
>>> strategy.update(0, 1) | ||
""" | ||
|
||
|
||
# Greedy strategy (full exploitation) | ||
|
||
|
||
class GreedyStrategy: | ||
""" | ||
A class for the Greedy strategy to show how full exploitation can be | ||
detrimental to the performance of the strategy. | ||
""" | ||
|
||
def __init__(self, k: int): | ||
sephml marked this conversation as resolved.
Show resolved
Hide resolved
sephml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Initialize the Greedy strategy. | ||
|
||
Args: | ||
k: The number of arms. | ||
""" | ||
self.k = k | ||
self.counts = np.zeros(k) | ||
self.values = np.zeros(k) | ||
|
||
def select_arm(self) -> int: | ||
""" | ||
Select an arm to pull. | ||
|
||
Returns: | ||
The index of the arm to pull. | ||
|
||
Example: | ||
>>> strategy = GreedyStrategy(k=3) | ||
>>> 0 <= strategy.select_arm() < 3 | ||
np.True_ | ||
""" | ||
return np.argmax(self.values) | ||
|
||
def update(self, arm_index: int, reward: int) -> None: | ||
""" | ||
Update the strategy. | ||
|
||
Args: | ||
arm_index: The index of the arm to pull. | ||
reward: The reward for the arm. | ||
|
||
Example: | ||
>>> strategy = GreedyStrategy(k=3) | ||
>>> strategy.update(0, 1) | ||
>>> strategy.counts[0] == 1 | ||
np.True_ | ||
""" | ||
self.counts[arm_index] += 1 | ||
n = self.counts[arm_index] | ||
self.values[arm_index] += (reward - self.values[arm_index]) / n | ||
|
||
|
||
def test_mab_strategies() -> None: | ||
""" | ||
Test the MAB strategies. | ||
""" | ||
# Simulation | ||
k = 4 | ||
arms_probabilities = [0.1, 0.3, 0.5, 0.8] # True probabilities | ||
|
||
bandit = Bandit(arms_probabilities) | ||
strategies = { | ||
"Epsilon-Greedy": EpsilonGreedy(epsilon=0.1, k=k), | ||
"UCB": UCB(k=k), | ||
"Thompson Sampling": ThompsonSampling(k=k), | ||
"Full Exploration(Random)": RandomStrategy(k=k), | ||
"Full Exploitation(Greedy)": GreedyStrategy(k=k), | ||
} | ||
|
||
num_rounds = 1000 | ||
results = {} | ||
|
||
for name, strategy in strategies.items(): | ||
rewards = [] | ||
total_reward = 0 | ||
for _ in range(num_rounds): | ||
arm = strategy.select_arm() | ||
current_reward = bandit.pull(arm) | ||
strategy.update(arm, current_reward) | ||
total_reward += current_reward | ||
rewards.append(total_reward) | ||
results[name] = rewards | ||
|
||
# Plotting results | ||
plt.figure(figsize=(12, 6)) | ||
for name, rewards in results.items(): | ||
plt.plot(rewards, label=name) | ||
|
||
plt.title("Cumulative Reward of Multi-Armed Bandit Strategies") | ||
plt.xlabel("Round") | ||
plt.ylabel("Cumulative Reward") | ||
plt.legend() | ||
plt.grid() | ||
plt.show() | ||
|
||
|
||
if __name__ == "__main__": | ||
import doctest | ||
|
||
doctest.testmod() | ||
test_mab_strategies() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.