diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml index 5d79c60a8c..97d25dd5b8 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-jax.yml @@ -11,9 +11,9 @@ dependencies: - cloudpickle - h5py>=2.7 # Jaxlib version must not be greater than jax version! -- blackjax==1.2.0 # Blackjax>=1.2.1 is incompatible with latest available version of jaxlib in conda-forge -- jaxlib==0.4.23 # Latest available version in conda-forge, update when new version is available -- jax==0.4.23 +- blackjax>=1.2.2 +- jax>=0.4.28 +- jaxlib>=0.4.28 - libblas=*=*mkl - mkl-service - numpy>=1.15.0 @@ -25,9 +25,8 @@ dependencies: - networkx - rich>=13.7.1 - threadpoolctl>=3.1.0 -# JAX is only compatible with Scipy 1.13.0 from >=0.4.26, but the respective version of -# JAXlib is still not on conda: https://github.com/conda-forge/jaxlib-feedstock/pull/243 -- scipy>=1.4.1,<1.13.0 +# JAX is only compatible with Scipy 1.13.0 from >=0.4.26 +- scipy>=1.13.0 - typing-extensions>=3.7.4 # Extra dependencies for testing - ipython>=7.16 diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 3439cfd470..c4d9099b90 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -243,6 +243,8 @@ def _blackjax_inference_loop( ): import blackjax + from blackjax.adaptation.base import get_filter_adapt_info_fn + algorithm_name = adaptation_kwargs.pop("algorithm", "nuts") if algorithm_name == "nuts": algorithm = blackjax.nuts @@ -255,6 +257,7 @@ def _blackjax_inference_loop( algorithm=algorithm, logdensity_fn=logprob_fn, target_acceptance_rate=target_accept, + adaptation_info_fn=get_filter_adapt_info_fn(), **adaptation_kwargs, ) (last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune) diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index eab6b402a6..4c594a2b64 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -51,7 +51,7 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): warns = { (warn.category, warn.message.args[0]) for warn in recwarn - if warn.category not in (FutureWarning, DeprecationWarning) + if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning) } expected = set() if nuts_sampler == "nutpie":