Skip to content

Commit 4836bc1

Browse files
committed
fix bug, rename, add test
1 parent da65a8e commit 4836bc1

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

pymc/smc/smc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def update_beta_and_weights(self):
274274

275275
def resample(self):
276276
"""Resample particles based on importance weights"""
277-
self.resampling_indexes = systematic(self.weights)
277+
self.resampling_indexes = systematic_resampling(self.weights, self.rng)
278278

279279
self.tempered_posterior = self.tempered_posterior[self.resampling_indexes]
280280
self.prior_logp = self.prior_logp[self.resampling_indexes]
@@ -544,7 +544,7 @@ def sample_settings(self):
544544
return stats
545545

546546

547-
def systematic(weights):
547+
def systematic_resampling(weights, rng):
548548
"""
549549
Systematic resampling.
550550
@@ -560,7 +560,7 @@ def systematic(weights):
560560
"""
561561
lnw = len(weights)
562562
arange = np.arange(lnw)
563-
uniform = (np.random.rand(1) + arange) / lnw
563+
uniform = (rng.random(1) + arange) / lnw
564564

565565
idx = 0
566566
weight_accu = weights[0]

pymc/tests/test_smc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,11 @@ def test_proposal_dist_shape(self):
286286
kernel=pm.smc.MH,
287287
return_inferencedata=False,
288288
)
289+
290+
291+
def test_systematic():
292+
rng = np.random.default_rng(seed=34)
293+
weights = [0.33, 0.33, 0.33]
294+
np.testing.assert_array_equal(pm.smc.systematic_resampling(weights, rng), [0, 1, 2])
295+
weights = [0.99, 0.01]
296+
np.testing.assert_array_equal(pm.smc.systematic_resampling(weights, rng), [0, 0])

0 commit comments

Comments
 (0)