From 5a36685cd4fc7733fa493f7f9522a0ba108236dc Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 16 Feb 2022 02:44:41 +0000 Subject: [PATCH 01/16] Adding NUTS sampler from blackjax to sampling_jax --- pymc/sampling_jax.py | 155 +++++++++++++++++++++++++++++++- pymc/tests/test_sampling_jax.py | 33 +++++-- 2 files changed, 178 insertions(+), 10 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index d26c1553dd..23e4dad14b 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -4,6 +4,7 @@ import sys import warnings +from functools import partial from typing import Callable, List, Optional xla_flags = os.getenv("XLA_FLAGS", "") @@ -94,14 +95,17 @@ def get_jaxified_graph( return jax_funcify(fgraph) -def get_jaxified_logp(model: Model) -> Callable: +def get_jaxified_logp(model: Model, sampler=None) -> Callable: logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model.logpt()]) def logp_fn_wrap(x): # NumPyro expects a scalar potential with the opposite sign of model.logpt res = logp_fn(*x)[0] - return -res + if sampler == "numpyro": + return -res + else: + return res return logp_fn_wrap @@ -138,6 +142,151 @@ def _get_log_likelihood(model, samples): return data +@partial(jax.jit, static_argnums=(2,3,4,5,6)) +def _blackjax_inference_loop( + seed, + init_position, + logprob_fn, + draws, + tune, + target_accept, + algorithm=None, +): + import blackjax + + if algorithm is None: + algorithm = blackjax.nuts + + adapt = blackjax.window_adaptation( + algorithm=algorithm, + logprob_fn=logprob_fn, + num_steps=tune, + target_acceptance_rate=target_accept, + ) + last_state, kernel, _ = adapt.run(seed, init_position) + + def inference_loop(rng_key, initial_state): + def one_step(state, rng_key): + state, info = kernel(rng_key, state) + return state, (state, info) + + keys = jax.random.split(rng_key, draws) + _, (states, infos) = jax.lax.scan(one_step, initial_state, keys) + + return states, infos + + return inference_loop(seed, last_state) + + +def sample_blackjax_nuts( + draws=1000, + tune=1000, + chains=4, + target_accept=0.8, + random_seed=10, + model=None, + var_names=None, + progress_bar=True, # FIXME: Unused for now + keep_untransformed=False, + chain_method="parallel", + idata_kwargs=None, +): + model = modelcontext(model) + + if var_names is None: + var_names = model.unobserved_value_vars + + vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) + + coords = { + cname: np.array(cvals) if isinstance(cvals, tuple) else cvals + for cname, cvals in model.coords.items() + if cvals is not None + } + + if hasattr(model, "RV_dims"): + dims = { + var_name: [dim for dim in dims if dim is not None] + for var_name, dims in model.RV_dims.items() + } + else: + dims = {} + + tic1 = pd.Timestamp.now() + print("Compiling...", file=sys.stdout) + + rv_names = [rv.name for rv in model.value_vars] + initial_point = model.compute_initial_point() + init_state = [initial_point[rv_name] for rv_name in rv_names] + init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) + + logprob_fn = get_jaxified_logp(model, sampler="blackjax") + + seed = jax.random.PRNGKey(random_seed) + keys = jax.random.split(seed, chains) + + get_posterior_samples = partial( + _blackjax_inference_loop, + logprob_fn=logprob_fn, + tune=tune, + draws=draws, + target_accept=target_accept, + ) + + tic2 = pd.Timestamp.now() + print("Compilation time = ", tic2 - tic1, file=sys.stdout) + + print("Sampling...", file=sys.stdout) + + # Adapted from numpyro + if chain_method == 'parallel': + map_fn = jax.pmap + elif chain_method == 'vectorized': + map_fn = jax.vmap + else: + raise ValueError('Only supporting the following methods to draw chains:' + ' "parallel" or "vectorized"') + + states, _ = map_fn(get_posterior_samples)(keys, init_state_batched) + raw_mcmc_samples = states.position + + tic3 = pd.Timestamp.now() + print("Sampling time = ", tic3 - tic2, file=sys.stdout) + + print("Transforming variables...", file=sys.stdout) + mcmc_samples = {} + for v in vars_to_sample: + jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[v]) + result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] + mcmc_samples[v.name] = result + + tic4 = pd.Timestamp.now() + print("Transformation time = ", tic4 - tic3, file=sys.stdout) + + if idata_kwargs is None: + idata_kwargs = {} + else: + idata_kwargs = idata_kwargs.copy() + + if idata_kwargs.pop("log_likelihood", True): + log_likelihood = _get_log_likelihood(model, raw_mcmc_samples) + else: + log_likelihood = None + + posterior = mcmc_samples + az_trace = az.from_dict( + posterior=posterior, + log_likelihood=log_likelihood, + observed_data=find_observations(model), + coords=coords, + dims=dims, + attrs={"sampling_time": (tic3 - tic2).total_seconds()}, + **idata_kwargs, + ) + + return az_trace + + def sample_numpyro_nuts( draws=1000, tune=1000, @@ -182,7 +331,7 @@ def sample_numpyro_nuts( init_state = [initial_point[rv_name] for rv_name in rv_names] init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) - logp_fn = get_jaxified_logp(model) + logp_fn = get_jaxified_logp(model, sampler="numpyro") nuts_kernel = NUTS( potential_fn=logp_fn, diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index 6dad177c18..a44b780b8d 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -14,10 +14,17 @@ get_jaxified_graph, get_jaxified_logp, sample_numpyro_nuts, + sample_blackjax_nuts, ) -def test_transform_samples(): +@pytest.mark.parametrize( + "sampler", + [ + sample_blackjax_nuts, + sample_numpyro_nuts, + ]) +def test_transform_samples(sampler): aesara.config.on_opt_error = "raise" np.random.seed(13244) @@ -28,7 +35,7 @@ def test_transform_samples(): sigma = pm.HalfNormal("sigma") b = pm.Normal("b", a, sigma=sigma, observed=obs_at) - trace = sample_numpyro_nuts(chains=1, random_seed=1322, keep_untransformed=True) + trace = sampler(chains=1, random_seed=1322, keep_untransformed=True) log_vals = trace.posterior["sigma_log__"].values @@ -40,13 +47,19 @@ def test_transform_samples(): obs_at.set_value(-obs) with model: - trace = sample_numpyro_nuts(chains=2, random_seed=1322, keep_untransformed=False) + trace = sampler(chains=2, random_seed=1322, keep_untransformed=False) assert -11 < trace.posterior["a"].mean() < -8 assert 1.5 < trace.posterior["sigma"].mean() < 2.5 -def test_deterministic_samples(): +@pytest.mark.parametrize( + "sampler", + [ + sample_blackjax_nuts, + sample_numpyro_nuts, + ]) +def test_deterministic_samples(sampler): aesara.config.on_opt_error = "raise" np.random.seed(13244) @@ -57,7 +70,7 @@ def test_deterministic_samples(): b = pm.Deterministic("b", a / 2.0) c = pm.Normal("c", a, sigma=1.0, observed=obs_at) - trace = sample_numpyro_nuts(chains=2, random_seed=1322, keep_untransformed=True) + trace = sampler(chains=2, random_seed=1322, keep_untransformed=True) assert 8 < trace.posterior["a"].mean() < 11 assert np.allclose(trace.posterior["b"].values, trace.posterior["a"].values / 2) @@ -115,6 +128,12 @@ def test_get_jaxified_logp(): assert not np.isinf(jax_fn((np.array(5000.0), np.array(5000.0)))) +@pytest.mark.parametrize( + "sampler", + [ + sample_blackjax_nuts, + sample_numpyro_nuts, + ]) @pytest.mark.parametrize( "idata_kwargs", [ @@ -122,11 +141,11 @@ def test_get_jaxified_logp(): dict(log_likelihood=False), ], ) -def test_idata_kwargs(idata_kwargs): +def test_idata_kwargs(sampler, idata_kwargs): with pm.Model() as m: x = pm.Normal("x") y = pm.Normal("y", x, observed=0) - idata = sample_numpyro_nuts( + idata = sampler( tune=50, draws=50, chains=1, From 41425fbb8819fdd4798ed936d1257bd74010f806 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 16 Feb 2022 02:53:27 +0000 Subject: [PATCH 02/16] Lint fixes --- pymc/sampling_jax.py | 27 ++++++++++++++------------- pymc/tests/test_sampling_jax.py | 11 +++++++---- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 23e4dad14b..0b32fecfb9 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -142,15 +142,15 @@ def _get_log_likelihood(model, samples): return data -@partial(jax.jit, static_argnums=(2,3,4,5,6)) +@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6)) def _blackjax_inference_loop( - seed, - init_position, - logprob_fn, - draws, - tune, - target_accept, - algorithm=None, + seed, + init_position, + logprob_fn, + draws, + tune, + target_accept, + algorithm=None, ): import blackjax @@ -186,7 +186,7 @@ def sample_blackjax_nuts( random_seed=10, model=None, var_names=None, - progress_bar=True, # FIXME: Unused for now + progress_bar=True, # FIXME: Unused for now keep_untransformed=False, chain_method="parallel", idata_kwargs=None, @@ -239,13 +239,14 @@ def sample_blackjax_nuts( print("Sampling...", file=sys.stdout) # Adapted from numpyro - if chain_method == 'parallel': + if chain_method == "parallel": map_fn = jax.pmap - elif chain_method == 'vectorized': + elif chain_method == "vectorized": map_fn = jax.vmap else: - raise ValueError('Only supporting the following methods to draw chains:' - ' "parallel" or "vectorized"') + raise ValueError( + "Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"' + ) states, _ = map_fn(get_posterior_samples)(keys, init_state_batched) raw_mcmc_samples = states.position diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index a44b780b8d..93782e9d43 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -13,8 +13,8 @@ _replace_shared_variables, get_jaxified_graph, get_jaxified_logp, - sample_numpyro_nuts, sample_blackjax_nuts, + sample_numpyro_nuts, ) @@ -23,7 +23,8 @@ [ sample_blackjax_nuts, sample_numpyro_nuts, - ]) + ], +) def test_transform_samples(sampler): aesara.config.on_opt_error = "raise" np.random.seed(13244) @@ -58,7 +59,8 @@ def test_transform_samples(sampler): [ sample_blackjax_nuts, sample_numpyro_nuts, - ]) + ], +) def test_deterministic_samples(sampler): aesara.config.on_opt_error = "raise" np.random.seed(13244) @@ -133,7 +135,8 @@ def test_get_jaxified_logp(): [ sample_blackjax_nuts, sample_numpyro_nuts, - ]) + ], +) @pytest.mark.parametrize( "idata_kwargs", [ From 1a9e8d64c8ed3f28e969baa6c2851826048e2449 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 16 Feb 2022 13:00:59 +0000 Subject: [PATCH 03/16] Refactor get_jaxified_logp --- pymc/sampling_jax.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 0b32fecfb9..bc8aa87dcf 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -95,17 +95,16 @@ def get_jaxified_graph( return jax_funcify(fgraph) -def get_jaxified_logp(model: Model, sampler=None) -> Callable: - - logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model.logpt()]) +def get_jaxified_logp(model: Model, negative_logp=True) -> Callable: + model_logpt = model.logpt() + if not negative_logp: + model_logpt = -model_logpt + logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt]) def logp_fn_wrap(x): # NumPyro expects a scalar potential with the opposite sign of model.logpt res = logp_fn(*x)[0] - if sampler == "numpyro": - return -res - else: - return res + return res return logp_fn_wrap @@ -220,7 +219,7 @@ def sample_blackjax_nuts( init_state = [initial_point[rv_name] for rv_name in rv_names] init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) - logprob_fn = get_jaxified_logp(model, sampler="blackjax") + logprob_fn = get_jaxified_logp(model) seed = jax.random.PRNGKey(random_seed) keys = jax.random.split(seed, chains) @@ -332,7 +331,7 @@ def sample_numpyro_nuts( init_state = [initial_point[rv_name] for rv_name in rv_names] init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) - logp_fn = get_jaxified_logp(model, sampler="numpyro") + logp_fn = get_jaxified_logp(model, negative_logp=False) nuts_kernel = NUTS( potential_fn=logp_fn, From ec62c899e086dd3a4137ced6dd61b5da98f1b33c Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 16 Feb 2022 13:02:30 +0000 Subject: [PATCH 04/16] Install blackjax in workflows --- .github/workflows/jaxtests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/jaxtests.yml b/.github/workflows/jaxtests.yml index cf1f5b64df..c4c3e5ec79 100644 --- a/.github/workflows/jaxtests.yml +++ b/.github/workflows/jaxtests.yml @@ -71,6 +71,7 @@ jobs: run: | conda activate pymc-test-py39 pip install "numpyro>=0.8.0" + pip install https://github.com/blackjax-devs/blackjax@main - name: Run tests run: | python -m pytest -vv --cov=pymc --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET From 2394537bab321d08e340b7ef70f923cf2cefbdd0 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 16 Feb 2022 13:18:20 +0000 Subject: [PATCH 05/16] Fix url --- .github/workflows/jaxtests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/jaxtests.yml b/.github/workflows/jaxtests.yml index c4c3e5ec79..0a470acd9a 100644 --- a/.github/workflows/jaxtests.yml +++ b/.github/workflows/jaxtests.yml @@ -71,7 +71,7 @@ jobs: run: | conda activate pymc-test-py39 pip install "numpyro>=0.8.0" - pip install https://github.com/blackjax-devs/blackjax@main + pip install git+https://github.com/blackjax-devs/blackjax.git@main - name: Run tests run: | python -m pytest -vv --cov=pymc --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET From 0e01e26b3d217a3f460ff9c3f79857ec3539ef45 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 16 Feb 2022 15:39:25 +0000 Subject: [PATCH 06/16] Simplify function --- pymc/sampling_jax.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index bc8aa87dcf..d61f7d5efe 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -102,9 +102,7 @@ def get_jaxified_logp(model: Model, negative_logp=True) -> Callable: logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt]) def logp_fn_wrap(x): - # NumPyro expects a scalar potential with the opposite sign of model.logpt - res = logp_fn(*x)[0] - return res + return logp_fn(*x)[0] return logp_fn_wrap From 75df8d2e95bed981a178866acb7a2e9ed90a0fdd Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 16 Feb 2022 17:48:30 +0000 Subject: [PATCH 07/16] Add documentation --- docs/source/api/samplers.rst | 2 + pymc/sampling_jax.py | 78 ++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/docs/source/api/samplers.rst b/docs/source/api/samplers.rst index 2e4a4cad13..cb3bad126e 100644 --- a/docs/source/api/samplers.rst +++ b/docs/source/api/samplers.rst @@ -30,6 +30,8 @@ HMC family NUTS HamiltonianMC + sampling_jax.sample_blackjax_nuts + sampling_jax.sample_numpyro_nuts Metropolis family ----------------- diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index d61f7d5efe..1cc6f0112e 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -188,6 +188,45 @@ def sample_blackjax_nuts( chain_method="parallel", idata_kwargs=None, ): + """Draw samples from the posterior using the NUTS method from the blackjax library. + + Parameters + ---------- + draws : int + The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded + by default. + tune : int + Number of iterations to tune, defaults to 1000. Samplers adjust the step sizes, scalings or + similar during tuning. Tuning samples will be drawn in addition to the number specified in + the ``draws`` argument. + chains : int + The number of chains to sample. Defaults to 4. + target_accept : float in [0, 1]. + The step size is tuned such that we approximate this acceptance rate. Higher values like + 0.9 or 0.95 often work better for problematic posteriors. + random_seed : int + Random seed used by the sampling steps. + model : Model (optional if in ``with`` context) + Model to sample from. The model needs to have free random variables. + var_names : Iterable[str] + Names of variables for which to compute the posterior samples. + progress_bar : bool, optional default=True + Whether or not to display a progress bar in the command line. The bar shows the percentage + of completion, the sampling speed in samples per second (SPS), and the estimated remaining + time until completion ("expected time of arrival"; ETA). + keep_untransformed : bool, optional default=False + Include untransformed variables in the posterior samples. Defaults to False. + chain_method : str + Specify how samples should be drawn. The choices include "parallel", and + "vectorized". Defaults to "parallel". + idata_kwargs : dict, optional + Keyword arguments for :func:`pymc.to_inference_data` + Returns + ------- + trace : arviz.InferenceData + ArviZ ``InferenceData`` object that contains the samples. + """ + model = modelcontext(model) if var_names is None: @@ -298,6 +337,45 @@ def sample_numpyro_nuts( chain_method="parallel", idata_kwargs=None, ): + """Draw samples from the posterior using the NUTS method from the numpyro library. + + Parameters + ---------- + draws : int + The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded + by default. + tune : int + Number of iterations to tune, defaults to 1000. Samplers adjust the step sizes, scalings or + similar during tuning. Tuning samples will be drawn in addition to the number specified in + the ``draws`` argument. + chains : int + The number of chains to sample. Defaults to 4. + target_accept : float in [0, 1]. + The step size is tuned such that we approximate this acceptance rate. Higher values like + 0.9 or 0.95 often work better for problematic posteriors. + random_seed : int + Random seed used by the sampling steps. + model : Model (optional if in ``with`` context) + Model to sample from. The model needs to have free random variables. + var_names : Iterable[str] + Names of variables for which to compute the posterior samples. + progress_bar : bool, optional default=True + Whether or not to display a progress bar in the command line. The bar shows the percentage + of completion, the sampling speed in samples per second (SPS), and the estimated remaining + time until completion ("expected time of arrival"; ETA). + keep_untransformed : bool, optional default=False + Include untransformed variables in the posterior samples. Defaults to False. + chain_method : str + Specify how samples should be drawn. The choices include "sequential", "parallel", and + "vectorized". Defaults to "parallel". + idata_kwargs : dict, optional + Keyword arguments for :func:`pymc.to_inference_data` + Returns + ------- + trace : arviz.InferenceData + ArviZ ``InferenceData`` object that contains the samples. + """ + from numpyro.infer import MCMC, NUTS model = modelcontext(model) From 3b5c9487ade758267c7d0a863d8576a4680079ca Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 16 Feb 2022 19:40:00 +0000 Subject: [PATCH 08/16] Add library versions --- pymc/sampling_jax.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 1cc6f0112e..00cfc121fb 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -24,6 +24,7 @@ from aesara.link.jax.dispatch import jax_funcify from aesara.raise_op import Assert from aesara.tensor import TensorVariable +from arviz.data.base import make_attrs from pymc import Model, modelcontext from pymc.backends.arviz import find_observations @@ -188,7 +189,8 @@ def sample_blackjax_nuts( chain_method="parallel", idata_kwargs=None, ): - """Draw samples from the posterior using the NUTS method from the blackjax library. + """ + Draw samples from the posterior using the NUTS method from the blackjax library. Parameters ---------- @@ -226,6 +228,7 @@ def sample_blackjax_nuts( trace : arviz.InferenceData ArviZ ``InferenceData`` object that contains the samples. """ + import blackjax model = modelcontext(model) @@ -310,6 +313,10 @@ def sample_blackjax_nuts( else: log_likelihood = None + attrs = { + "sampling_time": (tic3 - tic2).total_seconds(), + } + posterior = mcmc_samples az_trace = az.from_dict( posterior=posterior, @@ -317,7 +324,7 @@ def sample_blackjax_nuts( observed_data=find_observations(model), coords=coords, dims=dims, - attrs={"sampling_time": (tic3 - tic2).total_seconds()}, + attrs=make_attrs(attrs, library=blackjax), **idata_kwargs, ) @@ -337,7 +344,8 @@ def sample_numpyro_nuts( chain_method="parallel", idata_kwargs=None, ): - """Draw samples from the posterior using the NUTS method from the numpyro library. + """ + Draw samples from the posterior using the NUTS method from the numpyro library. Parameters ---------- @@ -376,6 +384,8 @@ def sample_numpyro_nuts( ArviZ ``InferenceData`` object that contains the samples. """ + import numpyro + from numpyro.infer import MCMC, NUTS model = modelcontext(model) @@ -479,6 +489,10 @@ def sample_numpyro_nuts( else: log_likelihood = None + attrs = { + "sampling_time": (tic3 - tic2).total_seconds(), + } + posterior = mcmc_samples az_trace = az.from_dict( posterior=posterior, @@ -487,7 +501,7 @@ def sample_numpyro_nuts( sample_stats=_sample_stats_to_xarray(pmap_numpyro), coords=coords, dims=dims, - attrs={"sampling_time": (tic3 - tic2).total_seconds()}, + attrs=make_attrs(attrs, library=numpyro), **idata_kwargs, ) From 8a27c1fa2b80a43afacdbd9e843b19353afb33ed Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 16 Feb 2022 19:44:32 +0000 Subject: [PATCH 09/16] Move to more appropriate section --- docs/source/api/samplers.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/api/samplers.rst b/docs/source/api/samplers.rst index cb3bad126e..614ac47d6e 100644 --- a/docs/source/api/samplers.rst +++ b/docs/source/api/samplers.rst @@ -13,6 +13,8 @@ This submodule contains functions for MCMC and forward sampling. sample_prior_predictive sample_posterior_predictive sample_posterior_predictive_w + sampling_jax.sample_blackjax_nuts + sampling_jax.sample_numpyro_nuts iter_sample init_nuts draw @@ -30,8 +32,6 @@ HMC family NUTS HamiltonianMC - sampling_jax.sample_blackjax_nuts - sampling_jax.sample_numpyro_nuts Metropolis family ----------------- From ea5874bdc78c3de7c345a6e776bbe24af618f3e7 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 16 Feb 2022 21:24:06 +0000 Subject: [PATCH 10/16] Fix docstrings --- pymc/sampling_jax.py | 90 +++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 42 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 00cfc121fb..4a9d51e5f2 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -190,43 +190,46 @@ def sample_blackjax_nuts( idata_kwargs=None, ): """ - Draw samples from the posterior using the NUTS method from the blackjax library. + Draw samples from the posterior using the NUTS method from the ``blackjax`` library. Parameters ---------- - draws : int - The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded - by default. - tune : int - Number of iterations to tune, defaults to 1000. Samplers adjust the step sizes, scalings or + draws : int, default 1000 + The number of samples to draw. The number of tuned samples are discarded by default. + tune : int, default 1000 + Number of iterations to tune. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the ``draws`` argument. - chains : int - The number of chains to sample. Defaults to 4. + chains : int, default 4 + The number of chains to sample. target_accept : float in [0, 1]. The step size is tuned such that we approximate this acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic posteriors. - random_seed : int + random_seed : int, default 10 Random seed used by the sampling steps. - model : Model (optional if in ``with`` context) - Model to sample from. The model needs to have free random variables. - var_names : Iterable[str] - Names of variables for which to compute the posterior samples. - progress_bar : bool, optional default=True + model : Model, optional + Model to sample from. The model needs to have free random variables. When inside a ``with`` model + context, it defaults to that model, otherwise the model must be passed explicitly. + var_names : iterable of str, optional + Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior + progress_bar : bool, default True Whether or not to display a progress bar in the command line. The bar shows the percentage of completion, the sampling speed in samples per second (SPS), and the estimated remaining time until completion ("expected time of arrival"; ETA). - keep_untransformed : bool, optional default=False + keep_untransformed : bool, default False Include untransformed variables in the posterior samples. Defaults to False. - chain_method : str - Specify how samples should be drawn. The choices include "parallel", and - "vectorized". Defaults to "parallel". + chain_method : str, default "parallel" + Specify how samples should be drawn. The choices include "parallel", and "vectorized". idata_kwargs : dict, optional - Keyword arguments for :func:`pymc.to_inference_data` + Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value + for the ``log_likelihood`` key to indicate that the pointwise log likelihood should + not be included in the returned object. + Returns ------- - trace : arviz.InferenceData - ArviZ ``InferenceData`` object that contains the samples. + InferenceData + ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and + pointwise log likeihood values (unless skipped with ``idata_kwargs``). """ import blackjax @@ -345,43 +348,46 @@ def sample_numpyro_nuts( idata_kwargs=None, ): """ - Draw samples from the posterior using the NUTS method from the numpyro library. + Draw samples from the posterior using the NUTS method from the ``numpyro`` library. Parameters ---------- - draws : int - The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded - by default. - tune : int - Number of iterations to tune, defaults to 1000. Samplers adjust the step sizes, scalings or + draws : int, default 1000 + The number of samples to draw. The number of tuned samples are discarded by default. + tune : int, default 1000 + Number of iterations to tune. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the ``draws`` argument. - chains : int - The number of chains to sample. Defaults to 4. + chains : int, default 4 + The number of chains to sample. target_accept : float in [0, 1]. The step size is tuned such that we approximate this acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic posteriors. - random_seed : int + random_seed : int, default 10 Random seed used by the sampling steps. - model : Model (optional if in ``with`` context) - Model to sample from. The model needs to have free random variables. - var_names : Iterable[str] - Names of variables for which to compute the posterior samples. - progress_bar : bool, optional default=True + model : Model, optional + Model to sample from. The model needs to have free random variables. When inside a ``with`` model + context, it defaults to that model, otherwise the model must be passed explicitly. + var_names : iterable of str, optional + Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior + progress_bar : bool, default True Whether or not to display a progress bar in the command line. The bar shows the percentage of completion, the sampling speed in samples per second (SPS), and the estimated remaining time until completion ("expected time of arrival"; ETA). - keep_untransformed : bool, optional default=False + keep_untransformed : bool, default False Include untransformed variables in the posterior samples. Defaults to False. - chain_method : str - Specify how samples should be drawn. The choices include "sequential", "parallel", and - "vectorized". Defaults to "parallel". + chain_method : str, default "parallel" + Specify how samples should be drawn. The choices include "sequential", "parallel", and "vectorized". idata_kwargs : dict, optional - Keyword arguments for :func:`pymc.to_inference_data` + Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value + for the ``log_likelihood`` key to indicate that the pointwise log likelihood should + not be included in the returned object. + Returns ------- - trace : arviz.InferenceData - ArviZ ``InferenceData`` object that contains the samples. + InferenceData + ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and + pointwise log likeihood values (unless skipped with ``idata_kwargs``). """ import numpyro From 659828c61df04e9ec062b281bb018c1328813b82 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 16 Feb 2022 22:59:32 +0000 Subject: [PATCH 11/16] Add jax to dev environment --- conda-envs/environment-dev-py37.yml | 1 + conda-envs/environment-dev-py38.yml | 1 + conda-envs/environment-dev-py39.yml | 1 + conda-envs/environment-test-py37.yml | 1 + conda-envs/environment-test-py38.yml | 1 + conda-envs/environment-test-py39.yml | 1 + conda-envs/windows-environment-dev-py38.yml | 1 + conda-envs/windows-environment-test-py38.yml | 1 + requirements-dev.txt | 1 + 9 files changed, 9 insertions(+) diff --git a/conda-envs/environment-dev-py37.yml b/conda-envs/environment-dev-py37.yml index 1cf01e6cca..8e96923234 100644 --- a/conda-envs/environment-dev-py37.yml +++ b/conda-envs/environment-dev-py37.yml @@ -13,6 +13,7 @@ dependencies: - fastprogress>=0.2.0 - h5py>=2.7 - ipython>=7.16 +- jax - myst-nb - numpy>=1.15.0 - numpydoc<1.2 diff --git a/conda-envs/environment-dev-py38.yml b/conda-envs/environment-dev-py38.yml index 1cb5abf637..621a7329ae 100644 --- a/conda-envs/environment-dev-py38.yml +++ b/conda-envs/environment-dev-py38.yml @@ -13,6 +13,7 @@ dependencies: - fastprogress>=0.2.0 - h5py>=2.7 - ipython>=7.16 +- jax - myst-nb - numpy>=1.15.0 - numpydoc<1.2 diff --git a/conda-envs/environment-dev-py39.yml b/conda-envs/environment-dev-py39.yml index 4c683a70a7..d037fb62c2 100644 --- a/conda-envs/environment-dev-py39.yml +++ b/conda-envs/environment-dev-py39.yml @@ -13,6 +13,7 @@ dependencies: - fastprogress>=0.2.0 - h5py>=2.7 - ipython>=7.16 +- jax - myst-nb - numpy>=1.15.0 - numpydoc<1.2 diff --git a/conda-envs/environment-test-py37.yml b/conda-envs/environment-test-py37.yml index 2c93125602..3f60e053ef 100644 --- a/conda-envs/environment-test-py37.yml +++ b/conda-envs/environment-test-py37.yml @@ -13,6 +13,7 @@ dependencies: - fastprogress>=0.2.0 - h5py>=2.7 - ipython>=7.16 +- jax - libblas=*=*mkl - mkl-service - numpy>=1.15.0 diff --git a/conda-envs/environment-test-py38.yml b/conda-envs/environment-test-py38.yml index 5d87f33945..b0bc35ef62 100644 --- a/conda-envs/environment-test-py38.yml +++ b/conda-envs/environment-test-py38.yml @@ -13,6 +13,7 @@ dependencies: - fastprogress>=0.2.0 - h5py>=2.7 - ipython>=7.16 +- jax - libblas=*=*mkl - mkl-service - numpy>=1.15.0 diff --git a/conda-envs/environment-test-py39.yml b/conda-envs/environment-test-py39.yml index 02d16cb22d..d11f9ef383 100644 --- a/conda-envs/environment-test-py39.yml +++ b/conda-envs/environment-test-py39.yml @@ -13,6 +13,7 @@ dependencies: - fastprogress>=0.2.0 - h5py>=2.7 - ipython>=7.16 +- jax - libblas=*=*mkl - mkl-service - numpy>=1.15.0 diff --git a/conda-envs/windows-environment-dev-py38.yml b/conda-envs/windows-environment-dev-py38.yml index f165a7a5b6..8746ddd30b 100644 --- a/conda-envs/windows-environment-dev-py38.yml +++ b/conda-envs/windows-environment-dev-py38.yml @@ -21,6 +21,7 @@ dependencies: - typing-extensions>=3.7.4 # Extra stuff for dev, testing and docs build - ipython>=7.16 +- jax - myst-nb - numpydoc<1.2 - pre-commit>=2.8.0 diff --git a/conda-envs/windows-environment-test-py38.yml b/conda-envs/windows-environment-test-py38.yml index 3adac2cd4a..ca917bc15a 100644 --- a/conda-envs/windows-environment-test-py38.yml +++ b/conda-envs/windows-environment-test-py38.yml @@ -12,6 +12,7 @@ dependencies: - cloudpickle - fastprogress>=0.2.0 - h5py>=2.7 +- jax - libpython - mkl==2020.4 - mkl-service==2.3.0 diff --git a/requirements-dev.txt b/requirements-dev.txt index 55568d4a27..ced8514db3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,6 +9,7 @@ cloudpickle fastprogress>=0.2.0 h5py>=2.7 ipython>=7.16 +jax myst-nb numpy>=1.15.0 numpydoc<1.2 From 6460a169dcb22bf1ec160c9d3f09ad91099390aa Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 16 Feb 2022 23:18:00 +0000 Subject: [PATCH 12/16] Jax doesn't work on windows, remove it --- conda-envs/windows-environment-dev-py38.yml | 1 - conda-envs/windows-environment-test-py38.yml | 1 - requirements-dev.txt | 1 - 3 files changed, 3 deletions(-) diff --git a/conda-envs/windows-environment-dev-py38.yml b/conda-envs/windows-environment-dev-py38.yml index 8746ddd30b..f165a7a5b6 100644 --- a/conda-envs/windows-environment-dev-py38.yml +++ b/conda-envs/windows-environment-dev-py38.yml @@ -21,7 +21,6 @@ dependencies: - typing-extensions>=3.7.4 # Extra stuff for dev, testing and docs build - ipython>=7.16 -- jax - myst-nb - numpydoc<1.2 - pre-commit>=2.8.0 diff --git a/conda-envs/windows-environment-test-py38.yml b/conda-envs/windows-environment-test-py38.yml index ca917bc15a..3adac2cd4a 100644 --- a/conda-envs/windows-environment-test-py38.yml +++ b/conda-envs/windows-environment-test-py38.yml @@ -12,7 +12,6 @@ dependencies: - cloudpickle - fastprogress>=0.2.0 - h5py>=2.7 -- jax - libpython - mkl==2020.4 - mkl-service==2.3.0 diff --git a/requirements-dev.txt b/requirements-dev.txt index ced8514db3..55568d4a27 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,7 +9,6 @@ cloudpickle fastprogress>=0.2.0 h5py>=2.7 ipython>=7.16 -jax myst-nb numpy>=1.15.0 numpydoc<1.2 From d8a3898b533f0570b971a20c77b5a55fb238c1c8 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 16 Feb 2022 23:26:42 +0000 Subject: [PATCH 13/16] Exclude jax --- scripts/generate_pip_deps_from_conda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/generate_pip_deps_from_conda.py b/scripts/generate_pip_deps_from_conda.py index 0f717564e5..2d29fc3193 100755 --- a/scripts/generate_pip_deps_from_conda.py +++ b/scripts/generate_pip_deps_from_conda.py @@ -52,6 +52,7 @@ "numba", "python-graphviz", "blas", + "jax", } RENAME = {} From 6ba2c31e9ffc84478b8ad80aa3b376eeb800e190 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Mon, 28 Feb 2022 12:24:29 +0000 Subject: [PATCH 14/16] Fix merge --- pymc/sampling_jax.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 129b028b8d..f6fd077843 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -155,7 +155,7 @@ def _get_batched_jittered_initial_points( Each item has shape `(chains, *var.shape)` """ if isinstance(random_seed, (int, np.integer)): - random_seed = np.random.default_rng(random_seed).integers(2**30, size=chains) + random_seed = np.random.default_rng(random_seed).integers(2 ** 30, size=chains) elif not isinstance(random_seed, (list, tuple, np.ndarray)): raise ValueError(f"The `seeds` must be int or array-like. Got {type(random_seed)} instead.") @@ -218,6 +218,7 @@ def sample_blackjax_nuts( chains=4, target_accept=0.8, random_seed=10, + initvals=None, model=None, var_names=None, progress_bar=True, # FIXME: Unused for now @@ -290,13 +291,19 @@ def sample_blackjax_nuts( else: dims = {} - tic1 = pd.Timestamp.now() + tic1 = datetime.now() print("Compiling...", file=sys.stdout) - rv_names = [rv.name for rv in model.value_vars] - initial_point = model.compute_initial_point() - init_state = [initial_point[rv_name] for rv_name in rv_names] - init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) + init_params = _get_batched_jittered_initial_points( + model=model, + chains=chains, + initvals=initvals, + random_seed=random_seed, + ) + + if chains == 1: + init_params = [np.stack(init_params)] + init_params = [np.stack(init_state) for init_state in zip(*init_params)] logprob_fn = get_jaxified_logp(model) @@ -311,7 +318,7 @@ def sample_blackjax_nuts( target_accept=target_accept, ) - tic2 = pd.Timestamp.now() + tic2 = datetime.now() print("Compilation time = ", tic2 - tic1, file=sys.stdout) print("Sampling...", file=sys.stdout) @@ -326,10 +333,10 @@ def sample_blackjax_nuts( "Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"' ) - states, _ = map_fn(get_posterior_samples)(keys, init_state_batched) + states, _ = map_fn(get_posterior_samples)(keys, init_params) raw_mcmc_samples = states.position - tic3 = pd.Timestamp.now() + tic3 = datetime.now() print("Sampling time = ", tic3 - tic2, file=sys.stdout) print("Transforming variables...", file=sys.stdout) @@ -339,7 +346,7 @@ def sample_blackjax_nuts( result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] mcmc_samples[v.name] = result - tic4 = pd.Timestamp.now() + tic4 = datetime.now() print("Transformation time = ", tic4 - tic3, file=sys.stdout) if idata_kwargs is None: @@ -368,8 +375,6 @@ def sample_blackjax_nuts( ) return az_trace -======= ->>>>>>> main def sample_numpyro_nuts( @@ -457,7 +462,7 @@ def sample_numpyro_nuts( if random_seed is None: random_seed = model.rng_seeder.randint( - 2**30, dtype=np.int64, size=chains if chains > 1 else None + 2 ** 30, dtype=np.int64, size=chains if chains > 1 else None ) tic1 = datetime.now() From 1bf85f0e8bab058542471dacbf97928551d84cf9 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Mon, 28 Feb 2022 12:34:38 +0000 Subject: [PATCH 15/16] Remove progress bar functionality for now --- pymc/sampling_jax.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index f6fd077843..b1d4edb7a5 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -221,7 +221,6 @@ def sample_blackjax_nuts( initvals=None, model=None, var_names=None, - progress_bar=True, # FIXME: Unused for now keep_untransformed=False, chain_method="parallel", idata_kwargs=None, @@ -249,10 +248,6 @@ def sample_blackjax_nuts( context, it defaults to that model, otherwise the model must be passed explicitly. var_names : iterable of str, optional Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior - progress_bar : bool, default True - Whether or not to display a progress bar in the command line. The bar shows the percentage - of completion, the sampling speed in samples per second (SPS), and the estimated remaining - time until completion ("expected time of arrival"; ETA). keep_untransformed : bool, default False Include untransformed variables in the posterior samples. Defaults to False. chain_method : str, default "parallel" From b1436d95fbdf5b586f9881eb1628746af70f2928 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Mon, 28 Feb 2022 12:42:02 +0000 Subject: [PATCH 16/16] Pre-commit fix --- pymc/sampling_jax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index b1d4edb7a5..d23cdb55c5 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -155,7 +155,7 @@ def _get_batched_jittered_initial_points( Each item has shape `(chains, *var.shape)` """ if isinstance(random_seed, (int, np.integer)): - random_seed = np.random.default_rng(random_seed).integers(2 ** 30, size=chains) + random_seed = np.random.default_rng(random_seed).integers(2**30, size=chains) elif not isinstance(random_seed, (list, tuple, np.ndarray)): raise ValueError(f"The `seeds` must be int or array-like. Got {type(random_seed)} instead.") @@ -457,7 +457,7 @@ def sample_numpyro_nuts( if random_seed is None: random_seed = model.rng_seeder.randint( - 2 ** 30, dtype=np.int64, size=chains if chains > 1 else None + 2**30, dtype=np.int64, size=chains if chains > 1 else None ) tic1 = datetime.now()