Skip to content

Commit c243cd8

Browse files
committed
fixed issues with mypy, ruff
1 parent c34feff commit c243cd8

File tree

1 file changed

+33
-6
lines changed

1 file changed

+33
-6
lines changed

Diff for: machine_learning/mab.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
2525
"""
2626

27+
from abc import ABC, abstractmethod
28+
2729
import matplotlib.pyplot as plt
2830
import numpy as np
2931

@@ -65,7 +67,32 @@ def pull(self, arm_index: int) -> int:
6567
# Epsilon-Greedy strategy
6668

6769

68-
class EpsilonGreedy:
70+
class Strategy(ABC):
71+
"""
72+
Base class for all strategies.
73+
"""
74+
75+
@abstractmethod
76+
def select_arm(self) -> int:
77+
"""
78+
Select an arm to pull.
79+
80+
Returns:
81+
The index of the arm to pull.
82+
"""
83+
84+
@abstractmethod
85+
def update(self, arm_index: int, reward: int) -> None:
86+
"""
87+
Update the strategy.
88+
89+
Args:
90+
arm_index: The index of the arm to pull.
91+
reward: The reward for the arm.
92+
"""
93+
94+
95+
class EpsilonGreedy(Strategy):
6996
"""
7097
A class for a simple implementation of the Epsilon-Greedy strategy.
7198
Follow this link to learn more:
@@ -126,7 +153,7 @@ def update(self, arm_index: int, reward: int) -> None:
126153
# Upper Confidence Bound (UCB)
127154

128155

129-
class UCB:
156+
class UCB(Strategy):
130157
"""
131158
A class for the Upper Confidence Bound (UCB) strategy.
132159
Follow this link to learn more:
@@ -185,7 +212,7 @@ def update(self, arm_index: int, reward: int) -> None:
185212
# Thompson Sampling
186213

187214

188-
class ThompsonSampling:
215+
class ThompsonSampling(Strategy):
189216
"""
190217
A class for the Thompson Sampling strategy.
191218
Follow this link to learn more:
@@ -245,7 +272,7 @@ def update(self, arm_index: int, reward: int) -> None:
245272

246273

247274
# Random strategy (full exploration)
248-
class RandomStrategy:
275+
class RandomStrategy(Strategy):
249276
"""
250277
A class for choosing totally random at each round to give
251278
a better comparison with the other optimised strategies.
@@ -292,7 +319,7 @@ def update(self, arm_index: int, reward: int) -> None:
292319
# Greedy strategy (full exploitation)
293320

294321

295-
class GreedyStrategy:
322+
class GreedyStrategy(Strategy):
296323
"""
297324
A class for the Greedy strategy to show how full exploitation can be
298325
detrimental to the performance of the strategy.
@@ -351,7 +378,7 @@ def test_mab_strategies() -> None:
351378
arms_probabilities = [0.1, 0.3, 0.5, 0.8] # True probabilities
352379

353380
bandit = Bandit(arms_probabilities)
354-
strategies = {
381+
strategies: dict[str, Strategy] = {
355382
"Epsilon-Greedy": EpsilonGreedy(epsilon=0.1, num_arms=num_arms),
356383
"UCB": UCB(num_arms=num_arms),
357384
"Thompson Sampling": ThompsonSampling(num_arms=num_arms),

0 commit comments

Comments
 (0)