Skip to content

Commit 5ff9bbc

Browse files
authored
Add experimental JAX samplers (#4247)
* Add JAX NUTS samplers from TFP and numpyro. With @junpenglao. * Add missing import. * Remove JAX as default linker. * Add experimental warning and clean up imports. * Add JAX nb. * Add NB to toc. * Black and isort. * Change title. * Remove comma * Typo. * nbqa NB * Run pre-commit. * Disable pylint. * Add to release-notes.
1 parent 22c079c commit 5ff9bbc

File tree

4 files changed

+570
-1
lines changed

4 files changed

+570
-1
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pip install theano-pymc
2626
This new version of `Theano-PyMC` comes with an experimental JAX backend which, when combined with the new and experimental JAX samplers in PyMC3, can greatly speed up sampling in your model. As this is still very new, please do not use it in production yet but do test it out and let us know if anything breaks and what results you are seeing, especially speed-wise.
2727

2828
### New features
29+
- New experimental JAX samplers in `pymc3.sample_jax` (see [notebook](https://docs.pymc.io/notebooks/GLM-hierarchical-jax.html) and [#4247](https://github.com/pymc-devs/pymc3/pull/4247)). Requires JAX and either TFP or numpyro.
2930
- Add MLDA, a new stepper for multilevel sampling. MLDA can be used when a hierarchy of approximate posteriors of varying accuracy is available, offering improved sampling efficiency especially in high-dimensional problems and/or where gradients are not available (see [#3926](https://github.com/pymc-devs/pymc3/pull/3926))
3031
- Add Bayesian Additive Regression Trees (BARTs) [#4183](https://github.com/pymc-devs/pymc3/pull/4183))
3132
- Added `pymc3.gp.cov.Circular` kernel for Gaussian Processes on circular domains, e.g. the unit circle (see [#4082](https://github.com/pymc-devs/pymc3/pull/4082)).

docs/source/notebooks/GLM-hierarchical-jax.ipynb

Lines changed: 384 additions & 0 deletions
Large diffs are not rendered by default.

docs/source/notebooks/table_of_contents_examples.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,6 @@ Gallery.contents = {
6464
"MLDA_introduction": "MCMC",
6565
"MLDA_simple_linear_regression": "MCMC",
6666
"MLDA_gravity_surveying": "MCMC",
67-
"MLDA_variance_reduction_linear_regression": "MCMC"
67+
"MLDA_variance_reduction_linear_regression": "MCMC",
68+
"GLM-hierarchical-jax": "MCMC"
6869
}

pymc3/sampling_jax.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# pylint: skip-file
2+
import os
3+
import re
4+
import warnings
5+
6+
xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--")
7+
xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
8+
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(100)])
9+
10+
import arviz as az
11+
import jax
12+
import numpy as np
13+
import pandas as pd
14+
import theano
15+
import theano.sandbox.jax_linker
16+
import theano.sandbox.jaxify
17+
18+
import pymc3 as pm
19+
20+
from pymc3 import modelcontext
21+
22+
warnings.warn("This module is experimental.")
23+
24+
# Disable C compilation by default
25+
# theano.config.cxx = ""
26+
# This will make the JAX Linker the default
27+
# theano.config.mode = "JAX"
28+
29+
30+
def sample_tfp_nuts(
31+
draws=1000,
32+
tune=1000,
33+
chains=4,
34+
target_accept=0.8,
35+
random_seed=10,
36+
model=None,
37+
num_tuning_epoch=2,
38+
num_compute_step_size=500,
39+
):
40+
from tensorflow_probability.substrates import jax as tfp
41+
42+
model = modelcontext(model)
43+
44+
seed = jax.random.PRNGKey(random_seed)
45+
46+
fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt])
47+
fns = theano.sandbox.jaxify.jax_funcify(fgraph)
48+
logp_fn_jax = fns[0]
49+
50+
rv_names = [rv.name for rv in model.free_RVs]
51+
init_state = [model.test_point[rv_name] for rv_name in rv_names]
52+
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
53+
54+
@jax.pmap
55+
def _sample(init_state, seed):
56+
def gen_kernel(step_size):
57+
hmc = tfp.mcmc.NoUTurnSampler(target_log_prob_fn=logp_fn_jax, step_size=step_size)
58+
return tfp.mcmc.DualAveragingStepSizeAdaptation(
59+
hmc, tune // num_tuning_epoch, target_accept_prob=target_accept
60+
)
61+
62+
def trace_fn(_, pkr):
63+
return pkr.new_step_size
64+
65+
def get_tuned_stepsize(samples, step_size):
66+
return step_size[-1] * jax.numpy.std(samples[-num_compute_step_size:])
67+
68+
step_size = jax.tree_map(jax.numpy.ones_like, init_state)
69+
for i in range(num_tuning_epoch - 1):
70+
tuning_hmc = gen_kernel(step_size)
71+
init_samples, tuning_result, kernel_results = tfp.mcmc.sample_chain(
72+
num_results=tune // num_tuning_epoch,
73+
current_state=init_state,
74+
kernel=tuning_hmc,
75+
trace_fn=trace_fn,
76+
return_final_kernel_results=True,
77+
seed=seed,
78+
)
79+
80+
step_size = jax.tree_multimap(get_tuned_stepsize, list(init_samples), tuning_result)
81+
init_state = [x[-1] for x in init_samples]
82+
83+
# Run inference
84+
sample_kernel = gen_kernel(step_size)
85+
mcmc_samples, leapfrog_num = tfp.mcmc.sample_chain(
86+
num_results=draws,
87+
num_burnin_steps=tune // num_tuning_epoch,
88+
current_state=init_state,
89+
kernel=sample_kernel,
90+
trace_fn=lambda _, pkr: pkr.inner_results.leapfrogs_taken,
91+
seed=seed,
92+
)
93+
94+
return mcmc_samples, leapfrog_num
95+
96+
print("Compiling...")
97+
tic2 = pd.Timestamp.now()
98+
map_seed = jax.random.split(seed, chains)
99+
mcmc_samples, leapfrog_num = _sample(init_state_batched, map_seed)
100+
tic3 = pd.Timestamp.now()
101+
print("Compilation + sampling time = ", tic3 - tic2)
102+
103+
# map_seed = jax.random.split(seed, chains)
104+
# mcmc_samples = _sample(init_state_batched, map_seed)
105+
# tic4 = pd.Timestamp.now()
106+
# print("Sampling time = ", tic4 - tic3)
107+
108+
posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}
109+
110+
az_trace = az.from_dict(posterior=posterior)
111+
return az_trace # , leapfrog_num, tic3 - tic2
112+
113+
import jax
114+
115+
116+
def sample_numpyro_nuts(
117+
draws=1000,
118+
tune=1000,
119+
chains=4,
120+
target_accept=0.8,
121+
random_seed=10,
122+
model=None,
123+
progress_bar=True,
124+
):
125+
from numpyro.infer import MCMC, NUTS
126+
127+
from pymc3 import modelcontext
128+
129+
model = modelcontext(model)
130+
131+
seed = jax.random.PRNGKey(random_seed)
132+
133+
fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt])
134+
fns = theano.sandbox.jaxify.jax_funcify(fgraph)
135+
logp_fn_jax = fns[0]
136+
137+
rv_names = [rv.name for rv in model.free_RVs]
138+
init_state = [model.test_point[rv_name] for rv_name in rv_names]
139+
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
140+
141+
@jax.jit
142+
def _sample(current_state, seed):
143+
step_size = jax.tree_map(jax.numpy.ones_like, init_state)
144+
nuts_kernel = NUTS(
145+
potential_fn=lambda x: -logp_fn_jax(*x),
146+
# model=model,
147+
target_accept_prob=target_accept,
148+
adapt_step_size=True,
149+
adapt_mass_matrix=True,
150+
dense_mass=False,
151+
)
152+
153+
pmap_numpyro = MCMC(
154+
nuts_kernel,
155+
num_warmup=tune,
156+
num_samples=draws,
157+
num_chains=chains,
158+
postprocess_fn=None,
159+
chain_method="parallel",
160+
progress_bar=progress_bar,
161+
)
162+
163+
pmap_numpyro.run(seed, init_params=current_state, extra_fields=("num_steps",))
164+
samples = pmap_numpyro.get_samples(group_by_chain=True)
165+
leapfrogs_taken = pmap_numpyro.get_extra_fields(group_by_chain=True)["num_steps"]
166+
return samples, leapfrogs_taken
167+
168+
print("Compiling...")
169+
tic2 = pd.Timestamp.now()
170+
map_seed = jax.random.split(seed, chains)
171+
mcmc_samples, leapfrogs_taken = _sample(init_state_batched, map_seed)
172+
tic3 = pd.Timestamp.now()
173+
print("Compilation + sampling time = ", tic3 - tic2)
174+
175+
# map_seed = jax.random.split(seed, chains)
176+
# mcmc_samples = _sample(init_state_batched, map_seed)
177+
# tic4 = pd.Timestamp.now()
178+
# print("Sampling time = ", tic4 - tic3)
179+
180+
posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}
181+
182+
az_trace = az.from_dict(posterior=posterior)
183+
return az_trace # , leapfrogs_taken, tic3 - tic2

0 commit comments

Comments
 (0)