|
24 | 24 |
|
25 | 25 | """
|
26 | 26 |
|
| 27 | +from abc import ABC, abstractmethod |
| 28 | + |
27 | 29 | import matplotlib.pyplot as plt
|
28 | 30 | import numpy as np
|
29 | 31 |
|
@@ -65,7 +67,32 @@ def pull(self, arm_index: int) -> int:
|
65 | 67 | # Epsilon-Greedy strategy
|
66 | 68 |
|
67 | 69 |
|
68 |
| -class EpsilonGreedy: |
| 70 | +class Strategy(ABC): |
| 71 | + """ |
| 72 | + Base class for all strategies. |
| 73 | + """ |
| 74 | + |
| 75 | + @abstractmethod |
| 76 | + def select_arm(self) -> int: |
| 77 | + """ |
| 78 | + Select an arm to pull. |
| 79 | +
|
| 80 | + Returns: |
| 81 | + The index of the arm to pull. |
| 82 | + """ |
| 83 | + |
| 84 | + @abstractmethod |
| 85 | + def update(self, arm_index: int, reward: int) -> None: |
| 86 | + """ |
| 87 | + Update the strategy. |
| 88 | +
|
| 89 | + Args: |
| 90 | + arm_index: The index of the arm to pull. |
| 91 | + reward: The reward for the arm. |
| 92 | + """ |
| 93 | + |
| 94 | + |
| 95 | +class EpsilonGreedy(Strategy): |
69 | 96 | """
|
70 | 97 | A class for a simple implementation of the Epsilon-Greedy strategy.
|
71 | 98 | Follow this link to learn more:
|
@@ -126,7 +153,7 @@ def update(self, arm_index: int, reward: int) -> None:
|
126 | 153 | # Upper Confidence Bound (UCB)
|
127 | 154 |
|
128 | 155 |
|
129 |
| -class UCB: |
| 156 | +class UCB(Strategy): |
130 | 157 | """
|
131 | 158 | A class for the Upper Confidence Bound (UCB) strategy.
|
132 | 159 | Follow this link to learn more:
|
@@ -185,7 +212,7 @@ def update(self, arm_index: int, reward: int) -> None:
|
185 | 212 | # Thompson Sampling
|
186 | 213 |
|
187 | 214 |
|
188 |
| -class ThompsonSampling: |
| 215 | +class ThompsonSampling(Strategy): |
189 | 216 | """
|
190 | 217 | A class for the Thompson Sampling strategy.
|
191 | 218 | Follow this link to learn more:
|
@@ -245,7 +272,7 @@ def update(self, arm_index: int, reward: int) -> None:
|
245 | 272 |
|
246 | 273 |
|
247 | 274 | # Random strategy (full exploration)
|
248 |
| -class RandomStrategy: |
| 275 | +class RandomStrategy(Strategy): |
249 | 276 | """
|
250 | 277 | A class for choosing totally random at each round to give
|
251 | 278 | a better comparison with the other optimised strategies.
|
@@ -292,7 +319,7 @@ def update(self, arm_index: int, reward: int) -> None:
|
292 | 319 | # Greedy strategy (full exploitation)
|
293 | 320 |
|
294 | 321 |
|
295 |
| -class GreedyStrategy: |
| 322 | +class GreedyStrategy(Strategy): |
296 | 323 | """
|
297 | 324 | A class for the Greedy strategy to show how full exploitation can be
|
298 | 325 | detrimental to the performance of the strategy.
|
@@ -351,7 +378,7 @@ def test_mab_strategies() -> None:
|
351 | 378 | arms_probabilities = [0.1, 0.3, 0.5, 0.8] # True probabilities
|
352 | 379 |
|
353 | 380 | bandit = Bandit(arms_probabilities)
|
354 |
| - strategies = { |
| 381 | + strategies: dict[str, Strategy] = { |
355 | 382 | "Epsilon-Greedy": EpsilonGreedy(epsilon=0.1, num_arms=num_arms),
|
356 | 383 | "UCB": UCB(num_arms=num_arms),
|
357 | 384 | "Thompson Sampling": ThompsonSampling(num_arms=num_arms),
|
|
0 commit comments