Skip to content

Commit 1f80028

Browse files
committed
Use random number generator seeding in SMC
1 parent 4a8de57 commit 1f80028

File tree

4 files changed

+38
-20
lines changed

4 files changed

+38
-20
lines changed

pymc/smc/sample_smc.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def sample_smc(
4242
*,
4343
start=None,
4444
model=None,
45-
random_seed=-1,
45+
random_seed=None,
4646
chains=None,
4747
cores=None,
4848
compute_convergence_checks=True,
@@ -191,15 +191,19 @@ def sample_smc(
191191
cores = min(chains, cores)
192192

193193
if random_seed == -1:
194+
raise FutureWarning(
195+
f"random_seed should be a non-negative integer or None, got: {random_seed}"
196+
"This will raise a ValueError in the Future"
197+
)
194198
random_seed = None
195-
if chains == 1 and isinstance(random_seed, int):
196-
random_seed = [random_seed]
197-
if random_seed is None or isinstance(random_seed, int):
198-
if random_seed is not None:
199-
np.random.seed(random_seed)
200-
random_seed = [np.random.randint(2 ** 30) for _ in range(chains)]
201-
if not isinstance(random_seed, Iterable):
202-
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")
199+
if isinstance(random_seed, int) or random_seed is None:
200+
rng = np.random.default_rng(seed=random_seed)
201+
random_seed = list(rng.integers(2 ** 30, size=chains))
202+
elif isinstance(random_seed, Iterable):
203+
if len(random_seed) != chains:
204+
raise ValueError(f"Length of seeds ({len(seeds)}) must match number of chains {chains}")
205+
else:
206+
raise TypeError("Invalid value for `random_seed`. Must be tuple, list, int or None")
203207

204208
model = modelcontext(model)
205209

pymc/smc/smc.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __init__(
126126
draws=2000,
127127
start=None,
128128
model=None,
129-
random_seed=-1,
129+
random_seed=None,
130130
threshold=0.5,
131131
):
132132
"""
@@ -140,6 +140,8 @@ def __init__(
140140
Starting point in parameter space. It should be a list of dict with length `chains`.
141141
When None (default) the starting point is sampled from the prior distribution.
142142
model: Model (optional if in ``with`` context)).
143+
random_seed: int
144+
Value used to initialize the random number generator.
143145
threshold: float
144146
Determines the change of beta from stage to stage, i.e.indirectly the number of stages,
145147
the higher the value of `threshold` the higher the number of stages. Defaults to 0.5.
@@ -151,10 +153,7 @@ def __init__(
151153
self.start = start
152154
self.threshold = threshold
153155
self.model = model
154-
self.random_seed = random_seed
155-
156-
if self.random_seed != -1:
157-
np.random.seed(self.random_seed)
156+
self.rng = np.random.default_rng(seed=random_seed)
158157

159158
self.model = modelcontext(model)
160159
self.variables = inputvars(self.model.value_vars)
@@ -262,7 +261,7 @@ def update_beta_and_weights(self):
262261

263262
def resample(self):
264263
"""Resample particles based on importance weights"""
265-
self.resampling_indexes = np.random.choice(
264+
self.resampling_indexes = self.rng.choice(
266265
np.arange(self.draws), size=self.draws, p=self.weights
267266
)
268267

@@ -382,11 +381,11 @@ def mutate(self):
382381
ac_ = np.empty((self.n_steps, self.draws))
383382

384383
cov = self.proposal_dist.cov
385-
log_R = np.log(np.random.rand(self.n_steps, self.draws))
384+
log_R = np.log(self.rng.random((self.n_steps, self.draws)))
386385
for n_step in range(self.n_steps):
387386
# The proposal is independent from the current point.
388387
# We have to take that into account to compute the Metropolis-Hastings acceptance
389-
proposal = floatX(self.proposal_dist.rvs(size=self.draws))
388+
proposal = floatX(self.proposal_dist.rvs(size=self.draws, random_state=self.rng))
390389
proposal = proposal.reshape(len(proposal), -1)
391390
# To do that we compute the logp of moving to a new point
392391
forward = self.proposal_dist.logpdf(proposal)
@@ -497,11 +496,12 @@ def mutate(self):
497496
"""Metropolis-Hastings perturbation."""
498497
ac_ = np.empty((self.n_steps, self.draws))
499498

500-
log_R = np.log(np.random.rand(self.n_steps, self.draws))
499+
log_R = np.log(self.rng.random((self.n_steps, self.draws)))
501500
for n_step in range(self.n_steps):
502501
proposal = floatX(
503502
self.tempered_posterior
504-
+ self.proposal_dist(num_draws=self.draws) * self.proposal_scales[:, None]
503+
+ self.proposal_dist(num_draws=self.draws, rng=self.rng)
504+
* self.proposal_scales[:, None]
505505
)
506506
ll = np.array([self.likelihood_logp_func(prop) for prop in proposal])
507507
pl = np.array([self.prior_logp_func(prop) for prop in proposal])

pymc/tests/helpers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from logging.handlers import BufferingHandler
1818

1919
import aesara
20+
import numpy as np
2021
import numpy.random as nr
2122

2223
from aesara.gradient import verify_grad as at_verify_grad
@@ -123,3 +124,11 @@ def verify_grad(op, pt, n_tests=2, rng=None, *args, **kwargs):
123124
if rng is None:
124125
rng = nr.RandomState(411342)
125126
at_verify_grad(op, pt, n_tests, rng, *args, **kwargs)
127+
128+
129+
def assert_random_state_equal(state1, state2):
130+
for field1, field2 in zip(state1, state2):
131+
if isinstance(field1, np.ndarray):
132+
np.testing.assert_array_equal(field1, field2)
133+
else:
134+
assert field1 == field2

pymc/tests/test_smc.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pymc.aesaraf import floatX
3333
from pymc.backends.base import MultiTrace
3434
from pymc.smc.smc import IMH
35-
from pymc.tests.helpers import SeededTest
35+
from pymc.tests.helpers import SeededTest, assert_random_state_equal
3636

3737

3838
class TestSMC(SeededTest):
@@ -77,8 +77,10 @@ def two_gaussians(x):
7777
y = pm.Normal("y", x, 1, observed=0)
7878

7979
def test_sample(self):
80+
initial_rng_state = np.random.get_state()
8081
with self.SMC_test:
8182
mtrace = pm.sample_smc(draws=self.samples, return_inferencedata=False)
83+
assert_random_state_equal(initial_rng_state, np.random.get_state())
8284

8385
x = mtrace["X"]
8486
mu1d = np.abs(x).mean(axis=0)
@@ -531,11 +533,14 @@ def test_named_model(self):
531533
class TestMHKernel(SeededTest):
532534
def test_normal_model(self):
533535
data = st.norm(10, 0.5).rvs(1000, random_state=self.get_random_state())
536+
537+
initial_rng_state = np.random.get_state()
534538
with pm.Model() as m:
535539
mu = pm.Normal("mu", 0, 3)
536540
sigma = pm.HalfNormal("sigma", 1)
537541
y = pm.Normal("y", mu, sigma, observed=data)
538542
idata = pm.sample_smc(draws=2000, kernel=pm.smc.MH)
543+
assert_random_state_equal(initial_rng_state, np.random.get_state())
539544

540545
post = idata.posterior.stack(sample=("chain", "draw"))
541546
assert np.abs(post["mu"].mean() - 10) < 0.1

0 commit comments

Comments
 (0)