Skip to content

Commit 15fbf0e

Browse files
authored
Fix jax sampling (#6922)
1 parent 54269df commit 15fbf0e

File tree

4 files changed

+74
-45
lines changed

4 files changed

+74
-45
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,7 @@ jobs:
395395
- name: Install external samplers
396396
run: |
397397
conda activate pymc-test
398-
pip install "numpyro>=0.8.0"
399-
pip install git+https://github.com/blackjax-devs/[email protected]
398+
pip install "numpyro>=0.8.0" "blackjax>=1.0.0"
400399
- name: Run tests
401400
run: |
402401
python -m pytest -vv --cov=pymc --cov-report=xml --no-cov-on-fail --cov-report term --durations=50 $TEST_SUBSET

pymc/sampling/jax.py

Lines changed: 68 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from datetime import datetime
1919
from functools import partial
20-
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
20+
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union
2121

2222
import arviz as az
2323
import jax
@@ -26,7 +26,7 @@
2626
import pytensor.tensor as pt
2727

2828
from arviz.data.base import make_attrs
29-
from jax.experimental.maps import SerialLoop, xmap
29+
from jax.lax import scan
3030
from pytensor.compile import SharedVariable, Supervisor, mode
3131
from pytensor.graph.basic import graph_inputs
3232
from pytensor.graph.fg import FunctionGraph
@@ -175,25 +175,29 @@ def _sample_stats_to_xarray(posterior):
175175
return data
176176

177177

178+
def _device_put(input, device: str):
179+
return jax.device_put(input, jax.devices(device)[0])
180+
181+
178182
def _postprocess_samples(
179-
jax_fn: List[TensorVariable],
183+
jax_fn: Callable,
180184
raw_mcmc_samples: List[TensorVariable],
181-
postprocessing_backend: str,
182-
num_chunks: Optional[int] = None,
185+
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
186+
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
183187
) -> List[TensorVariable]:
184-
if num_chunks is not None:
185-
loop = xmap(
186-
jax_fn,
187-
in_axes=["chain", "samples", ...],
188-
out_axes=["chain", "samples", ...],
189-
axis_resources={"samples": SerialLoop(num_chunks)},
188+
if postprocessing_vectorize == "scan":
189+
t_raw_mcmc_samples = [jnp.swapaxes(t, 0, 1) for t in raw_mcmc_samples]
190+
jax_vfn = jax.vmap(jax_fn)
191+
_, outs = scan(
192+
lambda _, x: ((), jax_vfn(*x)),
193+
(),
194+
_device_put(t_raw_mcmc_samples, postprocessing_backend),
190195
)
191-
f = xmap(loop, in_axes=[...], out_axes=[...])
192-
return f(*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0]))
196+
return [jnp.swapaxes(t, 0, 1) for t in outs]
197+
elif postprocessing_vectorize == "vmap":
198+
return jax.vmap(jax.vmap(jax_fn))(*_device_put(raw_mcmc_samples, postprocessing_backend))
193199
else:
194-
return jax.vmap(jax.vmap(jax_fn))(
195-
*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
196-
)
200+
raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}")
197201

198202

199203
def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict:
@@ -231,12 +235,17 @@ def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict:
231235

232236

233237
def _get_log_likelihood(
234-
model: Model, samples, backend=None, num_chunks: Optional[int] = None
238+
model: Model,
239+
samples,
240+
backend: Literal["cpu", "gpu"] | None = None,
241+
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
235242
) -> Dict:
236243
"""Compute log-likelihood for all observations"""
237244
elemwise_logp = model.logp(model.observed_RVs, sum=False)
238245
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp)
239-
result = _postprocess_samples(jax_fn, samples, backend, num_chunks=num_chunks)
246+
result = _postprocess_samples(
247+
jax_fn, samples, backend, postprocessing_vectorize=postprocessing_vectorize
248+
)
240249
return {v.name: r for v, r in zip(model.observed_RVs, result)}
241250

242251

@@ -297,11 +306,11 @@ def _blackjax_inference_loop(
297306

298307
adapt = blackjax.window_adaptation(
299308
algorithm=algorithm,
300-
logprob_fn=logprob_fn,
301-
num_steps=tune,
309+
logdensity_fn=logprob_fn,
302310
target_acceptance_rate=target_accept,
303311
)
304-
last_state, kernel, _ = adapt.run(seed, init_position)
312+
(last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune)
313+
kernel = algorithm(logprob_fn, **tuned_params).step
305314

306315
def inference_loop(rng_key, initial_state):
307316
def one_step(state, rng_key):
@@ -327,9 +336,10 @@ def sample_blackjax_nuts(
327336
var_names: Optional[Sequence[str]] = None,
328337
keep_untransformed: bool = False,
329338
chain_method: str = "parallel",
330-
postprocessing_backend: Optional[str] = None,
331-
postprocessing_chunks: Optional[int] = None,
339+
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
340+
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
332341
idata_kwargs: Optional[Dict[str, Any]] = None,
342+
postprocessing_chunks=None, # deprecated
333343
) -> az.InferenceData:
334344
"""
335345
Draw samples from the posterior using the NUTS method from the ``blackjax`` library.
@@ -366,12 +376,10 @@ def sample_blackjax_nuts(
366376
chain_method : str, default "parallel"
367377
Specify how samples should be drawn. The choices include "parallel", and
368378
"vectorized".
369-
postprocessing_backend : str, optional
379+
postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None,
370380
Specify how postprocessing should be computed. gpu or cpu
371-
postprocessing_chunks: Optional[int], default None
372-
Specify the number of chunks the postprocessing should be computed in. More
373-
chunks reduces memory usage at the cost of losing some vectorization, None
374-
uses jax.vmap
381+
postprocessing_vectorize: Literal["vmap", "scan"], default "scan"
382+
How to vectorize the postprocessing: vmap or sequential scan
375383
idata_kwargs : dict, optional
376384
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
377385
value for the ``log_likelihood`` key to indicate that the pointwise log
@@ -387,6 +395,14 @@ def sample_blackjax_nuts(
387395
with their respective sample stats and pointwise log likeihood values (unless
388396
skipped with ``idata_kwargs``).
389397
"""
398+
if postprocessing_chunks is not None:
399+
import warnings
400+
401+
warnings.warn(
402+
"postprocessing_chunks is deprecated due to being unstable, "
403+
"using postprocessing_vectorize='scan' instead",
404+
DeprecationWarning,
405+
)
390406
import blackjax
391407

392408
model = modelcontext(model)
@@ -441,14 +457,17 @@ def sample_blackjax_nuts(
441457

442458
states, stats = map_fn(get_posterior_samples)(keys, init_params)
443459
raw_mcmc_samples = states.position
444-
potential_energy = states.potential_energy
460+
potential_energy = states.logdensity
445461
tic3 = datetime.now()
446462
print("Sampling time = ", tic3 - tic2, file=sys.stdout)
447463

448464
print("Transforming variables...", file=sys.stdout)
449465
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
450466
result = _postprocess_samples(
451-
jax_fn, raw_mcmc_samples, postprocessing_backend, num_chunks=postprocessing_chunks
467+
jax_fn,
468+
raw_mcmc_samples,
469+
postprocessing_backend=postprocessing_backend,
470+
postprocessing_vectorize=postprocessing_vectorize,
452471
)
453472
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
454473
mcmc_stats = _blackjax_stats_to_dict(stats, potential_energy)
@@ -467,7 +486,7 @@ def sample_blackjax_nuts(
467486
model,
468487
raw_mcmc_samples,
469488
backend=postprocessing_backend,
470-
num_chunks=postprocessing_chunks,
489+
postprocessing_vectorize=postprocessing_vectorize,
471490
)
472491
tic6 = datetime.now()
473492
print("Log Likelihood time = ", tic6 - tic5, file=sys.stdout)
@@ -527,10 +546,11 @@ def sample_numpyro_nuts(
527546
progressbar: bool = True,
528547
keep_untransformed: bool = False,
529548
chain_method: str = "parallel",
530-
postprocessing_backend: Optional[str] = None,
531-
postprocessing_chunks: Optional[int] = None,
549+
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
550+
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
532551
idata_kwargs: Optional[Dict] = None,
533552
nuts_kwargs: Optional[Dict] = None,
553+
postprocessing_chunks=None,
534554
) -> az.InferenceData:
535555
"""
536556
Draw samples from the posterior using the NUTS method from the ``numpyro`` library.
@@ -571,12 +591,10 @@ def sample_numpyro_nuts(
571591
chain_method : str, default "parallel"
572592
Specify how samples should be drawn. The choices include "sequential",
573593
"parallel", and "vectorized".
574-
postprocessing_backend : Optional[str]
594+
postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None,
575595
Specify how postprocessing should be computed. gpu or cpu
576-
postprocessing_chunks: Optional[int], default None
577-
Specify the number of chunks the postprocessing should be computed in. More
578-
chunks reduces memory usage at the cost of losing some vectorization, None
579-
uses jax.vmap
596+
postprocessing_vectorize: Literal["vmap", "scan"], default "scan"
597+
How to vectorize the postprocessing: vmap or sequential scan
580598
idata_kwargs : dict, optional
581599
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
582600
value for the ``log_likelihood`` key to indicate that the pointwise log
@@ -594,7 +612,14 @@ def sample_numpyro_nuts(
594612
with their respective sample stats and pointwise log likeihood values (unless
595613
skipped with ``idata_kwargs``).
596614
"""
615+
if postprocessing_chunks is not None:
616+
import warnings
597617

618+
warnings.warn(
619+
"postprocessing_chunks is deprecated due to being unstable, "
620+
"using postprocessing_vectorize='scan' instead",
621+
DeprecationWarning,
622+
)
598623
import numpyro
599624

600625
from numpyro.infer import MCMC, NUTS
@@ -667,7 +692,10 @@ def sample_numpyro_nuts(
667692
print("Transforming variables...", file=sys.stdout)
668693
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
669694
result = _postprocess_samples(
670-
jax_fn, raw_mcmc_samples, postprocessing_backend, num_chunks=postprocessing_chunks
695+
jax_fn,
696+
raw_mcmc_samples,
697+
postprocessing_backend=postprocessing_backend,
698+
postprocessing_vectorize=postprocessing_vectorize,
671699
)
672700
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
673701

@@ -686,7 +714,7 @@ def sample_numpyro_nuts(
686714
model,
687715
raw_mcmc_samples,
688716
backend=postprocessing_backend,
689-
num_chunks=postprocessing_chunks,
717+
postprocessing_vectorize=postprocessing_vectorize,
690718
)
691719
tic6 = datetime.now()
692720
print("Log Likelihood time = ", tic6 - tic5, file=sys.stdout)

pymc/sampling_jax.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,7 @@
1616

1717
# pylint: disable=wildcard-import
1818
# pylint: disable=unused-wildcard-import
19+
import warnings
1920

21+
warnings.warn("This module is deprecated, use pymc.sampling.jax", DeprecationWarning)
2022
from pymc.sampling.jax import *

tests/sampling/test_jax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def test_jax_PosDefMatrix():
8787
),
8888
],
8989
)
90-
@pytest.mark.parametrize("postprocessing_chunks", [None, 10])
91-
def test_transform_samples(sampler, postprocessing_backend, chains, postprocessing_chunks):
90+
@pytest.mark.parametrize("postprocessing_vectorize", ["scan", "vmap"])
91+
def test_transform_samples(sampler, postprocessing_backend, chains, postprocessing_vectorize):
9292
pytensor.config.on_opt_error = "raise"
9393
np.random.seed(13244)
9494

@@ -104,7 +104,7 @@ def test_transform_samples(sampler, postprocessing_backend, chains, postprocessi
104104
random_seed=1322,
105105
keep_untransformed=True,
106106
postprocessing_backend=postprocessing_backend,
107-
postprocessing_chunks=postprocessing_chunks,
107+
postprocessing_vectorize=postprocessing_vectorize,
108108
)
109109

110110
log_vals = trace.posterior["sigma_log__"].values

0 commit comments

Comments
 (0)