diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index c08eb068ac..26b784b36f 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -42,6 +42,26 @@ Var = Any # pylint: disable=invalid-name +def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]: + """If there are observations available, return them as a dictionary.""" + if model is None: + return None + + observations = {} + for obs in model.observed_RVs: + aux_obs = getattr(obs.tag, "observations", None) + if aux_obs is not None: + try: + obs_data = extract_obs_data(aux_obs) + observations[obs.name] = obs_data + except TypeError: + warnings.warn(f"Could not extract data from symbolic observation {obs}") + else: + warnings.warn(f"No data for observation {obs}") + + return observations + + class _DefaultTrace: """ Utility for collecting samples into a dictionary. @@ -196,25 +216,7 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray: self.dims = {**model_dims, **self.dims} self.density_dist_obs = density_dist_obs - self.observations = self.find_observations() - - def find_observations(self) -> Optional[Dict[str, Var]]: - """If there are observations available, return them as a dictionary.""" - if self.model is None: - return None - observations = {} - for obs in self.model.observed_RVs: - aux_obs = getattr(obs.tag, "observations", None) - if aux_obs is not None: - try: - obs_data = extract_obs_data(aux_obs) - observations[obs.name] = obs_data - except TypeError: - warnings.warn(f"Could not extract data from symbolic observation {obs}") - else: - warnings.warn(f"No data for observation {obs}") - - return observations + self.observations = find_observations(self.model) def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]: """Split MultiTrace object into posterior and warmup. diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 6fdbd64e73..efb975c2f0 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -26,7 +26,9 @@ from aesara.link.jax.dispatch import jax_funcify from pymc import Model, modelcontext -from pymc.aesaraf import compile_rv_inplace, inputvars +from pymc.aesaraf import compile_rv_inplace +from pymc.backends.arviz import find_observations +from pymc.distributions import logpt from pymc.util import get_default_varnames warnings.warn("This module is experimental.") @@ -95,6 +97,39 @@ def logp_fn_wrap(x): return logp_fn_wrap +# Adopted from arviz numpyro extractor +def _sample_stats_to_xarray(posterior): + """Extract sample_stats from NumPyro posterior.""" + rename_key = { + "potential_energy": "lp", + "adapt_state.step_size": "step_size", + "num_steps": "n_steps", + "accept_prob": "acceptance_rate", + } + data = {} + for stat, value in posterior.get_extra_fields(group_by_chain=True).items(): + if isinstance(value, (dict, tuple)): + continue + name = rename_key.get(stat, stat) + value = value.copy() + data[name] = value + if stat == "num_steps": + data["tree_depth"] = np.log2(value).astype(int) + 1 + return data + + +def _get_log_likelihood(model, samples): + "Compute log-likelihood for all observations" + data = {} + for v in model.observed_RVs: + logp_v = replace_shared_variables([logpt(v)]) + fgraph = FunctionGraph(model.value_vars, logp_v, clone=False) + jax_fn = jax_funcify(fgraph) + result = jax.vmap(jax.vmap(jax_fn))(*samples)[0] + data[v.name] = result + return data + + def sample_numpyro_nuts( draws=1000, tune=1000, @@ -151,9 +186,23 @@ def sample_numpyro_nuts( map_seed = jax.random.split(seed, chains) if chains == 1: - pmap_numpyro.run(seed, init_params=init_state, extra_fields=("num_steps",)) + init_params = init_state + map_seed = seed else: - pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",)) + init_params = init_state_batched + + pmap_numpyro.run( + map_seed, + init_params=init_params, + extra_fields=( + "num_steps", + "potential_energy", + "energy", + "adapt_state.step_size", + "accept_prob", + "diverging", + ), + ) raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) @@ -172,6 +221,11 @@ def sample_numpyro_nuts( print("Transformation time = ", tic4 - tic3, file=sys.stdout) posterior = mcmc_samples - az_trace = az.from_dict(posterior=posterior) + az_posterior = az.from_dict(posterior=posterior) + + az_obs = az.from_dict(observed_data=find_observations(model)) + az_stats = az.from_dict(sample_stats=_sample_stats_to_xarray(pmap_numpyro)) + az_ll = az.from_dict(log_likelihood=_get_log_likelihood(model, raw_mcmc_samples)) + az_trace = az.concat(az_posterior, az_ll, az_obs, az_stats) return az_trace diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index 3fd04059c0..172eceb4d0 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -9,6 +9,7 @@ import pymc as pm from pymc.sampling_jax import ( + _get_log_likelihood, get_jaxified_logp, replace_shared_variables, sample_numpyro_nuts, @@ -61,6 +62,24 @@ def test_deterministic_samples(): assert np.allclose(trace.posterior["b"].values, trace.posterior["a"].values / 2) +def test_get_log_likelihood(): + obs = np.random.normal(10, 2, size=100) + obs_at = aesara.shared(obs, borrow=True, name="obs") + with pm.Model() as model: + a = pm.Normal("a", 0, 2) + sigma = pm.HalfNormal("sigma") + b = pm.Normal("b", a, sigma=sigma, observed=obs_at) + + trace = pm.sample(tune=10, draws=10, chains=2, random_seed=1322) + + b_true = trace.log_likelihood.b.values + a = np.array(trace.posterior.a) + sigma_log_ = np.log(np.array(trace.posterior.sigma)) + b_jax = _get_log_likelihood(model, [a, sigma_log_])["b"] + + assert np.allclose(b_jax.reshape(-1), b_true.reshape(-1)) + + def test_replace_shared_variables(): x = aesara.shared(5, name="shared_x")