Skip to content

Commit e58c0e6

Browse files
ferrinericardoV94
andauthored
refactor sampling_jax postrocessing to avoid jit (#5908)
Refactor sampling_jax postrocessing to avoid JIT and multiple function definitions Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 0ba47b5 commit e58c0e6

File tree

2 files changed

+47
-27
lines changed

2 files changed

+47
-27
lines changed

pymc/sampling_jax.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,10 @@ def _sample_stats_to_xarray(posterior):
134134

135135
def _get_log_likelihood(model: Model, samples, backend=None) -> Dict:
136136
"""Compute log-likelihood for all observations"""
137-
data = {}
138-
for v in model.observed_RVs:
139-
v_elemwise_logp = model.logp(v, sum=False)
140-
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=v_elemwise_logp)
141-
result = jax.jit(jax.vmap(jax.vmap(jax_fn)), backend=backend)(*samples)[0]
142-
data[v.name] = result
143-
return data
137+
elemwise_logp = model.logp(model.observed_RVs, sum=False)
138+
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp)
139+
result = jax.vmap(jax.vmap(jax_fn))(*jax.device_put(samples, jax.devices(backend)[0]))
140+
return {v.name: r for v, r in zip(model.observed_RVs, result)}
144141

145142

146143
def _get_batched_jittered_initial_points(
@@ -339,13 +336,11 @@ def sample_blackjax_nuts(
339336
print("Sampling time = ", tic3 - tic2, file=sys.stdout)
340337

341338
print("Transforming variables...", file=sys.stdout)
342-
mcmc_samples = {}
343-
for v in vars_to_sample:
344-
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[v])
345-
result = jax.jit(jax.vmap(jax.vmap(jax_fn)), backend=postprocessing_backend)(
346-
*raw_mcmc_samples
347-
)[0]
348-
mcmc_samples[v.name] = result
339+
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
340+
result = jax.vmap(jax.vmap(jax_fn))(
341+
*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
342+
)
343+
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
349344

350345
tic4 = datetime.now()
351346
print("Transformation time = ", tic4 - tic3, file=sys.stdout)
@@ -355,10 +350,14 @@ def sample_blackjax_nuts(
355350
else:
356351
idata_kwargs = idata_kwargs.copy()
357352

358-
if idata_kwargs.pop("log_likelihood", True):
353+
if idata_kwargs.pop("log_likelihood", bool(model.observed_RVs)):
354+
tic5 = datetime.now()
355+
print("Computing Log Likelihood...", file=sys.stdout)
359356
log_likelihood = _get_log_likelihood(
360357
model, raw_mcmc_samples, backend=postprocessing_backend
361358
)
359+
tic6 = datetime.now()
360+
print("Log Likelihood time = ", tic6 - tic5, file=sys.stdout)
362361
else:
363362
log_likelihood = None
364363

@@ -531,13 +530,11 @@ def sample_numpyro_nuts(
531530
print("Sampling time = ", tic3 - tic2, file=sys.stdout)
532531

533532
print("Transforming variables...", file=sys.stdout)
534-
mcmc_samples = {}
535-
for v in vars_to_sample:
536-
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[v])
537-
result = jax.jit(jax.vmap(jax.vmap(jax_fn)), backend=postprocessing_backend)(
538-
*raw_mcmc_samples
539-
)[0]
540-
mcmc_samples[v.name] = result
533+
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
534+
result = jax.vmap(jax.vmap(jax_fn))(
535+
*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
536+
)
537+
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
541538

542539
tic4 = datetime.now()
543540
print("Transformation time = ", tic4 - tic3, file=sys.stdout)
@@ -547,10 +544,14 @@ def sample_numpyro_nuts(
547544
else:
548545
idata_kwargs = idata_kwargs.copy()
549546

550-
if idata_kwargs.pop("log_likelihood", True):
547+
if idata_kwargs.pop("log_likelihood", bool(model.observed_RVs)):
548+
tic5 = datetime.now()
549+
print("Computing Log Likelihood...", file=sys.stdout)
551550
log_likelihood = _get_log_likelihood(
552551
model, raw_mcmc_samples, backend=postprocessing_backend
553552
)
553+
tic6 = datetime.now()
554+
print("Log Likelihood time = ", tic6 - tic5, file=sys.stdout)
554555
else:
555556
log_likelihood = None
556557

pymc/tests/test_sampling_jax.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import aesara
22
import aesara.tensor as at
3+
import jax
34
import numpy as np
45
import pytest
56

@@ -27,7 +28,16 @@
2728
],
2829
)
2930
@pytest.mark.parametrize("postprocessing_backend", [None, "cpu"])
30-
def test_transform_samples(sampler, postprocessing_backend):
31+
@pytest.mark.parametrize(
32+
"chains",
33+
[
34+
pytest.param(1),
35+
pytest.param(
36+
2, marks=pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices")
37+
),
38+
],
39+
)
40+
def test_transform_samples(sampler, postprocessing_backend, chains):
3141
aesara.config.on_opt_error = "raise"
3242
np.random.seed(13244)
3343

@@ -39,7 +49,7 @@ def test_transform_samples(sampler, postprocessing_backend):
3949
b = pm.Normal("b", a, sigma=sigma, observed=obs_at)
4050

4151
trace = sampler(
42-
chains=1,
52+
chains=chains,
4353
random_seed=1322,
4454
keep_untransformed=True,
4555
postprocessing_backend=postprocessing_backend,
@@ -56,7 +66,7 @@ def test_transform_samples(sampler, postprocessing_backend):
5666
obs_at.set_value(-obs)
5767
with model:
5868
trace = sampler(
59-
chains=2,
69+
chains=chains,
6070
random_seed=1322,
6171
keep_untransformed=False,
6272
postprocessing_backend=postprocessing_backend,
@@ -73,6 +83,7 @@ def test_transform_samples(sampler, postprocessing_backend):
7383
sample_numpyro_nuts,
7484
],
7585
)
86+
@pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices")
7687
def test_deterministic_samples(sampler):
7788
aesara.config.on_opt_error = "raise"
7889
np.random.seed(13244)
@@ -207,7 +218,15 @@ def test_get_batched_jittered_initial_points():
207218
],
208219
)
209220
@pytest.mark.parametrize("random_seed", (None, 123))
210-
@pytest.mark.parametrize("chains", (1, 2))
221+
@pytest.mark.parametrize(
222+
"chains",
223+
[
224+
pytest.param(1),
225+
pytest.param(
226+
2, marks=pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices")
227+
),
228+
],
229+
)
211230
def test_seeding(chains, random_seed, sampler):
212231
sample_kwargs = dict(
213232
tune=100,

0 commit comments

Comments
 (0)