From 5baedaf3f05a5c2442e8f8996cf60bb38ad0dc7d Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Tue, 9 Jul 2024 13:41:09 +0200 Subject: [PATCH 1/5] Reduce blackjax sampling memory usage ... by not outputing the warmup diagnositics --- pymc/sampling/jax.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 390661fdc2..df6bd302b5 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -232,6 +232,7 @@ def _blackjax_inference_loop( seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs ): import blackjax + from blackjax.adaptation.base import get_filter_adapt_info_fn algorithm_name = adaptation_kwargs.pop("algorithm", "nuts") if algorithm_name == "nuts": @@ -245,6 +246,7 @@ def _blackjax_inference_loop( algorithm=algorithm, logdensity_fn=logprob_fn, target_acceptance_rate=target_accept, + adaptation_info_fn=get_filter_adapt_info_fn(), **adaptation_kwargs, ) (last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune) From de3cd0e0b7ae12429e2165a7d391b9d93821ec28 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Fri, 12 Jul 2024 11:49:54 +0200 Subject: [PATCH 2/5] Update jax env --- conda-envs/environment-jax.yml | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml index cd2f63de23..f7e488ceb1 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-jax.yml @@ -11,23 +11,21 @@ dependencies: - cloudpickle - h5py>=2.7 # Jaxlib version must not be greater than jax version! -- blackjax==1.2.0 # Blackjax>=1.2.1 is incompatible with latest available version of jaxlib in conda-forge -- jaxlib==0.4.23 # Latest available version in conda-forge, update when new version is available -- jax==0.4.23 +- blackjax +- jaxlib +- jax - libblas=*=*mkl - mkl-service - numpy>=1.15.0 - numpyro>=0.8.0 - pandas>=0.24.0 - pip -- pytensor>=2.23,<2.24 +- pytensor - python-graphviz - networkx - rich>=13.7.1 - threadpoolctl>=3.1.0 -# JAX is only compatible with Scipy 1.13.0 from >=0.4.26, but the respective version of -# JAXlib is still not on conda: https://github.com/conda-forge/jaxlib-feedstock/pull/243 -- scipy>=1.4.1,<1.13.0 +- scipy - typing-extensions>=3.7.4 # Extra dependencies for testing - ipython>=7.16 From 7ca4d30e8010e7fad41cc0121695784be32b2608 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Fri, 12 Jul 2024 11:54:46 +0200 Subject: [PATCH 3/5] fix pre-commit --- pymc/sampling/jax.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 2d20beb922..c4d9099b90 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -242,6 +242,7 @@ def _blackjax_inference_loop( seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs ): import blackjax + from blackjax.adaptation.base import get_filter_adapt_info_fn algorithm_name = adaptation_kwargs.pop("algorithm", "nuts") From 2434e37a2e93aeedb8573fb859718e98099bbdd3 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Fri, 12 Jul 2024 11:54:46 +0200 Subject: [PATCH 4/5] skip also RuntimeWarning --- tests/sampling/test_mcmc_external.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index eab6b402a6..4c594a2b64 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -51,7 +51,7 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): warns = { (warn.category, warn.message.args[0]) for warn in recwarn - if warn.category not in (FutureWarning, DeprecationWarning) + if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning) } expected = set() if nuts_sampler == "nutpie": From e81d8289ddbaa093874d8a6c8d0bdda6dfdc206e Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Sat, 13 Jul 2024 08:16:00 +0200 Subject: [PATCH 5/5] ping jax versions --- conda-envs/environment-jax.yml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml index cb69daad40..97d25dd5b8 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-jax.yml @@ -11,9 +11,9 @@ dependencies: - cloudpickle - h5py>=2.7 # Jaxlib version must not be greater than jax version! -- blackjax -- jaxlib -- jax +- blackjax>=1.2.2 +- jax>=0.4.28 +- jaxlib>=0.4.28 - libblas=*=*mkl - mkl-service - numpy>=1.15.0 @@ -25,7 +25,8 @@ dependencies: - networkx - rich>=13.7.1 - threadpoolctl>=3.1.0 -- scipy +# JAX is only compatible with Scipy 1.13.0 from >=0.4.26 +- scipy>=1.13.0 - typing-extensions>=3.7.4 # Extra dependencies for testing - ipython>=7.16