Skip to content

Commit 04e1271

Browse files
brandonwillardtwiecki
authored andcommitted
Create a NumPyro sampler Op for better JAX backend integration
1 parent de74ff6 commit 04e1271

File tree

2 files changed

+182
-158
lines changed

2 files changed

+182
-158
lines changed

pymc3/sampling_jax.py

Lines changed: 162 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -3,150 +3,133 @@
33
import re
44
import warnings
55

6-
from collections import defaultdict
7-
86
xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--")
97
xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
108
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(100)])
119

1210
import aesara.graph.fg
11+
import aesara.tensor as at
1312
import arviz as az
1413
import jax
1514
import numpy as np
1615
import pandas as pd
1716

18-
from aesara.link.jax.jax_dispatch import jax_funcify
19-
20-
import pymc3 as pm
17+
from aesara.compile import SharedVariable
18+
from aesara.graph.basic import Apply, Constant, clone, graph_inputs
19+
from aesara.graph.fg import FunctionGraph
20+
from aesara.graph.op import Op
21+
from aesara.graph.opt import MergeOptimizer
22+
from aesara.link.jax.dispatch import jax_funcify
23+
from aesara.tensor.type import TensorType
2124

2225
from pymc3 import modelcontext
2326

2427
warnings.warn("This module is experimental.")
2528

26-
# Disable C compilation by default
27-
# aesara.config.cxx = ""
28-
# This will make the JAX Linker the default
29-
# aesara.config.mode = "JAX"
3029

30+
class NumPyroNUTS(Op):
31+
def __init__(
32+
self,
33+
inputs,
34+
outputs,
35+
target_accept=0.8,
36+
draws=1000,
37+
tune=1000,
38+
chains=4,
39+
seed=None,
40+
progress_bar=True,
41+
):
42+
self.draws = draws
43+
self.tune = tune
44+
self.chains = chains
45+
self.target_accept = target_accept
46+
self.progress_bar = progress_bar
47+
self.seed = seed
3148

32-
def sample_tfp_nuts(
33-
draws=1000,
34-
tune=1000,
35-
chains=4,
36-
target_accept=0.8,
37-
random_seed=10,
38-
model=None,
39-
num_tuning_epoch=2,
40-
num_compute_step_size=500,
41-
):
42-
import jax
49+
self.inputs, self.outputs = clone(inputs, outputs, copy_inputs=False)
50+
self.inputs_type = tuple([input.type for input in inputs])
51+
self.outputs_type = tuple([output.type for output in outputs])
52+
self.nin = len(inputs)
53+
self.nout = len(outputs)
54+
self.nshared = len([v for v in inputs if isinstance(v, SharedVariable)])
55+
self.samples_bcast = [self.chains == 1, self.draws == 1]
4356

44-
from tensorflow_probability.substrates import jax as tfp
57+
self.fgraph = FunctionGraph(self.inputs, self.outputs, clone=False)
58+
MergeOptimizer().optimize(self.fgraph)
4559

46-
model = modelcontext(model)
60+
super().__init__()
4761

48-
seed = jax.random.PRNGKey(random_seed)
62+
def make_node(self, *inputs):
4963

50-
fgraph = model.logp.f.maker.fgraph
51-
fns = jax_funcify(fgraph)
52-
logp_fn_jax = fns[0]
64+
# The samples for each variable
65+
outputs = [
66+
TensorType(v.dtype, self.samples_bcast + list(v.broadcastable))() for v in inputs
67+
]
5368

54-
rv_names = [rv.name for rv in model.free_RVs]
55-
init_state = [model.initial_point[rv_name] for rv_name in rv_names]
56-
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
69+
# The leapfrog statistics
70+
outputs += [TensorType("int64", self.samples_bcast)()]
5771

58-
@jax.pmap
59-
def _sample(init_state, seed):
60-
def gen_kernel(step_size):
61-
hmc = tfp.mcmc.NoUTurnSampler(target_log_prob_fn=logp_fn_jax, step_size=step_size)
62-
return tfp.mcmc.DualAveragingStepSizeAdaptation(
63-
hmc, tune // num_tuning_epoch, target_accept_prob=target_accept
64-
)
72+
all_inputs = list(inputs)
73+
if self.nshared > 0:
74+
all_inputs += self.inputs[-self.nshared :]
6575

66-
def trace_fn(_, pkr):
67-
return pkr.new_step_size
68-
69-
def get_tuned_stepsize(samples, step_size):
70-
return step_size[-1] * jax.numpy.std(samples[-num_compute_step_size:])
71-
72-
step_size = jax.tree_map(jax.numpy.ones_like, init_state)
73-
for i in range(num_tuning_epoch - 1):
74-
tuning_hmc = gen_kernel(step_size)
75-
init_samples, tuning_result, kernel_results = tfp.mcmc.sample_chain(
76-
num_results=tune // num_tuning_epoch,
77-
current_state=init_state,
78-
kernel=tuning_hmc,
79-
trace_fn=trace_fn,
80-
return_final_kernel_results=True,
81-
seed=seed,
82-
)
76+
return Apply(self, all_inputs, outputs)
8377

84-
step_size = jax.tree_multimap(get_tuned_stepsize, list(init_samples), tuning_result)
85-
init_state = [x[-1] for x in init_samples]
86-
87-
# Run inference
88-
sample_kernel = gen_kernel(step_size)
89-
mcmc_samples, leapfrog_num = tfp.mcmc.sample_chain(
90-
num_results=draws,
91-
num_burnin_steps=tune // num_tuning_epoch,
92-
current_state=init_state,
93-
kernel=sample_kernel,
94-
trace_fn=lambda _, pkr: pkr.inner_results.leapfrogs_taken,
95-
seed=seed,
96-
)
78+
def do_constant_folding(self, *args):
79+
return False
9780

98-
return mcmc_samples, leapfrog_num
81+
def perform(self, node, inputs, outputs):
82+
raise NotImplementedError()
9983

100-
print("Compiling...")
101-
tic2 = pd.Timestamp.now()
102-
map_seed = jax.random.split(seed, chains)
103-
mcmc_samples, leapfrog_num = _sample(init_state_batched, map_seed)
104-
105-
# map_seed = jax.random.split(seed, chains)
106-
# mcmc_samples = _sample(init_state_batched, map_seed)
107-
# tic4 = pd.Timestamp.now()
108-
# print("Sampling time = ", tic4 - tic3)
109-
110-
posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}
11184

112-
az_trace = az.from_dict(posterior=posterior)
113-
tic3 = pd.Timestamp.now()
114-
print("Compilation + sampling time = ", tic3 - tic2)
115-
return az_trace # , leapfrog_num, tic3 - tic2
116-
117-
118-
def sample_numpyro_nuts(
119-
draws=1000,
120-
tune=1000,
121-
chains=4,
122-
target_accept=0.8,
123-
random_seed=10,
124-
model=None,
125-
progress_bar=True,
126-
keep_untransformed=False,
127-
):
85+
@jax_funcify.register(NumPyroNUTS)
86+
def jax_funcify_NumPyroNUTS(op, node, **kwargs):
12887
from numpyro.infer import MCMC, NUTS
12988

130-
from pymc3 import modelcontext
89+
draws = op.draws
90+
tune = op.tune
91+
chains = op.chains
92+
target_accept = op.target_accept
93+
progress_bar = op.progress_bar
94+
seed = op.seed
95+
96+
# Compile the "inner" log-likelihood function. This will have extra shared
97+
# variable inputs as the last arguments
98+
logp_fn = jax_funcify(op.fgraph, **kwargs)
99+
100+
if isinstance(logp_fn, (list, tuple)):
101+
# This handles the new JAX backend, which always returns a tuple
102+
logp_fn = logp_fn[0]
103+
104+
def _sample(*inputs):
105+
106+
if op.nshared > 0:
107+
current_state = inputs[: -op.nshared]
108+
shared_inputs = tuple(op.fgraph.inputs[-op.nshared :])
109+
else:
110+
current_state = inputs
111+
shared_inputs = ()
112+
113+
def log_fn_wrap(x):
114+
res = logp_fn(
115+
*(
116+
x
117+
# We manually obtain the shared values and added them
118+
# as arguments to our compiled "inner" function
119+
+ tuple(
120+
v.get_value(borrow=True, return_internal_type=True) for v in shared_inputs
121+
)
122+
)
123+
)
131124

132-
model = modelcontext(model)
125+
if isinstance(res, (list, tuple)):
126+
# This handles the new JAX backend, which always returns a tuple
127+
res = res[0]
133128

134-
seed = jax.random.PRNGKey(random_seed)
129+
return -res
135130

136-
fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
137-
fns = jax_funcify(fgraph)
138-
logp_fn_jax = fns[0]
139-
140-
rv_names = [rv.name for rv in model.free_RVs]
141-
init_state = [model.initial_point[rv_name] for rv_name in rv_names]
142-
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
143-
144-
@jax.jit
145-
def _sample(current_state, seed):
146-
step_size = jax.tree_map(jax.numpy.ones_like, init_state)
147131
nuts_kernel = NUTS(
148-
potential_fn=lambda x: -logp_fn_jax(*x),
149-
# model=model,
132+
potential_fn=log_fn_wrap,
150133
target_accept_prob=target_accept,
151134
adapt_step_size=True,
152135
adapt_mass_matrix=True,
@@ -166,60 +149,87 @@ def _sample(current_state, seed):
166149
pmap_numpyro.run(seed, init_params=current_state, extra_fields=("num_steps",))
167150
samples = pmap_numpyro.get_samples(group_by_chain=True)
168151
leapfrogs_taken = pmap_numpyro.get_extra_fields(group_by_chain=True)["num_steps"]
169-
return samples, leapfrogs_taken
170-
171-
print("Compiling...")
172-
tic2 = pd.Timestamp.now()
173-
map_seed = jax.random.split(seed, chains)
174-
mcmc_samples, leapfrogs_taken = _sample(init_state_batched, map_seed)
175-
# map_seed = jax.random.split(seed, chains)
176-
# mcmc_samples = _sample(init_state_batched, map_seed)
177-
# tic4 = pd.Timestamp.now()
178-
# print("Sampling time = ", tic4 - tic3)
152+
return tuple(samples) + (leapfrogs_taken,)
179153

180-
posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}
181-
tic3 = pd.Timestamp.now()
182-
posterior = _transform_samples(posterior, model, keep_untransformed=keep_untransformed)
183-
tic4 = pd.Timestamp.now()
154+
return _sample
184155

185-
az_trace = az.from_dict(posterior=posterior)
186-
print("Compilation + sampling time = ", tic3 - tic2)
187-
print("Transformation time = ", tic4 - tic3)
188156

189-
return az_trace # , leapfrogs_taken, tic3 - tic2
157+
def sample_numpyro_nuts(
158+
draws=1000,
159+
tune=1000,
160+
chains=4,
161+
target_accept=0.8,
162+
random_seed=10,
163+
model=None,
164+
progress_bar=True,
165+
keep_untransformed=False,
166+
):
167+
model = modelcontext(model)
190168

169+
seed = jax.random.PRNGKey(random_seed)
191170

192-
def _transform_samples(samples, model, keep_untransformed=False):
171+
rv_names = [rv.name for rv in model.value_vars]
172+
init_state = [model.initial_point[rv_name] for rv_name in rv_names]
173+
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
174+
init_state_batched_at = [at.as_tensor(v) for v in init_state_batched]
193175

194-
# Find out which RVs we need to compute:
195-
free_rv_names = {x.name for x in model.free_RVs}
196-
unobserved_names = {x.name for x in model.unobserved_RVs}
176+
nuts_inputs = sorted(
177+
[v for v in graph_inputs([model.logpt]) if not isinstance(v, Constant)],
178+
key=lambda x: isinstance(x, SharedVariable),
179+
)
180+
map_seed = jax.random.split(seed, chains)
181+
numpyro_samples = NumPyroNUTS(
182+
nuts_inputs,
183+
[model.logpt],
184+
target_accept=target_accept,
185+
draws=draws,
186+
tune=tune,
187+
chains=chains,
188+
seed=map_seed,
189+
progress_bar=progress_bar,
190+
)(*init_state_batched_at)
191+
192+
# Un-transform the transformed variables in JAX
193+
sample_outputs = []
194+
for i, (value_var, rv_samples) in enumerate(zip(model.value_vars, numpyro_samples[:-1])):
195+
rv = model.values_to_rvs[value_var]
196+
transform = getattr(value_var.tag, "transform", None)
197+
if transform is not None:
198+
untrans_value_var = transform.backward(rv, rv_samples)
199+
untrans_value_var.name = rv.name
200+
sample_outputs.append(untrans_value_var)
201+
202+
if keep_untransformed:
203+
rv_samples.name = value_var.name
204+
sample_outputs.append(rv_samples)
205+
else:
206+
rv_samples.name = rv.name
207+
sample_outputs.append(rv_samples)
197208

198-
names_to_compute = unobserved_names - free_rv_names
199-
ops_to_compute = [x for x in model.unobserved_RVs if x.name in names_to_compute]
209+
print("Compiling...")
200210

201-
# Create function graph for these:
202-
fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, ops_to_compute)
211+
tic1 = pd.Timestamp.now()
212+
_sample = aesara.function(
213+
[],
214+
sample_outputs + [numpyro_samples[-1]],
215+
allow_input_downcast=True,
216+
on_unused_input="ignore",
217+
accept_inplace=True,
218+
mode="JAX",
219+
)
220+
tic2 = pd.Timestamp.now()
203221

204-
# Jaxify, which returns a list of functions, one for each op
205-
jax_fns = jax_funcify(fgraph)
222+
print("Compilation time = ", tic2 - tic1)
206223

207-
# Put together the inputs
208-
inputs = [samples[x.name] for x in model.free_RVs]
224+
print("Sampling...")
209225

210-
for cur_op, cur_jax_fn in zip(ops_to_compute, jax_fns):
226+
*mcmc_samples, leapfrogs_taken = _sample()
227+
tic3 = pd.Timestamp.now()
211228

212-
# We need a function taking a single argument to run vmap, while the
213-
# jax_fn takes a list, so:
214-
result = jax.vmap(jax.vmap(cur_jax_fn))(*inputs)
229+
print("Sampling time = ", tic3 - tic2)
215230

216-
# Add to sample dict
217-
samples[cur_op.name] = result
231+
posterior = {k.name: v for k, v in zip(sample_outputs, mcmc_samples)}
218232

219-
# Discard unwanted transformed variables, if desired:
220-
vars_to_keep = set(
221-
pm.util.get_default_varnames(list(samples.keys()), include_transformed=keep_untransformed)
222-
)
223-
samples = {x: y for x, y in samples.items() if x in vars_to_keep}
233+
az_trace = az.from_dict(posterior=posterior)
224234

225-
return samples
235+
return az_trace

0 commit comments

Comments
 (0)