-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Adding NUTS sampler from blackjax to sampling_jax #5477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
5a36685
41425fb
1a9e8d6
ec62c89
2394537
0e01e26
75df8d2
3b5c948
8a27c1f
ea5874b
659828c
6460a16
d8a3898
df86b55
6ba2c31
1bf85f0
b1436d9
c2a5ea7
d5fcbb0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
import sys | ||
import warnings | ||
|
||
from functools import partial | ||
from typing import Callable, List, Optional | ||
|
||
xla_flags = os.getenv("XLA_FLAGS", "") | ||
|
@@ -23,6 +24,7 @@ | |
from aesara.link.jax.dispatch import jax_funcify | ||
from aesara.raise_op import Assert | ||
from aesara.tensor import TensorVariable | ||
from arviz.data.base import make_attrs | ||
|
||
from pymc import Model, modelcontext | ||
from pymc.backends.arviz import find_observations | ||
|
@@ -94,14 +96,14 @@ def get_jaxified_graph( | |
return jax_funcify(fgraph) | ||
|
||
|
||
def get_jaxified_logp(model: Model) -> Callable: | ||
|
||
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model.logpt()]) | ||
def get_jaxified_logp(model: Model, negative_logp=True) -> Callable: | ||
model_logpt = model.logpt() | ||
if not negative_logp: | ||
model_logpt = -model_logpt | ||
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt]) | ||
|
||
def logp_fn_wrap(x): | ||
# NumPyro expects a scalar potential with the opposite sign of model.logpt | ||
res = logp_fn(*x)[0] | ||
return -res | ||
return logp_fn(*x)[0] | ||
|
||
return logp_fn_wrap | ||
|
||
|
@@ -138,6 +140,200 @@ def _get_log_likelihood(model, samples): | |
return data | ||
|
||
|
||
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6)) | ||
def _blackjax_inference_loop( | ||
seed, | ||
init_position, | ||
logprob_fn, | ||
draws, | ||
tune, | ||
target_accept, | ||
algorithm=None, | ||
): | ||
import blackjax | ||
|
||
if algorithm is None: | ||
algorithm = blackjax.nuts | ||
|
||
adapt = blackjax.window_adaptation( | ||
algorithm=algorithm, | ||
logprob_fn=logprob_fn, | ||
num_steps=tune, | ||
target_acceptance_rate=target_accept, | ||
) | ||
last_state, kernel, _ = adapt.run(seed, init_position) | ||
|
||
def inference_loop(rng_key, initial_state): | ||
def one_step(state, rng_key): | ||
state, info = kernel(rng_key, state) | ||
return state, (state, info) | ||
|
||
keys = jax.random.split(rng_key, draws) | ||
_, (states, infos) = jax.lax.scan(one_step, initial_state, keys) | ||
|
||
return states, infos | ||
|
||
return inference_loop(seed, last_state) | ||
|
||
|
||
def sample_blackjax_nuts( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function needs to be documented and added to the API docs (where the numpyro one might also be missing?). It probably makes more sense to add in the first block of https://github.com/pymc-devs/pymc/blob/main/docs/source/api/samplers.rst. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good call. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit rusty with sphinx, where do I specify that the functions are in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. by that point |
||
draws=1000, | ||
tune=1000, | ||
chains=4, | ||
target_accept=0.8, | ||
random_seed=10, | ||
model=None, | ||
var_names=None, | ||
progress_bar=True, # FIXME: Unused for now | ||
keep_untransformed=False, | ||
chain_method="parallel", | ||
idata_kwargs=None, | ||
): | ||
""" | ||
Draw samples from the posterior using the NUTS method from the ``blackjax`` library. | ||
|
||
Parameters | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you update with the advise on https://pymc-data-umbrella.xyz/en/latest/sprint/docstring_tutorial.html#edit-the-docstring and I'll take a 2nd look later? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I must be missing something but I think I'm style compliant. I actually copied the docstrings from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Basically none of the docstrings are currently style compliant, and most aren't rendered correctly either because of that. We have an issue open for this: #5459. I can give a quick go at one of the docstrings. I think you'll also need to rebase on main for the docs preview to work and CI to pass. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added some comments, but they are not exaustive There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh ok it was how I'm specifying defaults and optional. |
||
---------- | ||
draws : int, default 1000 | ||
The number of samples to draw. The number of tuned samples are discarded by default. | ||
tune : int, default 1000 | ||
Number of iterations to tune. Samplers adjust the step sizes, scalings or | ||
similar during tuning. Tuning samples will be drawn in addition to the number specified in | ||
the ``draws`` argument. | ||
chains : int, default 4 | ||
The number of chains to sample. | ||
target_accept : float in [0, 1]. | ||
The step size is tuned such that we approximate this acceptance rate. Higher values like | ||
0.9 or 0.95 often work better for problematic posteriors. | ||
random_seed : int, default 10 | ||
Random seed used by the sampling steps. | ||
model : Model, optional | ||
Model to sample from. The model needs to have free random variables. When inside a ``with`` model | ||
context, it defaults to that model, otherwise the model must be passed explicitly. | ||
var_names : iterable of str, optional | ||
Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior | ||
progress_bar : bool, default True | ||
Whether or not to display a progress bar in the command line. The bar shows the percentage | ||
of completion, the sampling speed in samples per second (SPS), and the estimated remaining | ||
time until completion ("expected time of arrival"; ETA). | ||
keep_untransformed : bool, default False | ||
Include untransformed variables in the posterior samples. Defaults to False. | ||
chain_method : str, default "parallel" | ||
Specify how samples should be drawn. The choices include "parallel", and "vectorized". | ||
idata_kwargs : dict, optional | ||
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value | ||
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should | ||
not be included in the returned object. | ||
|
||
Returns | ||
------- | ||
InferenceData | ||
ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and | ||
pointwise log likeihood values (unless skipped with ``idata_kwargs``). | ||
""" | ||
import blackjax | ||
|
||
model = modelcontext(model) | ||
|
||
if var_names is None: | ||
var_names = model.unobserved_value_vars | ||
|
||
vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) | ||
|
||
coords = { | ||
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals | ||
for cname, cvals in model.coords.items() | ||
if cvals is not None | ||
} | ||
|
||
if hasattr(model, "RV_dims"): | ||
dims = { | ||
var_name: [dim for dim in dims if dim is not None] | ||
for var_name, dims in model.RV_dims.items() | ||
} | ||
else: | ||
dims = {} | ||
|
||
tic1 = pd.Timestamp.now() | ||
print("Compiling...", file=sys.stdout) | ||
|
||
rv_names = [rv.name for rv in model.value_vars] | ||
initial_point = model.compute_initial_point() | ||
init_state = [initial_point[rv_name] for rv_name in rv_names] | ||
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) | ||
|
||
logprob_fn = get_jaxified_logp(model) | ||
|
||
seed = jax.random.PRNGKey(random_seed) | ||
keys = jax.random.split(seed, chains) | ||
|
||
get_posterior_samples = partial( | ||
_blackjax_inference_loop, | ||
logprob_fn=logprob_fn, | ||
tune=tune, | ||
draws=draws, | ||
target_accept=target_accept, | ||
) | ||
|
||
tic2 = pd.Timestamp.now() | ||
print("Compilation time = ", tic2 - tic1, file=sys.stdout) | ||
|
||
print("Sampling...", file=sys.stdout) | ||
|
||
# Adapted from numpyro | ||
if chain_method == "parallel": | ||
map_fn = jax.pmap | ||
elif chain_method == "vectorized": | ||
map_fn = jax.vmap | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
raise ValueError( | ||
"Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"' | ||
) | ||
|
||
states, _ = map_fn(get_posterior_samples)(keys, init_state_batched) | ||
raw_mcmc_samples = states.position | ||
|
||
tic3 = pd.Timestamp.now() | ||
print("Sampling time = ", tic3 - tic2, file=sys.stdout) | ||
|
||
print("Transforming variables...", file=sys.stdout) | ||
mcmc_samples = {} | ||
for v in vars_to_sample: | ||
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[v]) | ||
result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] | ||
mcmc_samples[v.name] = result | ||
|
||
tic4 = pd.Timestamp.now() | ||
print("Transformation time = ", tic4 - tic3, file=sys.stdout) | ||
|
||
if idata_kwargs is None: | ||
idata_kwargs = {} | ||
else: | ||
idata_kwargs = idata_kwargs.copy() | ||
|
||
if idata_kwargs.pop("log_likelihood", True): | ||
log_likelihood = _get_log_likelihood(model, raw_mcmc_samples) | ||
else: | ||
log_likelihood = None | ||
|
||
attrs = { | ||
"sampling_time": (tic3 - tic2).total_seconds(), | ||
} | ||
|
||
posterior = mcmc_samples | ||
az_trace = az.from_dict( | ||
posterior=posterior, | ||
log_likelihood=log_likelihood, | ||
observed_data=find_observations(model), | ||
coords=coords, | ||
dims=dims, | ||
attrs=make_attrs(attrs, library=blackjax), | ||
**idata_kwargs, | ||
) | ||
|
||
return az_trace | ||
|
||
|
||
def sample_numpyro_nuts( | ||
draws=1000, | ||
tune=1000, | ||
|
@@ -151,6 +347,51 @@ def sample_numpyro_nuts( | |
chain_method="parallel", | ||
idata_kwargs=None, | ||
): | ||
""" | ||
Draw samples from the posterior using the NUTS method from the ``numpyro`` library. | ||
|
||
Parameters | ||
---------- | ||
draws : int, default 1000 | ||
The number of samples to draw. The number of tuned samples are discarded by default. | ||
tune : int, default 1000 | ||
Number of iterations to tune. Samplers adjust the step sizes, scalings or | ||
similar during tuning. Tuning samples will be drawn in addition to the number specified in | ||
the ``draws`` argument. | ||
chains : int, default 4 | ||
The number of chains to sample. | ||
target_accept : float in [0, 1]. | ||
The step size is tuned such that we approximate this acceptance rate. Higher values like | ||
0.9 or 0.95 often work better for problematic posteriors. | ||
random_seed : int, default 10 | ||
Random seed used by the sampling steps. | ||
model : Model, optional | ||
Model to sample from. The model needs to have free random variables. When inside a ``with`` model | ||
context, it defaults to that model, otherwise the model must be passed explicitly. | ||
var_names : iterable of str, optional | ||
Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior | ||
progress_bar : bool, default True | ||
Whether or not to display a progress bar in the command line. The bar shows the percentage | ||
of completion, the sampling speed in samples per second (SPS), and the estimated remaining | ||
time until completion ("expected time of arrival"; ETA). | ||
keep_untransformed : bool, default False | ||
Include untransformed variables in the posterior samples. Defaults to False. | ||
chain_method : str, default "parallel" | ||
Specify how samples should be drawn. The choices include "sequential", "parallel", and "vectorized". | ||
idata_kwargs : dict, optional | ||
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value | ||
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should | ||
not be included in the returned object. | ||
|
||
Returns | ||
------- | ||
InferenceData | ||
ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and | ||
pointwise log likeihood values (unless skipped with ``idata_kwargs``). | ||
""" | ||
|
||
import numpyro | ||
|
||
from numpyro.infer import MCMC, NUTS | ||
|
||
model = modelcontext(model) | ||
|
@@ -182,7 +423,7 @@ def sample_numpyro_nuts( | |
init_state = [initial_point[rv_name] for rv_name in rv_names] | ||
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) | ||
|
||
logp_fn = get_jaxified_logp(model) | ||
logp_fn = get_jaxified_logp(model, negative_logp=False) | ||
|
||
nuts_kernel = NUTS( | ||
potential_fn=logp_fn, | ||
|
@@ -254,6 +495,10 @@ def sample_numpyro_nuts( | |
else: | ||
log_likelihood = None | ||
|
||
attrs = { | ||
"sampling_time": (tic3 - tic2).total_seconds(), | ||
} | ||
|
||
posterior = mcmc_samples | ||
az_trace = az.from_dict( | ||
posterior=posterior, | ||
|
@@ -262,7 +507,7 @@ def sample_numpyro_nuts( | |
sample_stats=_sample_stats_to_xarray(pmap_numpyro), | ||
coords=coords, | ||
dims=dims, | ||
attrs={"sampling_time": (tic3 - tic2).total_seconds()}, | ||
attrs=make_attrs(attrs, library=numpyro), | ||
**idata_kwargs, | ||
) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure we should add jax as a dependency. By that logic we should also add numpyro and blackjax...
The jaxtests workflow installs it manually
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh this was for documentation... Yeah it seems like we need to have an environment for docs... feel free to ignore my comment then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need them in the
dev
environments, only fortest
(until we add a docs env).Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need for the docs to render properly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the docs using the
dev
environments?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the python 3.8 dev one, just out of convenience I think. We should just create a specific environment for the docs...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, shouldn't we have it use the test env instead? I guess if we start a doc env it doesn't really matter, but maybe they would end up identical?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think docs requires quite more dependencies than tests