Skip to content

Commit 7343268

Browse files
committed
fixed variable name k
1 parent a824511 commit 7343268

File tree

1 file changed

+47
-45
lines changed

1 file changed

+47
-45
lines changed

Diff for: machine_learning/mab.py

+47-45
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, probabilities: list[float]) -> None:
4141
probabilities: List of probabilities for each arm.
4242
"""
4343
self.probabilities = probabilities
44-
self.k = len(probabilities)
44+
self.num_arms = len(probabilities)
4545

4646
def pull(self, arm_index: int) -> int:
4747
"""
@@ -72,18 +72,18 @@ class EpsilonGreedy:
7272
https://medium.com/analytics-vidhya/the-epsilon-greedy-algorithm-for-reinforcement-learning-5fe6f96dc870
7373
"""
7474

75-
def __init__(self, epsilon: float, k: int) -> None:
75+
def __init__(self, epsilon: float, num_arms: int) -> None:
7676
"""
7777
Initialize the Epsilon-Greedy strategy.
7878
7979
Args:
8080
epsilon: The probability of exploring new arms.
81-
k: The number of arms.
81+
num_arms: The number of arms.
8282
"""
8383
self.epsilon = epsilon
84-
self.k = k
85-
self.counts = np.zeros(k)
86-
self.values = np.zeros(k)
84+
self.num_arms = num_arms
85+
self.counts = np.zeros(num_arms)
86+
self.values = np.zeros(num_arms)
8787

8888
def select_arm(self) -> int:
8989
"""
@@ -93,14 +93,14 @@ def select_arm(self) -> int:
9393
The index of the arm to pull.
9494
9595
Example:
96-
>>> strategy = EpsilonGreedy(epsilon=0.1, k=3)
96+
>>> strategy = EpsilonGreedy(epsilon=0.1, num_arms=3)
9797
>>> 0 <= strategy.select_arm() < 3
9898
np.True_
9999
"""
100100
rng = np.random.default_rng()
101101

102102
if rng.random() < self.epsilon:
103-
return rng.integers(self.k)
103+
return rng.integers(self.num_arms)
104104
else:
105105
return np.argmax(self.values)
106106

@@ -113,7 +113,7 @@ def update(self, arm_index: int, reward: int) -> None:
113113
reward: The reward for the arm.
114114
115115
Example:
116-
>>> strategy = EpsilonGreedy(epsilon=0.1, k=3)
116+
>>> strategy = EpsilonGreedy(epsilon=0.1, num_arms=3)
117117
>>> strategy.update(0, 1)
118118
>>> strategy.counts[0] == 1
119119
np.True_
@@ -133,16 +133,16 @@ class UCB:
133133
https://people.maths.bris.ac.uk/~maajg/teaching/stochopt/ucb.pdf
134134
"""
135135

136-
def __init__(self, k: int) -> None:
136+
def __init__(self, num_arms: int) -> None:
137137
"""
138138
Initialize the UCB strategy.
139139
140140
Args:
141-
k: The number of arms.
141+
num_arms: The number of arms.
142142
"""
143-
self.k = k
144-
self.counts = np.zeros(k)
145-
self.values = np.zeros(k)
143+
self.num_arms = num_arms
144+
self.counts = np.zeros(num_arms)
145+
self.values = np.zeros(num_arms)
146146
self.total_counts = 0
147147

148148
def select_arm(self) -> int:
@@ -153,13 +153,14 @@ def select_arm(self) -> int:
153153
The index of the arm to pull.
154154
155155
Example:
156-
>>> strategy = UCB(k=3)
156+
>>> strategy = UCB(num_arms=3)
157157
>>> 0 <= strategy.select_arm() < 3
158158
True
159159
"""
160-
if self.total_counts < self.k:
160+
if self.total_counts < self.num_arms:
161161
return self.total_counts
162-
ucb_values = self.values + np.sqrt(2 * np.log(self.total_counts) / self.counts)
162+
ucb_values = self.values + \
163+
np.sqrt(2 * np.log(self.total_counts) / self.counts)
163164
return np.argmax(ucb_values)
164165

165166
def update(self, arm_index: int, reward: int) -> None:
@@ -171,7 +172,7 @@ def update(self, arm_index: int, reward: int) -> None:
171172
reward: The reward for the arm.
172173
173174
Example:
174-
>>> strategy = UCB(k=3)
175+
>>> strategy = UCB(num_arms=3)
175176
>>> strategy.update(0, 1)
176177
>>> strategy.counts[0] == 1
177178
np.True_
@@ -192,16 +193,16 @@ class ThompsonSampling:
192193
https://en.wikipedia.org/wiki/Thompson_sampling
193194
"""
194195

195-
def __init__(self, k: int) -> None:
196+
def __init__(self, num_arms: int) -> None:
196197
"""
197198
Initialize the Thompson Sampling strategy.
198199
199200
Args:
200-
k: The number of arms.
201+
num_arms: The number of arms.
201202
"""
202-
self.k = k
203-
self.successes = np.zeros(k)
204-
self.failures = np.zeros(k)
203+
self.num_arms = num_arms
204+
self.successes = np.zeros(num_arms)
205+
self.failures = np.zeros(num_arms)
205206

206207
def select_arm(self) -> int:
207208
"""
@@ -212,14 +213,15 @@ def select_arm(self) -> int:
212213
which relies on the Beta distribution.
213214
214215
Example:
215-
>>> strategy = ThompsonSampling(k=3)
216+
>>> strategy = ThompsonSampling(num_arms=3)
216217
>>> 0 <= strategy.select_arm() < 3
217218
np.True_
218219
"""
219220
rng = np.random.default_rng()
220221

221222
samples = [
222-
rng.beta(self.successes[i] + 1, self.failures[i] + 1) for i in range(self.k)
223+
rng.beta(self.successes[i] + 1, self.failures[i] + 1)
224+
for i in range(self.num_arms)
223225
]
224226
return np.argmax(samples)
225227

@@ -232,7 +234,7 @@ def update(self, arm_index: int, reward: int) -> None:
232234
reward: The reward for the arm.
233235
234236
Example:
235-
>>> strategy = ThompsonSampling(k=3)
237+
>>> strategy = ThompsonSampling(num_arms=3)
236238
>>> strategy.update(0, 1)
237239
>>> strategy.successes[0] == 1
238240
np.True_
@@ -250,14 +252,14 @@ class RandomStrategy:
250252
a better comparison with the other optimised strategies.
251253
"""
252254

253-
def __init__(self, k: int):
255+
def __init__(self, num_arms: int) -> None:
254256
"""
255257
Initialize the Random strategy.
256258
257259
Args:
258-
k: The number of arms.
260+
num_arms: The number of arms.
259261
"""
260-
self.k = k
262+
self.num_arms = num_arms
261263

262264
def select_arm(self) -> int:
263265
"""
@@ -267,12 +269,12 @@ def select_arm(self) -> int:
267269
The index of the arm to pull.
268270
269271
Example:
270-
>>> strategy = RandomStrategy(k=3)
272+
>>> strategy = RandomStrategy(num_arms=3)
271273
>>> 0 <= strategy.select_arm() < 3
272274
np.True_
273275
"""
274276
rng = np.random.default_rng()
275-
return rng.integers(self.k)
277+
return rng.integers(self.num_arms)
276278

277279
def update(self, arm_index: int, reward: int) -> None:
278280
"""
@@ -283,7 +285,7 @@ def update(self, arm_index: int, reward: int) -> None:
283285
reward: The reward for the arm.
284286
285287
Example:
286-
>>> strategy = RandomStrategy(k=3)
288+
>>> strategy = RandomStrategy(num_arms=3)
287289
>>> strategy.update(0, 1)
288290
"""
289291

@@ -297,16 +299,16 @@ class GreedyStrategy:
297299
detrimental to the performance of the strategy.
298300
"""
299301

300-
def __init__(self, k: int):
302+
def __init__(self, num_arms: int) -> None:
301303
"""
302304
Initialize the Greedy strategy.
303305
304306
Args:
305-
k: The number of arms.
307+
num_arms: The number of arms.
306308
"""
307-
self.k = k
308-
self.counts = np.zeros(k)
309-
self.values = np.zeros(k)
309+
self.num_arms = num_arms
310+
self.counts = np.zeros(num_arms)
311+
self.values = np.zeros(num_arms)
310312

311313
def select_arm(self) -> int:
312314
"""
@@ -316,7 +318,7 @@ def select_arm(self) -> int:
316318
The index of the arm to pull.
317319
318320
Example:
319-
>>> strategy = GreedyStrategy(k=3)
321+
>>> strategy = GreedyStrategy(num_arms=3)
320322
>>> 0 <= strategy.select_arm() < 3
321323
np.True_
322324
"""
@@ -331,7 +333,7 @@ def update(self, arm_index: int, reward: int) -> None:
331333
reward: The reward for the arm.
332334
333335
Example:
334-
>>> strategy = GreedyStrategy(k=3)
336+
>>> strategy = GreedyStrategy(num_arms=3)
335337
>>> strategy.update(0, 1)
336338
>>> strategy.counts[0] == 1
337339
np.True_
@@ -346,16 +348,16 @@ def test_mab_strategies() -> None:
346348
Test the MAB strategies.
347349
"""
348350
# Simulation
349-
k = 4
351+
num_arms = 4
350352
arms_probabilities = [0.1, 0.3, 0.5, 0.8] # True probabilities
351353

352354
bandit = Bandit(arms_probabilities)
353355
strategies = {
354-
"Epsilon-Greedy": EpsilonGreedy(epsilon=0.1, k=k),
355-
"UCB": UCB(k=k),
356-
"Thompson Sampling": ThompsonSampling(k=k),
357-
"Full Exploration(Random)": RandomStrategy(k=k),
358-
"Full Exploitation(Greedy)": GreedyStrategy(k=k),
356+
"Epsilon-Greedy": EpsilonGreedy(epsilon=0.1, num_arms=num_arms),
357+
"UCB": UCB(num_arms=num_arms),
358+
"Thompson Sampling": ThompsonSampling(num_arms=num_arms),
359+
"Full Exploration(Random)": RandomStrategy(num_arms=num_arms),
360+
"Full Exploitation(Greedy)": GreedyStrategy(num_arms=num_arms),
359361
}
360362

361363
num_rounds = 1000

0 commit comments

Comments
 (0)