Skip to content

Commit c34feff

Browse files
committed
fix1
1 parent 4167ddb commit c34feff

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

Diff for: machine_learning/mab.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,14 @@ def select_arm(self) -> int:
9595
Example:
9696
>>> strategy = EpsilonGreedy(epsilon=0.1, num_arms=3)
9797
>>> 0 <= strategy.select_arm() < 3
98-
np.True_
98+
True
9999
"""
100100
rng = np.random.default_rng()
101101

102102
if rng.random() < self.epsilon:
103103
return rng.integers(self.num_arms)
104104
else:
105-
return np.argmax(self.values)
105+
return int(np.argmax(self.values))
106106

107107
def update(self, arm_index: int, reward: int) -> None:
108108
"""
@@ -160,7 +160,7 @@ def select_arm(self) -> int:
160160
if self.total_counts < self.num_arms:
161161
return self.total_counts
162162
ucb_values = self.values + np.sqrt(2 * np.log(self.total_counts) / self.counts)
163-
return np.argmax(ucb_values)
163+
return int(np.argmax(ucb_values))
164164

165165
def update(self, arm_index: int, reward: int) -> None:
166166
"""
@@ -214,15 +214,15 @@ def select_arm(self) -> int:
214214
Example:
215215
>>> strategy = ThompsonSampling(num_arms=3)
216216
>>> 0 <= strategy.select_arm() < 3
217-
np.True_
217+
True
218218
"""
219219
rng = np.random.default_rng()
220220

221221
samples = [
222222
rng.beta(self.successes[i] + 1, self.failures[i] + 1)
223223
for i in range(self.num_arms)
224224
]
225-
return np.argmax(samples)
225+
return int(np.argmax(samples))
226226

227227
def update(self, arm_index: int, reward: int) -> None:
228228
"""
@@ -319,9 +319,9 @@ def select_arm(self) -> int:
319319
Example:
320320
>>> strategy = GreedyStrategy(num_arms=3)
321321
>>> 0 <= strategy.select_arm() < 3
322-
np.True_
322+
True
323323
"""
324-
return np.argmax(self.values)
324+
return int(np.argmax(self.values))
325325

326326
def update(self, arm_index: int, reward: int) -> None:
327327
"""

0 commit comments

Comments
 (0)