Skip to content

Commit fe2d101

Browse files
zaxtaxricardoV94
andauthored
Adding log_likelihood, observed_data, and sample_stats to numpyro sampler (#5189)
* Adding observed_data and sample_stats to numpyro sampler * Refactor find_observations * Add log likehoods to trace object Co-authored-by: Ricardo Vieira <[email protected]>
1 parent c22859d commit fe2d101

File tree

3 files changed

+98
-23
lines changed

3 files changed

+98
-23
lines changed

pymc/backends/arviz.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,26 @@
4242
Var = Any # pylint: disable=invalid-name
4343

4444

45+
def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]:
46+
"""If there are observations available, return them as a dictionary."""
47+
if model is None:
48+
return None
49+
50+
observations = {}
51+
for obs in model.observed_RVs:
52+
aux_obs = getattr(obs.tag, "observations", None)
53+
if aux_obs is not None:
54+
try:
55+
obs_data = extract_obs_data(aux_obs)
56+
observations[obs.name] = obs_data
57+
except TypeError:
58+
warnings.warn(f"Could not extract data from symbolic observation {obs}")
59+
else:
60+
warnings.warn(f"No data for observation {obs}")
61+
62+
return observations
63+
64+
4565
class _DefaultTrace:
4666
"""
4767
Utility for collecting samples into a dictionary.
@@ -196,25 +216,7 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
196216
self.dims = {**model_dims, **self.dims}
197217

198218
self.density_dist_obs = density_dist_obs
199-
self.observations = self.find_observations()
200-
201-
def find_observations(self) -> Optional[Dict[str, Var]]:
202-
"""If there are observations available, return them as a dictionary."""
203-
if self.model is None:
204-
return None
205-
observations = {}
206-
for obs in self.model.observed_RVs:
207-
aux_obs = getattr(obs.tag, "observations", None)
208-
if aux_obs is not None:
209-
try:
210-
obs_data = extract_obs_data(aux_obs)
211-
observations[obs.name] = obs_data
212-
except TypeError:
213-
warnings.warn(f"Could not extract data from symbolic observation {obs}")
214-
else:
215-
warnings.warn(f"No data for observation {obs}")
216-
217-
return observations
219+
self.observations = find_observations(self.model)
218220

219221
def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
220222
"""Split MultiTrace object into posterior and warmup.

pymc/sampling_jax.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
from aesara.link.jax.dispatch import jax_funcify
2727

2828
from pymc import Model, modelcontext
29-
from pymc.aesaraf import compile_rv_inplace, inputvars
29+
from pymc.aesaraf import compile_rv_inplace
30+
from pymc.backends.arviz import find_observations
31+
from pymc.distributions import logpt
3032
from pymc.util import get_default_varnames
3133

3234
warnings.warn("This module is experimental.")
@@ -95,6 +97,39 @@ def logp_fn_wrap(x):
9597
return logp_fn_wrap
9698

9799

100+
# Adopted from arviz numpyro extractor
101+
def _sample_stats_to_xarray(posterior):
102+
"""Extract sample_stats from NumPyro posterior."""
103+
rename_key = {
104+
"potential_energy": "lp",
105+
"adapt_state.step_size": "step_size",
106+
"num_steps": "n_steps",
107+
"accept_prob": "acceptance_rate",
108+
}
109+
data = {}
110+
for stat, value in posterior.get_extra_fields(group_by_chain=True).items():
111+
if isinstance(value, (dict, tuple)):
112+
continue
113+
name = rename_key.get(stat, stat)
114+
value = value.copy()
115+
data[name] = value
116+
if stat == "num_steps":
117+
data["tree_depth"] = np.log2(value).astype(int) + 1
118+
return data
119+
120+
121+
def _get_log_likelihood(model, samples):
122+
"Compute log-likelihood for all observations"
123+
data = {}
124+
for v in model.observed_RVs:
125+
logp_v = replace_shared_variables([logpt(v)])
126+
fgraph = FunctionGraph(model.value_vars, logp_v, clone=False)
127+
jax_fn = jax_funcify(fgraph)
128+
result = jax.vmap(jax.vmap(jax_fn))(*samples)[0]
129+
data[v.name] = result
130+
return data
131+
132+
98133
def sample_numpyro_nuts(
99134
draws=1000,
100135
tune=1000,
@@ -151,9 +186,23 @@ def sample_numpyro_nuts(
151186
map_seed = jax.random.split(seed, chains)
152187

153188
if chains == 1:
154-
pmap_numpyro.run(seed, init_params=init_state, extra_fields=("num_steps",))
189+
init_params = init_state
190+
map_seed = seed
155191
else:
156-
pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",))
192+
init_params = init_state_batched
193+
194+
pmap_numpyro.run(
195+
map_seed,
196+
init_params=init_params,
197+
extra_fields=(
198+
"num_steps",
199+
"potential_energy",
200+
"energy",
201+
"adapt_state.step_size",
202+
"accept_prob",
203+
"diverging",
204+
),
205+
)
157206

158207
raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
159208

@@ -172,6 +221,11 @@ def sample_numpyro_nuts(
172221
print("Transformation time = ", tic4 - tic3, file=sys.stdout)
173222

174223
posterior = mcmc_samples
175-
az_trace = az.from_dict(posterior=posterior)
224+
az_posterior = az.from_dict(posterior=posterior)
225+
226+
az_obs = az.from_dict(observed_data=find_observations(model))
227+
az_stats = az.from_dict(sample_stats=_sample_stats_to_xarray(pmap_numpyro))
228+
az_ll = az.from_dict(log_likelihood=_get_log_likelihood(model, raw_mcmc_samples))
229+
az_trace = az.concat(az_posterior, az_ll, az_obs, az_stats)
176230

177231
return az_trace

pymc/tests/test_sampling_jax.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pymc as pm
1010

1111
from pymc.sampling_jax import (
12+
_get_log_likelihood,
1213
get_jaxified_logp,
1314
replace_shared_variables,
1415
sample_numpyro_nuts,
@@ -61,6 +62,24 @@ def test_deterministic_samples():
6162
assert np.allclose(trace.posterior["b"].values, trace.posterior["a"].values / 2)
6263

6364

65+
def test_get_log_likelihood():
66+
obs = np.random.normal(10, 2, size=100)
67+
obs_at = aesara.shared(obs, borrow=True, name="obs")
68+
with pm.Model() as model:
69+
a = pm.Normal("a", 0, 2)
70+
sigma = pm.HalfNormal("sigma")
71+
b = pm.Normal("b", a, sigma=sigma, observed=obs_at)
72+
73+
trace = pm.sample(tune=10, draws=10, chains=2, random_seed=1322)
74+
75+
b_true = trace.log_likelihood.b.values
76+
a = np.array(trace.posterior.a)
77+
sigma_log_ = np.log(np.array(trace.posterior.sigma))
78+
b_jax = _get_log_likelihood(model, [a, sigma_log_])["b"]
79+
80+
assert np.allclose(b_jax.reshape(-1), b_true.reshape(-1))
81+
82+
6483
def test_replace_shared_variables():
6584
x = aesara.shared(5, name="shared_x")
6685

0 commit comments

Comments
 (0)