Skip to content

Commit 1df15ab

Browse files
committed
Jitter initial points in sample_numpyro_nuts
1 parent e987950 commit 1df15ab

File tree

2 files changed

+75
-13
lines changed

2 files changed

+75
-13
lines changed

pymc/sampling_jax.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from typing import Callable, List, Optional
88

9+
from pymc.sampling import _init_jitter
10+
911
xla_flags = os.getenv("XLA_FLAGS", "")
1012
xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
1113
os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags)
@@ -139,12 +141,46 @@ def _get_log_likelihood(model, samples):
139141
return data
140142

141143

144+
def _get_batched_jittered_initial_points(
145+
model, chains, initvals, random_seed, jitter=True, jitter_max_retries=10
146+
):
147+
"""Get jittered initial point in format expected by NumPyro MCMC kernel
148+
149+
Returns
150+
-------
151+
out: list of ndarrays
152+
list with one item per variable and number of chains as batch dimension.
153+
Each item has shape `(chains, *var.shape)`
154+
"""
155+
if isinstance(random_seed, (int, np.integer)):
156+
random_seed = np.random.default_rng(random_seed).integers(2**30, size=chains)
157+
elif not isinstance(random_seed, (list, tuple, np.ndarray)):
158+
raise ValueError(f"The `seeds` must be int or array-like. Got {type(random_seed)} instead.")
159+
160+
assert len(random_seed) == chains
161+
162+
initial_points = _init_jitter(
163+
model,
164+
initvals,
165+
seeds=random_seed,
166+
jitter=jitter,
167+
jitter_max_retries=jitter_max_retries,
168+
)
169+
initial_points = [list(initial_point.values()) for initial_point in initial_points]
170+
if chains == 1:
171+
initial_points = initial_points[0]
172+
else:
173+
initial_points = [np.stack(init_state) for init_state in zip(*initial_points)]
174+
return initial_points
175+
176+
142177
def sample_numpyro_nuts(
143178
draws=1000,
144179
tune=1000,
145180
chains=4,
146181
target_accept=0.8,
147-
random_seed=10,
182+
random_seed=None,
183+
initvals=None,
148184
model=None,
149185
var_names=None,
150186
progress_bar=True,
@@ -176,13 +212,20 @@ def sample_numpyro_nuts(
176212
else:
177213
dims = {}
178214

215+
if random_seed is None:
216+
random_seed = model.rng_seeder.randint(
217+
2**30, dtype=np.int64, size=chains if chains > 1 else None
218+
)
219+
179220
tic1 = datetime.now()
180221
print("Compiling...", file=sys.stdout)
181222

182-
rv_names = [rv.name for rv in model.value_vars]
183-
initial_point = model.compute_initial_point()
184-
init_state = [initial_point[rv_name] for rv_name in rv_names]
185-
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
223+
init_params = _get_batched_jittered_initial_points(
224+
model=model,
225+
chains=chains,
226+
initvals=initvals,
227+
random_seed=random_seed,
228+
)
186229

187230
logp_fn = get_jaxified_logp(model)
188231

@@ -212,14 +255,9 @@ def sample_numpyro_nuts(
212255

213256
print("Sampling...", file=sys.stdout)
214257

215-
seed = jax.random.PRNGKey(random_seed)
216-
map_seed = jax.random.split(seed, chains)
217-
218-
if chains == 1:
219-
init_params = init_state
220-
map_seed = seed
221-
else:
222-
init_params = init_state_batched
258+
map_seed = jax.random.PRNGKey(random_seed)
259+
if chains > 1:
260+
map_seed = jax.random.split(map_seed, chains)
223261

224262
pmap_numpyro.run(
225263
map_seed,

pymc/tests/test_sampling_jax.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pymc as pm
1010

1111
from pymc.sampling_jax import (
12+
_get_batched_jittered_initial_points,
1213
_get_log_likelihood,
1314
_replace_shared_variables,
1415
get_jaxified_graph,
@@ -137,3 +138,26 @@ def test_idata_kwargs(idata_kwargs):
137138
assert "log_likelihood" in idata
138139
else:
139140
assert "log_likelihood" not in idata
141+
142+
143+
def test_get_batched_jittered_initial_points():
144+
with pm.Model() as model:
145+
x = pm.MvNormal("x", mu=np.zeros(3), cov=np.eye(3), shape=(2, 3), initval=np.zeros((2, 3)))
146+
147+
# No jitter
148+
ips = _get_batched_jittered_initial_points(
149+
model=model, chains=1, random_seed=1, initvals=None, jitter=False
150+
)
151+
assert np.all(ips[0] == 0)
152+
153+
# Single chain
154+
ips = _get_batched_jittered_initial_points(model=model, chains=1, random_seed=1, initvals=None)
155+
156+
assert ips[0].shape == (2, 3)
157+
assert np.all(ips[0] != 0)
158+
159+
# Multiple chains
160+
ips = _get_batched_jittered_initial_points(model=model, chains=2, random_seed=1, initvals=None)
161+
162+
assert ips[0].shape == (2, 2, 3)
163+
assert np.all(ips[0][0] != ips[0][1])

0 commit comments

Comments
 (0)