Skip to content

Commit 2a4694c

Browse files
ferrinericardoV94
andauthored
Replace point vmap with point map (#5664)
* replace point vmap with point map * improve test by adding a multivariate usecase, fix failing refactoring * add a kwarg for postprocessing method (by @zaxtax) * add postprocessing backend kwarg * Update pymc/sampling_jax.py Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]>
1 parent c22ea96 commit 2a4694c

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

pymc/sampling_jax.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,13 @@ def _sample_stats_to_xarray(posterior):
132132
return data
133133

134134

135-
def _get_log_likelihood(model: Model, samples) -> Dict:
135+
def _get_log_likelihood(model: Model, samples, backend=None) -> Dict:
136136
"""Compute log-likelihood for all observations"""
137137
data = {}
138138
for v in model.observed_RVs:
139139
v_elemwise_logpt = model.logpt(v, sum=False)
140140
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=v_elemwise_logpt)
141-
result = jax.jit(jax.vmap(jax.vmap(jax_fn)))(*samples)[0]
141+
result = jax.jit(jax.vmap(jax.vmap(jax_fn)), backend=backend)(*samples)[0]
142142
data[v.name] = result
143143
return data
144144

@@ -226,6 +226,7 @@ def sample_blackjax_nuts(
226226
var_names=None,
227227
keep_untransformed=False,
228228
chain_method="parallel",
229+
postprocessing_backend=None,
229230
idata_kwargs=None,
230231
):
231232
"""
@@ -255,6 +256,8 @@ def sample_blackjax_nuts(
255256
Include untransformed variables in the posterior samples. Defaults to False.
256257
chain_method : str, default "parallel"
257258
Specify how samples should be drawn. The choices include "parallel", and "vectorized".
259+
postprocessing_backend : str, optional
260+
Specify how postprocessing should be computed. gpu or cpu
258261
idata_kwargs : dict, optional
259262
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
260263
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
@@ -341,7 +344,9 @@ def sample_blackjax_nuts(
341344
mcmc_samples = {}
342345
for v in vars_to_sample:
343346
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[v])
344-
result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0]
347+
result = jax.jit(jax.vmap(jax.vmap(jax_fn)), backend=postprocessing_backend)(
348+
*raw_mcmc_samples
349+
)[0]
345350
mcmc_samples[v.name] = result
346351

347352
tic4 = datetime.now()
@@ -353,7 +358,9 @@ def sample_blackjax_nuts(
353358
idata_kwargs = idata_kwargs.copy()
354359

355360
if idata_kwargs.pop("log_likelihood", True):
356-
log_likelihood = _get_log_likelihood(model, raw_mcmc_samples)
361+
log_likelihood = _get_log_likelihood(
362+
model, raw_mcmc_samples, backend=postprocessing_backend
363+
)
357364
else:
358365
log_likelihood = None
359366

@@ -387,6 +394,7 @@ def sample_numpyro_nuts(
387394
progress_bar: bool = True,
388395
keep_untransformed: bool = False,
389396
chain_method: str = "parallel",
397+
postprocessing_backend: str = None,
390398
idata_kwargs: Optional[Dict] = None,
391399
nuts_kwargs: Optional[Dict] = None,
392400
):
@@ -421,6 +429,8 @@ def sample_numpyro_nuts(
421429
Include untransformed variables in the posterior samples. Defaults to False.
422430
chain_method : str, default "parallel"
423431
Specify how samples should be drawn. The choices include "sequential", "parallel", and "vectorized".
432+
postprocessing_backend : Optional[str]
433+
Specify how postprocessing should be computed. gpu or cpu
424434
idata_kwargs : dict, optional
425435
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
426436
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
@@ -525,7 +535,9 @@ def sample_numpyro_nuts(
525535
mcmc_samples = {}
526536
for v in vars_to_sample:
527537
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[v])
528-
result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0]
538+
result = jax.jit(jax.vmap(jax.vmap(jax_fn)), backend=postprocessing_backend)(
539+
*raw_mcmc_samples
540+
)[0]
529541
mcmc_samples[v.name] = result
530542

531543
tic4 = datetime.now()
@@ -537,7 +549,9 @@ def sample_numpyro_nuts(
537549
idata_kwargs = idata_kwargs.copy()
538550

539551
if idata_kwargs.pop("log_likelihood", True):
540-
log_likelihood = _get_log_likelihood(model, raw_mcmc_samples)
552+
log_likelihood = _get_log_likelihood(
553+
model, raw_mcmc_samples, backend=postprocessing_backend
554+
)
541555
else:
542556
log_likelihood = None
543557

pymc/tests/test_sampling_jax.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
sample_numpyro_nuts,
2727
],
2828
)
29-
def test_transform_samples(sampler):
29+
@pytest.mark.parametrize("postprocessing_backend", [None, "cpu"])
30+
def test_transform_samples(sampler, postprocessing_backend):
3031
aesara.config.on_opt_error = "raise"
3132
np.random.seed(13244)
3233

@@ -37,7 +38,12 @@ def test_transform_samples(sampler):
3738
sigma = pm.HalfNormal("sigma")
3839
b = pm.Normal("b", a, sigma=sigma, observed=obs_at)
3940

40-
trace = sampler(chains=1, random_seed=1322, keep_untransformed=True)
41+
trace = sampler(
42+
chains=1,
43+
random_seed=1322,
44+
keep_untransformed=True,
45+
postprocessing_backend=postprocessing_backend,
46+
)
4147

4248
log_vals = trace.posterior["sigma_log__"].values
4349

@@ -49,7 +55,12 @@ def test_transform_samples(sampler):
4955

5056
obs_at.set_value(-obs)
5157
with model:
52-
trace = sampler(chains=2, random_seed=1322, keep_untransformed=False)
58+
trace = sampler(
59+
chains=2,
60+
random_seed=1322,
61+
keep_untransformed=False,
62+
postprocessing_backend=postprocessing_backend,
63+
)
5364

5465
assert -11 < trace.posterior["a"].mean() < -8
5566
assert 1.5 < trace.posterior["sigma"].mean() < 2.5
@@ -145,15 +156,18 @@ def test_get_jaxified_logp():
145156
dict(log_likelihood=False),
146157
],
147158
)
148-
def test_idata_kwargs(sampler, idata_kwargs):
159+
@pytest.mark.parametrize("postprocessing_backend", [None, "cpu"])
160+
def test_idata_kwargs(sampler, idata_kwargs, postprocessing_backend):
149161
with pm.Model() as m:
150162
x = pm.Normal("x")
163+
z = pm.Normal("z")
151164
y = pm.Normal("y", x, observed=0)
152165
idata = sampler(
153166
tune=50,
154167
draws=50,
155168
chains=1,
156169
idata_kwargs=idata_kwargs,
170+
postprocessing_backend=postprocessing_backend,
157171
)
158172

159173
if idata_kwargs.get("log_likelihood", True):

0 commit comments

Comments
 (0)