Skip to content

Adding NUTS sampler from blackjax to sampling_jax #5477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/jaxtests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ jobs:
run: |
conda activate pymc-test-py39
pip install "numpyro>=0.8.0"
pip install git+https://github.com/blackjax-devs/blackjax.git@main
- name: Run tests
run: |
python -m pytest -vv --cov=pymc --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-dev-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies:
- fastprogress>=0.2.0
- h5py>=2.7
- ipython>=7.16
- jax
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure we should add jax as a dependency. By that logic we should also add numpyro and blackjax...

The jaxtests workflow installs it manually

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this was for documentation... Yeah it seems like we need to have an environment for docs... feel free to ignore my comment then

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need them in the dev environments, only for test (until we add a docs env).

Copy link
Member

@ricardoV94 ricardoV94 Feb 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need for the docs to render properly

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the docs using the dev environments?

Copy link
Member

@ricardoV94 ricardoV94 Feb 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the python 3.8 dev one, just out of convenience I think. We should just create a specific environment for the docs...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, shouldn't we have it use the test env instead? I guess if we start a doc env it doesn't really matter, but maybe they would end up identical?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think docs requires quite more dependencies than tests

- myst-nb
- numpy>=1.15.0
- numpydoc<1.2
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-dev-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies:
- fastprogress>=0.2.0
- h5py>=2.7
- ipython>=7.16
- jax
- myst-nb
- numpy>=1.15.0
- numpydoc<1.2
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-dev-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies:
- fastprogress>=0.2.0
- h5py>=2.7
- ipython>=7.16
- jax
- myst-nb
- numpy>=1.15.0
- numpydoc<1.2
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-test-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies:
- fastprogress>=0.2.0
- h5py>=2.7
- ipython>=7.16
- jax
- libblas=*=*mkl
- mkl-service
- numpy>=1.15.0
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-test-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies:
- fastprogress>=0.2.0
- h5py>=2.7
- ipython>=7.16
- jax
- libblas=*=*mkl
- mkl-service
- numpy>=1.15.0
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-test-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies:
- fastprogress>=0.2.0
- h5py>=2.7
- ipython>=7.16
- jax
- libblas=*=*mkl
- mkl-service
- numpy>=1.15.0
Expand Down
2 changes: 2 additions & 0 deletions docs/source/api/samplers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ This submodule contains functions for MCMC and forward sampling.
sample_prior_predictive
sample_posterior_predictive
sample_posterior_predictive_w
sampling_jax.sample_blackjax_nuts
sampling_jax.sample_numpyro_nuts
iter_sample
init_nuts
draw
Expand Down
261 changes: 253 additions & 8 deletions pymc/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import warnings

from functools import partial
from typing import Callable, List, Optional

xla_flags = os.getenv("XLA_FLAGS", "")
Expand All @@ -23,6 +24,7 @@
from aesara.link.jax.dispatch import jax_funcify
from aesara.raise_op import Assert
from aesara.tensor import TensorVariable
from arviz.data.base import make_attrs

from pymc import Model, modelcontext
from pymc.backends.arviz import find_observations
Expand Down Expand Up @@ -94,14 +96,14 @@ def get_jaxified_graph(
return jax_funcify(fgraph)


def get_jaxified_logp(model: Model) -> Callable:

logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model.logpt()])
def get_jaxified_logp(model: Model, negative_logp=True) -> Callable:
model_logpt = model.logpt()
if not negative_logp:
model_logpt = -model_logpt
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt])

def logp_fn_wrap(x):
# NumPyro expects a scalar potential with the opposite sign of model.logpt
res = logp_fn(*x)[0]
return -res
return logp_fn(*x)[0]

return logp_fn_wrap

Expand Down Expand Up @@ -138,6 +140,200 @@ def _get_log_likelihood(model, samples):
return data


@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
def _blackjax_inference_loop(
seed,
init_position,
logprob_fn,
draws,
tune,
target_accept,
algorithm=None,
):
import blackjax

if algorithm is None:
algorithm = blackjax.nuts

adapt = blackjax.window_adaptation(
algorithm=algorithm,
logprob_fn=logprob_fn,
num_steps=tune,
target_acceptance_rate=target_accept,
)
last_state, kernel, _ = adapt.run(seed, init_position)

def inference_loop(rng_key, initial_state):
def one_step(state, rng_key):
state, info = kernel(rng_key, state)
return state, (state, info)

keys = jax.random.split(rng_key, draws)
_, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

return states, infos

return inference_loop(seed, last_state)


def sample_blackjax_nuts(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function needs to be documented and added to the API docs (where the numpyro one might also be missing?). It probably makes more sense to add in the first block of https://github.com/pymc-devs/pymc/blob/main/docs/source/api/samplers.rst.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call.

Copy link
Contributor Author

@zaxtax zaxtax Feb 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit rusty with sphinx, where do I specify that the functions are in sampling_jax

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by that point currentmodule is set to pymc, so sphinx assumes functions live there. If they are in pymc.sampling_jax.<function> you'll need to use sampling_jax.<function> as the pymc is assumed because of the currentmodule use

draws=1000,
tune=1000,
chains=4,
target_accept=0.8,
random_seed=10,
model=None,
var_names=None,
progress_bar=True, # FIXME: Unused for now
keep_untransformed=False,
chain_method="parallel",
idata_kwargs=None,
):
"""
Draw samples from the posterior using the NUTS method from the ``blackjax`` library.

Parameters
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update with the advise on https://pymc-data-umbrella.xyz/en/latest/sprint/docstring_tutorial.html#edit-the-docstring and I'll take a 2nd look later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I must be missing something but I think I'm style compliant. I actually copied the docstrings from sample as my starting point.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically none of the docstrings are currently style compliant, and most aren't rendered correctly either because of that. We have an issue open for this: #5459. I can give a quick go at one of the docstrings. I think you'll also need to rebase on main for the docs preview to work and CI to pass.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some comments, but they are not exaustive

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok it was how I'm specifying defaults and optional.

----------
draws : int, default 1000
The number of samples to draw. The number of tuned samples are discarded by default.
tune : int, default 1000
Number of iterations to tune. Samplers adjust the step sizes, scalings or
similar during tuning. Tuning samples will be drawn in addition to the number specified in
the ``draws`` argument.
chains : int, default 4
The number of chains to sample.
target_accept : float in [0, 1].
The step size is tuned such that we approximate this acceptance rate. Higher values like
0.9 or 0.95 often work better for problematic posteriors.
random_seed : int, default 10
Random seed used by the sampling steps.
model : Model, optional
Model to sample from. The model needs to have free random variables. When inside a ``with`` model
context, it defaults to that model, otherwise the model must be passed explicitly.
var_names : iterable of str, optional
Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior
progress_bar : bool, default True
Whether or not to display a progress bar in the command line. The bar shows the percentage
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
time until completion ("expected time of arrival"; ETA).
keep_untransformed : bool, default False
Include untransformed variables in the posterior samples. Defaults to False.
chain_method : str, default "parallel"
Specify how samples should be drawn. The choices include "parallel", and "vectorized".
idata_kwargs : dict, optional
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
not be included in the returned object.

Returns
-------
InferenceData
ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and
pointwise log likeihood values (unless skipped with ``idata_kwargs``).
"""
import blackjax

model = modelcontext(model)

if var_names is None:
var_names = model.unobserved_value_vars

vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))

coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in model.coords.items()
if cvals is not None
}

if hasattr(model, "RV_dims"):
dims = {
var_name: [dim for dim in dims if dim is not None]
for var_name, dims in model.RV_dims.items()
}
else:
dims = {}

tic1 = pd.Timestamp.now()
print("Compiling...", file=sys.stdout)

rv_names = [rv.name for rv in model.value_vars]
initial_point = model.compute_initial_point()
init_state = [initial_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

logprob_fn = get_jaxified_logp(model)

seed = jax.random.PRNGKey(random_seed)
keys = jax.random.split(seed, chains)

get_posterior_samples = partial(
_blackjax_inference_loop,
logprob_fn=logprob_fn,
tune=tune,
draws=draws,
target_accept=target_accept,
)

tic2 = pd.Timestamp.now()
print("Compilation time = ", tic2 - tic1, file=sys.stdout)

print("Sampling...", file=sys.stdout)

# Adapted from numpyro
if chain_method == "parallel":
map_fn = jax.pmap
elif chain_method == "vectorized":
map_fn = jax.vmap
else:
raise ValueError(
"Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
)

states, _ = map_fn(get_posterior_samples)(keys, init_state_batched)
raw_mcmc_samples = states.position

tic3 = pd.Timestamp.now()
print("Sampling time = ", tic3 - tic2, file=sys.stdout)

print("Transforming variables...", file=sys.stdout)
mcmc_samples = {}
for v in vars_to_sample:
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[v])
result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0]
mcmc_samples[v.name] = result

tic4 = pd.Timestamp.now()
print("Transformation time = ", tic4 - tic3, file=sys.stdout)

if idata_kwargs is None:
idata_kwargs = {}
else:
idata_kwargs = idata_kwargs.copy()

if idata_kwargs.pop("log_likelihood", True):
log_likelihood = _get_log_likelihood(model, raw_mcmc_samples)
else:
log_likelihood = None

attrs = {
"sampling_time": (tic3 - tic2).total_seconds(),
}

posterior = mcmc_samples
az_trace = az.from_dict(
posterior=posterior,
log_likelihood=log_likelihood,
observed_data=find_observations(model),
coords=coords,
dims=dims,
attrs=make_attrs(attrs, library=blackjax),
**idata_kwargs,
)

return az_trace


def sample_numpyro_nuts(
draws=1000,
tune=1000,
Expand All @@ -151,6 +347,51 @@ def sample_numpyro_nuts(
chain_method="parallel",
idata_kwargs=None,
):
"""
Draw samples from the posterior using the NUTS method from the ``numpyro`` library.

Parameters
----------
draws : int, default 1000
The number of samples to draw. The number of tuned samples are discarded by default.
tune : int, default 1000
Number of iterations to tune. Samplers adjust the step sizes, scalings or
similar during tuning. Tuning samples will be drawn in addition to the number specified in
the ``draws`` argument.
chains : int, default 4
The number of chains to sample.
target_accept : float in [0, 1].
The step size is tuned such that we approximate this acceptance rate. Higher values like
0.9 or 0.95 often work better for problematic posteriors.
random_seed : int, default 10
Random seed used by the sampling steps.
model : Model, optional
Model to sample from. The model needs to have free random variables. When inside a ``with`` model
context, it defaults to that model, otherwise the model must be passed explicitly.
var_names : iterable of str, optional
Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior
progress_bar : bool, default True
Whether or not to display a progress bar in the command line. The bar shows the percentage
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
time until completion ("expected time of arrival"; ETA).
keep_untransformed : bool, default False
Include untransformed variables in the posterior samples. Defaults to False.
chain_method : str, default "parallel"
Specify how samples should be drawn. The choices include "sequential", "parallel", and "vectorized".
idata_kwargs : dict, optional
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
not be included in the returned object.

Returns
-------
InferenceData
ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and
pointwise log likeihood values (unless skipped with ``idata_kwargs``).
"""

import numpyro

from numpyro.infer import MCMC, NUTS

model = modelcontext(model)
Expand Down Expand Up @@ -182,7 +423,7 @@ def sample_numpyro_nuts(
init_state = [initial_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

logp_fn = get_jaxified_logp(model)
logp_fn = get_jaxified_logp(model, negative_logp=False)

nuts_kernel = NUTS(
potential_fn=logp_fn,
Expand Down Expand Up @@ -254,6 +495,10 @@ def sample_numpyro_nuts(
else:
log_likelihood = None

attrs = {
"sampling_time": (tic3 - tic2).total_seconds(),
}

posterior = mcmc_samples
az_trace = az.from_dict(
posterior=posterior,
Expand All @@ -262,7 +507,7 @@ def sample_numpyro_nuts(
sample_stats=_sample_stats_to_xarray(pmap_numpyro),
coords=coords,
dims=dims,
attrs={"sampling_time": (tic3 - tic2).total_seconds()},
attrs=make_attrs(attrs, library=numpyro),
**idata_kwargs,
)

Expand Down
Loading