Skip to content

Commit dd4b940

Browse files
ricardoV94twiecki
authored andcommitted
Fix seeding bug in sample_numpyro_nuts when more than one chain was sampled
1 parent edf49e1 commit dd4b940

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

pymc/sampling_jax.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,8 @@ def _get_batched_jittered_initial_points(
152152
list with one item per variable and number of chains as batch dimension.
153153
Each item has shape `(chains, *var.shape)`
154154
"""
155-
if isinstance(random_seed, (int, np.integer)):
156-
random_seed = np.random.default_rng(random_seed).integers(2**30, size=chains)
157-
elif not isinstance(random_seed, (list, tuple, np.ndarray)):
158-
raise ValueError(f"The `seeds` must be int or array-like. Got {type(random_seed)} instead.")
155+
156+
random_seed = np.random.default_rng(random_seed).integers(2**30, size=chains)
159157

160158
assert len(random_seed) == chains
161159

@@ -213,9 +211,7 @@ def sample_numpyro_nuts(
213211
dims = {}
214212

215213
if random_seed is None:
216-
random_seed = model.rng_seeder.randint(
217-
2**30, dtype=np.int64, size=chains if chains > 1 else None
218-
)
214+
random_seed = model.rng_seeder.randint(2**30, dtype=np.int64)
219215

220216
tic1 = datetime.now()
221217
print("Compiling...", file=sys.stdout)

pymc/tests/test_sampling_jax.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,32 @@ def test_get_batched_jittered_initial_points():
161161

162162
assert ips[0].shape == (2, 2, 3)
163163
assert np.all(ips[0][0] != ips[0][1])
164+
165+
166+
@pytest.mark.parametrize("random_seed", (None, 123))
167+
@pytest.mark.parametrize("chains", (1, 2))
168+
def test_seeding(chains, random_seed):
169+
sample_kwargs = dict(
170+
tune=100,
171+
draws=5,
172+
chains=chains,
173+
random_seed=random_seed,
174+
)
175+
176+
with pm.Model(rng_seeder=456) as m:
177+
pm.Normal("x", mu=0, sigma=1)
178+
result1 = sample_numpyro_nuts(**sample_kwargs)
179+
180+
with pm.Model(rng_seeder=456) as m:
181+
pm.Normal("x", mu=0, sigma=1)
182+
result2 = sample_numpyro_nuts(**sample_kwargs)
183+
result3 = sample_numpyro_nuts(**sample_kwargs)
184+
185+
assert np.all(result1.posterior["x"] == result2.posterior["x"])
186+
expected_equal_result3 = random_seed is not None
187+
assert np.all(result2.posterior["x"] == result3.posterior["x"]) == expected_equal_result3
188+
189+
if chains > 1:
190+
assert np.all(result1.posterior["x"].sel(chain=0) != result1.posterior["x"].sel(chain=1))
191+
assert np.all(result2.posterior["x"].sel(chain=0) != result2.posterior["x"].sel(chain=1))
192+
assert np.all(result3.posterior["x"].sel(chain=0) != result3.posterior["x"].sel(chain=1))

0 commit comments

Comments
 (0)