diff --git a/.github/workflows/jaxtests.yml b/.github/workflows/jaxtests.yml index cf1f5b64df..0a470acd9a 100644 --- a/.github/workflows/jaxtests.yml +++ b/.github/workflows/jaxtests.yml @@ -71,6 +71,7 @@ jobs: run: | conda activate pymc-test-py39 pip install "numpyro>=0.8.0" + pip install git+https://github.com/blackjax-devs/blackjax.git@main - name: Run tests run: | python -m pytest -vv --cov=pymc --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET diff --git a/conda-envs/environment-dev-py37.yml b/conda-envs/environment-dev-py37.yml index 1d1e936ab9..db99b252b4 100644 --- a/conda-envs/environment-dev-py37.yml +++ b/conda-envs/environment-dev-py37.yml @@ -13,6 +13,7 @@ dependencies: - fastprogress>=0.2.0 - h5py>=2.7 - ipython>=7.16 +- jax - myst-nb - numpy>=1.15.0 - numpydoc<1.2 diff --git a/conda-envs/environment-dev-py38.yml b/conda-envs/environment-dev-py38.yml index d4426963e1..c5b900b93c 100644 --- a/conda-envs/environment-dev-py38.yml +++ b/conda-envs/environment-dev-py38.yml @@ -13,6 +13,7 @@ dependencies: - fastprogress>=0.2.0 - h5py>=2.7 - ipython>=7.16 +- jax - myst-nb - numpy>=1.15.0 - numpydoc<1.2 diff --git a/conda-envs/environment-dev-py39.yml b/conda-envs/environment-dev-py39.yml index 6e50880151..6d2c86ef47 100644 --- a/conda-envs/environment-dev-py39.yml +++ b/conda-envs/environment-dev-py39.yml @@ -13,6 +13,7 @@ dependencies: - fastprogress>=0.2.0 - h5py>=2.7 - ipython>=7.16 +- jax - myst-nb - numpy>=1.15.0 - numpydoc<1.2 diff --git a/conda-envs/environment-test-py37.yml b/conda-envs/environment-test-py37.yml index ab8bdbe386..5104750b52 100644 --- a/conda-envs/environment-test-py37.yml +++ b/conda-envs/environment-test-py37.yml @@ -13,6 +13,7 @@ dependencies: - fastprogress>=0.2.0 - h5py>=2.7 - ipython>=7.16 +- jax - libblas=*=*mkl - mkl-service - numpy>=1.15.0 diff --git a/conda-envs/environment-test-py38.yml b/conda-envs/environment-test-py38.yml index 9c51bd0d67..9ecf3bce2c 100644 --- a/conda-envs/environment-test-py38.yml +++ b/conda-envs/environment-test-py38.yml @@ -13,6 +13,7 @@ dependencies: - fastprogress>=0.2.0 - h5py>=2.7 - ipython>=7.16 +- jax - libblas=*=*mkl - mkl-service - numpy>=1.15.0 diff --git a/conda-envs/environment-test-py39.yml b/conda-envs/environment-test-py39.yml index cc869f9fe6..a1eec94818 100644 --- a/conda-envs/environment-test-py39.yml +++ b/conda-envs/environment-test-py39.yml @@ -13,6 +13,7 @@ dependencies: - fastprogress>=0.2.0 - h5py>=2.7 - ipython>=7.16 +- jax - libblas=*=*mkl - mkl-service - numpy>=1.15.0 diff --git a/docs/source/api/samplers.rst b/docs/source/api/samplers.rst index 2e4a4cad13..614ac47d6e 100644 --- a/docs/source/api/samplers.rst +++ b/docs/source/api/samplers.rst @@ -13,6 +13,8 @@ This submodule contains functions for MCMC and forward sampling. sample_prior_predictive sample_posterior_predictive sample_posterior_predictive_w + sampling_jax.sample_blackjax_nuts + sampling_jax.sample_numpyro_nuts iter_sample init_nuts draw diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index e8537b5ef8..967d000858 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -3,6 +3,7 @@ import sys import warnings +from functools import partial from typing import Callable, Dict, List, Optional, Sequence, Union from pymc.initial_point import StartDict @@ -26,6 +27,7 @@ from aesara.link.jax.dispatch import jax_funcify from aesara.raise_op import Assert from aesara.tensor import TensorVariable +from arviz.data.base import make_attrs from pymc import Model, modelcontext from pymc.backends.arviz import find_observations @@ -97,14 +99,14 @@ def get_jaxified_graph( return jax_funcify(fgraph) -def get_jaxified_logp(model: Model) -> Callable: - - logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model.logpt()]) +def get_jaxified_logp(model: Model, negative_logp=True) -> Callable: + model_logpt = model.logpt() + if not negative_logp: + model_logpt = -model_logpt + logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt]) def logp_fn_wrap(x): - # NumPyro expects a scalar potential with the opposite sign of model.logpt - res = logp_fn(*x)[0] - return -res + return logp_fn(*x)[0] return logp_fn_wrap @@ -177,6 +179,202 @@ def _get_batched_jittered_initial_points( return initial_points +@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6)) +def _blackjax_inference_loop( + seed, + init_position, + logprob_fn, + draws, + tune, + target_accept, + algorithm=None, +): + import blackjax + + if algorithm is None: + algorithm = blackjax.nuts + + adapt = blackjax.window_adaptation( + algorithm=algorithm, + logprob_fn=logprob_fn, + num_steps=tune, + target_acceptance_rate=target_accept, + ) + last_state, kernel, _ = adapt.run(seed, init_position) + + def inference_loop(rng_key, initial_state): + def one_step(state, rng_key): + state, info = kernel(rng_key, state) + return state, (state, info) + + keys = jax.random.split(rng_key, draws) + _, (states, infos) = jax.lax.scan(one_step, initial_state, keys) + + return states, infos + + return inference_loop(seed, last_state) + + +def sample_blackjax_nuts( + draws=1000, + tune=1000, + chains=4, + target_accept=0.8, + random_seed=10, + initvals=None, + model=None, + var_names=None, + keep_untransformed=False, + chain_method="parallel", + idata_kwargs=None, +): + """ + Draw samples from the posterior using the NUTS method from the ``blackjax`` library. + + Parameters + ---------- + draws : int, default 1000 + The number of samples to draw. The number of tuned samples are discarded by default. + tune : int, default 1000 + Number of iterations to tune. Samplers adjust the step sizes, scalings or + similar during tuning. Tuning samples will be drawn in addition to the number specified in + the ``draws`` argument. + chains : int, default 4 + The number of chains to sample. + target_accept : float in [0, 1]. + The step size is tuned such that we approximate this acceptance rate. Higher values like + 0.9 or 0.95 often work better for problematic posteriors. + random_seed : int, default 10 + Random seed used by the sampling steps. + model : Model, optional + Model to sample from. The model needs to have free random variables. When inside a ``with`` model + context, it defaults to that model, otherwise the model must be passed explicitly. + var_names : iterable of str, optional + Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior + keep_untransformed : bool, default False + Include untransformed variables in the posterior samples. Defaults to False. + chain_method : str, default "parallel" + Specify how samples should be drawn. The choices include "parallel", and "vectorized". + idata_kwargs : dict, optional + Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value + for the ``log_likelihood`` key to indicate that the pointwise log likelihood should + not be included in the returned object. + + Returns + ------- + InferenceData + ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and + pointwise log likeihood values (unless skipped with ``idata_kwargs``). + """ + import blackjax + + model = modelcontext(model) + + if var_names is None: + var_names = model.unobserved_value_vars + + vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) + + coords = { + cname: np.array(cvals) if isinstance(cvals, tuple) else cvals + for cname, cvals in model.coords.items() + if cvals is not None + } + + if hasattr(model, "RV_dims"): + dims = { + var_name: [dim for dim in dims if dim is not None] + for var_name, dims in model.RV_dims.items() + } + else: + dims = {} + + tic1 = datetime.now() + print("Compiling...", file=sys.stdout) + + init_params = _get_batched_jittered_initial_points( + model=model, + chains=chains, + initvals=initvals, + random_seed=random_seed, + ) + + if chains == 1: + init_params = [np.stack(init_params)] + init_params = [np.stack(init_state) for init_state in zip(*init_params)] + + logprob_fn = get_jaxified_logp(model) + + seed = jax.random.PRNGKey(random_seed) + keys = jax.random.split(seed, chains) + + get_posterior_samples = partial( + _blackjax_inference_loop, + logprob_fn=logprob_fn, + tune=tune, + draws=draws, + target_accept=target_accept, + ) + + tic2 = datetime.now() + print("Compilation time = ", tic2 - tic1, file=sys.stdout) + + print("Sampling...", file=sys.stdout) + + # Adapted from numpyro + if chain_method == "parallel": + map_fn = jax.pmap + elif chain_method == "vectorized": + map_fn = jax.vmap + else: + raise ValueError( + "Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"' + ) + + states, _ = map_fn(get_posterior_samples)(keys, init_params) + raw_mcmc_samples = states.position + + tic3 = datetime.now() + print("Sampling time = ", tic3 - tic2, file=sys.stdout) + + print("Transforming variables...", file=sys.stdout) + mcmc_samples = {} + for v in vars_to_sample: + jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[v]) + result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] + mcmc_samples[v.name] = result + + tic4 = datetime.now() + print("Transformation time = ", tic4 - tic3, file=sys.stdout) + + if idata_kwargs is None: + idata_kwargs = {} + else: + idata_kwargs = idata_kwargs.copy() + + if idata_kwargs.pop("log_likelihood", True): + log_likelihood = _get_log_likelihood(model, raw_mcmc_samples) + else: + log_likelihood = None + + attrs = { + "sampling_time": (tic3 - tic2).total_seconds(), + } + + posterior = mcmc_samples + az_trace = az.from_dict( + posterior=posterior, + log_likelihood=log_likelihood, + observed_data=find_observations(model), + coords=coords, + dims=dims, + attrs=make_attrs(attrs, library=blackjax), + **idata_kwargs, + ) + + return az_trace + + def sample_numpyro_nuts( draws: int = 1000, tune: int = 1000, @@ -192,6 +390,51 @@ def sample_numpyro_nuts( idata_kwargs: Optional[Dict] = None, nuts_kwargs: Optional[Dict] = None, ): + """ + Draw samples from the posterior using the NUTS method from the ``numpyro`` library. + + Parameters + ---------- + draws : int, default 1000 + The number of samples to draw. The number of tuned samples are discarded by default. + tune : int, default 1000 + Number of iterations to tune. Samplers adjust the step sizes, scalings or + similar during tuning. Tuning samples will be drawn in addition to the number specified in + the ``draws`` argument. + chains : int, default 4 + The number of chains to sample. + target_accept : float in [0, 1]. + The step size is tuned such that we approximate this acceptance rate. Higher values like + 0.9 or 0.95 often work better for problematic posteriors. + random_seed : int, default 10 + Random seed used by the sampling steps. + model : Model, optional + Model to sample from. The model needs to have free random variables. When inside a ``with`` model + context, it defaults to that model, otherwise the model must be passed explicitly. + var_names : iterable of str, optional + Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior + progress_bar : bool, default True + Whether or not to display a progress bar in the command line. The bar shows the percentage + of completion, the sampling speed in samples per second (SPS), and the estimated remaining + time until completion ("expected time of arrival"; ETA). + keep_untransformed : bool, default False + Include untransformed variables in the posterior samples. Defaults to False. + chain_method : str, default "parallel" + Specify how samples should be drawn. The choices include "sequential", "parallel", and "vectorized". + idata_kwargs : dict, optional + Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value + for the ``log_likelihood`` key to indicate that the pointwise log likelihood should + not be included in the returned object. + + Returns + ------- + InferenceData + ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and + pointwise log likeihood values (unless skipped with ``idata_kwargs``). + """ + + import numpyro + from numpyro.infer import MCMC, NUTS model = modelcontext(model) @@ -228,7 +471,7 @@ def sample_numpyro_nuts( random_seed=random_seed, ) - logp_fn = get_jaxified_logp(model) + logp_fn = get_jaxified_logp(model, negative_logp=False) if nuts_kwargs is None: nuts_kwargs = {} @@ -298,6 +541,10 @@ def sample_numpyro_nuts( else: log_likelihood = None + attrs = { + "sampling_time": (tic3 - tic2).total_seconds(), + } + posterior = mcmc_samples az_trace = az.from_dict( posterior=posterior, @@ -306,7 +553,7 @@ def sample_numpyro_nuts( sample_stats=_sample_stats_to_xarray(pmap_numpyro), coords=coords, dims=dims, - attrs={"sampling_time": (tic3 - tic2).total_seconds()}, + attrs=make_attrs(attrs, library=numpyro), **idata_kwargs, ) diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index 2c81ac7d7a..60bd1e9783 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -14,11 +14,19 @@ _replace_shared_variables, get_jaxified_graph, get_jaxified_logp, + sample_blackjax_nuts, sample_numpyro_nuts, ) -def test_transform_samples(): +@pytest.mark.parametrize( + "sampler", + [ + sample_blackjax_nuts, + sample_numpyro_nuts, + ], +) +def test_transform_samples(sampler): aesara.config.on_opt_error = "raise" np.random.seed(13244) @@ -29,7 +37,7 @@ def test_transform_samples(): sigma = pm.HalfNormal("sigma") b = pm.Normal("b", a, sigma=sigma, observed=obs_at) - trace = sample_numpyro_nuts(chains=1, random_seed=1322, keep_untransformed=True) + trace = sampler(chains=1, random_seed=1322, keep_untransformed=True) log_vals = trace.posterior["sigma_log__"].values @@ -41,13 +49,20 @@ def test_transform_samples(): obs_at.set_value(-obs) with model: - trace = sample_numpyro_nuts(chains=2, random_seed=1322, keep_untransformed=False) + trace = sampler(chains=2, random_seed=1322, keep_untransformed=False) assert -11 < trace.posterior["a"].mean() < -8 assert 1.5 < trace.posterior["sigma"].mean() < 2.5 -def test_deterministic_samples(): +@pytest.mark.parametrize( + "sampler", + [ + sample_blackjax_nuts, + sample_numpyro_nuts, + ], +) +def test_deterministic_samples(sampler): aesara.config.on_opt_error = "raise" np.random.seed(13244) @@ -58,7 +73,7 @@ def test_deterministic_samples(): b = pm.Deterministic("b", a / 2.0) c = pm.Normal("c", a, sigma=1.0, observed=obs_at) - trace = sample_numpyro_nuts(chains=2, random_seed=1322, keep_untransformed=True) + trace = sampler(chains=2, random_seed=1322, keep_untransformed=True) assert 8 < trace.posterior["a"].mean() < 11 assert np.allclose(trace.posterior["b"].values, trace.posterior["a"].values / 2) @@ -116,6 +131,13 @@ def test_get_jaxified_logp(): assert not np.isinf(jax_fn((np.array(5000.0), np.array(5000.0)))) +@pytest.mark.parametrize( + "sampler", + [ + sample_blackjax_nuts, + sample_numpyro_nuts, + ], +) @pytest.mark.parametrize( "idata_kwargs", [ @@ -123,11 +145,11 @@ def test_get_jaxified_logp(): dict(log_likelihood=False), ], ) -def test_idata_kwargs(idata_kwargs): +def test_idata_kwargs(sampler, idata_kwargs): with pm.Model() as m: x = pm.Normal("x") y = pm.Normal("y", x, observed=0) - idata = sample_numpyro_nuts( + idata = sampler( tune=50, draws=50, chains=1, diff --git a/scripts/generate_pip_deps_from_conda.py b/scripts/generate_pip_deps_from_conda.py index 0f717564e5..2d29fc3193 100755 --- a/scripts/generate_pip_deps_from_conda.py +++ b/scripts/generate_pip_deps_from_conda.py @@ -52,6 +52,7 @@ "numba", "python-graphviz", "blas", + "jax", } RENAME = {}