Skip to content

Commit da65a8e

Browse files
committed
use systematic sampling
1 parent 838c0d7 commit da65a8e

File tree

1 file changed

+31
-3
lines changed

1 file changed

+31
-3
lines changed

pymc/smc/smc.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,7 @@ def update_beta_and_weights(self):
274274

275275
def resample(self):
276276
"""Resample particles based on importance weights"""
277-
self.resampling_indexes = self.rng.choice(
278-
np.arange(self.draws), size=self.draws, p=self.weights
279-
)
277+
self.resampling_indexes = systematic(self.weights)
280278

281279
self.tempered_posterior = self.tempered_posterior[self.resampling_indexes]
282280
self.prior_logp = self.prior_logp[self.resampling_indexes]
@@ -546,6 +544,36 @@ def sample_settings(self):
546544
return stats
547545

548546

547+
def systematic(weights):
548+
"""
549+
Systematic resampling.
550+
551+
Parameters
552+
----------
553+
weights :
554+
The weights should be probabilities and the total sum should be 1.
555+
556+
Returns
557+
-------
558+
new_indices: array
559+
A vector of indices in the interval 0, ..., len(normalized_weights)
560+
"""
561+
lnw = len(weights)
562+
arange = np.arange(lnw)
563+
uniform = (np.random.rand(1) + arange) / lnw
564+
565+
idx = 0
566+
weight_accu = weights[0]
567+
new_indices = np.empty(lnw, dtype=int)
568+
for i in arange:
569+
while uniform[i] > weight_accu:
570+
idx += 1
571+
weight_accu += weights[idx]
572+
new_indices[i] = idx
573+
574+
return new_indices
575+
576+
549577
def _logp_forw(point, out_vars, in_vars, shared):
550578
"""Compile Aesara function of the model and the input and output variables.
551579

0 commit comments

Comments
 (0)