Skip to content

Commit 94fc1b1

Browse files
committed
Adding recover_marginals utility function
1 parent f1ece1c commit 94fc1b1

File tree

2 files changed

+182
-3
lines changed

2 files changed

+182
-3
lines changed

pymc_experimental/model/marginal_model.py

+142-3
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,29 @@
22
from typing import Sequence, Tuple, Union
33

44
import numpy as np
5+
import pymc
56
import pytensor.tensor as pt
7+
from arviz import dict_to_dataset
68
from pymc import SymbolicRandomVariable
9+
from pymc.backends.arviz import coords_and_dims_for_inferencedata
710
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
811
from pymc.distributions.transforms import Chain
912
from pymc.logprob.abstract import _logprob
1013
from pymc.logprob.basic import conditional_logp
1114
from pymc.logprob.transforms import IntervalTransform
1215
from pymc.model import Model
13-
from pymc.pytensorf import constant_fold, inputvars
16+
from pymc.pytensorf import compile_pymc, constant_fold, inputvars
17+
from pymc.util import dataset_to_point_list, treedict
1418
from pytensor import Mode
1519
from pytensor.compile import SharedVariable
1620
from pytensor.compile.builders import OpFromGraph
17-
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
21+
from pytensor.graph import (
22+
Constant,
23+
FunctionGraph,
24+
ancestors,
25+
clone_replace,
26+
vectorize_graph,
27+
)
1828
from pytensor.scan import map as scan_map
1929
from pytensor.tensor import TensorVariable
2030
from pytensor.tensor.elemwise import Elemwise
@@ -205,8 +215,9 @@ def clone(self):
205215
vars = self.basic_RVs + self.potentials + self.deterministics + self.marginalized_rvs
206216
cloned_vars = clone_replace(vars)
207217
vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)}
218+
m.vars_to_clone = vars_to_clone
208219

209-
m.named_vars = {name: vars_to_clone[var] for name, var in self.named_vars.items()}
220+
m.named_vars = treedict({name: vars_to_clone[var] for name, var in self.named_vars.items()})
210221
m.named_vars_to_dims = self.named_vars_to_dims
211222
m.values_to_rvs = {i: vars_to_clone[rv] for i, rv in self.values_to_rvs.items()}
212223
m.rvs_to_values = {vars_to_clone[rv]: i for rv, i in self.rvs_to_values.items()}
@@ -244,6 +255,134 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV
244255
# Raise errors and warnings immediately
245256
self.clone()._marginalize(user_warnings=True)
246257

258+
def unmarginalize(self, rvs_to_unmarginalize):
259+
for rv in rvs_to_unmarginalize:
260+
self.marginalized_rvs.remove(rv)
261+
self.register_rv(rv, name=rv.name)
262+
263+
def recover_marginals(
264+
self, idata, var_names=None, include_samples=False, extend_inferencedata=True
265+
):
266+
"""Computes log-likelihoods of marginalized variables conditioned on parameters
267+
of the model given InferenceData with posterior group
268+
269+
When there are multiple marginalized variables, each marginalized variable is
270+
conditioned on both the parameters and the other variables still marginalized
271+
272+
Parameters
273+
----------
274+
idata : InferenceData
275+
InferenceData with posterior group
276+
var_names : sequence of str, optional
277+
List of Observed variable names for which to compute log_likelihood. Defaults to all observed variables
278+
include_samples : bool, default False
279+
Include samples of the marginalized variables
280+
extend_inferencedata : bool, default True
281+
Whether to extend the original InferenceData or return a new one
282+
283+
Returns
284+
-------
285+
idata : InferenceData
286+
InferenceData with var_names added to posterior
287+
288+
"""
289+
if var_names is None:
290+
var_names = self.marginalized_rvs
291+
292+
posterior = idata.posterior
293+
294+
# Remove Deterministics
295+
posterior_values = posterior[
296+
[rv.name for rv in self.free_RVs if rv not in self.marginalized_rvs]
297+
]
298+
299+
sample_dims = ("chain", "draw")
300+
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)
301+
rv_dict = {}
302+
rv_dims_dict = {}
303+
304+
for rv in var_names:
305+
m = self.clone()
306+
rv = m.vars_to_clone[rv]
307+
m.unmarginalize([rv])
308+
joint_logp = m.logp()
309+
310+
rv_shape = constant_fold(tuple(rv.shape))
311+
rv_domain = get_domain_of_finite_discrete_rv(rv)
312+
rv_domain_tensor = pt.swapaxes(
313+
pt.full(
314+
(*rv_shape, len(rv_domain)),
315+
rv_domain,
316+
dtype=rv.dtype,
317+
),
318+
axis1=0,
319+
axis2=-1,
320+
)
321+
322+
marginalized_value = m.rvs_to_values[rv]
323+
324+
other_values = [v for v in m.value_vars if v is not marginalized_value]
325+
326+
# TODO: Handle constants
327+
# TODO: Handle transformed variables
328+
329+
joint_logps = vectorize_graph(
330+
joint_logp,
331+
replace={marginalized_value: rv_domain_tensor},
332+
)
333+
334+
rv_loglike_fn = None
335+
if include_samples:
336+
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps)
337+
rv_loglike_fn = compile_pymc(
338+
inputs=other_values,
339+
outputs=[joint_logps, sample_rv_outs],
340+
on_unused_input="ignore",
341+
)
342+
else:
343+
rv_loglike_fn = compile_pymc(
344+
inputs=other_values,
345+
outputs=joint_logps,
346+
on_unused_input="ignore",
347+
)
348+
349+
logvs = [rv_loglike_fn(**vs) for vs in posterior_pts]
350+
351+
if include_samples:
352+
logps, samples = zip(*logvs)
353+
logps = np.array(logps)
354+
rv_dict[rv.name] = np.reshape(
355+
samples, tuple(len(coord) for coord in stacked_dims.values())
356+
)
357+
rv_dims_dict[rv.name] = sample_dims
358+
rv_dict["lp_" + rv.name] = np.reshape(
359+
logps, tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:]
360+
)
361+
rv_dims_dict["lp_" + rv.name] = sample_dims + ("lp_" + rv.name + "_dims",)
362+
else:
363+
logps = np.array(logvs)
364+
rv_dict["lp_" + rv.name] = np.reshape(
365+
logps, tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:]
366+
)
367+
rv_dims_dict["lp_" + rv.name] = sample_dims + ("lp_" + rv.name + "_dims",)
368+
369+
coords, dims = coords_and_dims_for_inferencedata(self)
370+
rv_dataset = dict_to_dataset(
371+
rv_dict,
372+
library=pymc,
373+
dims=dims,
374+
coords=coords,
375+
default_dims=list(sample_dims),
376+
skip_event_dims=True,
377+
)
378+
379+
if extend_inferencedata:
380+
rv_dict = {k: (rv_dims_dict[k], v) for (k, v) in rv_dict.items()}
381+
idata = idata.posterior.assign(**rv_dict)
382+
return idata
383+
else:
384+
return rv_dataset
385+
247386

248387
class MarginalRV(SymbolicRandomVariable):
249388
"""Base class for Marginalized RVs"""

pymc_experimental/tests/model/test_marginal_model.py

+40
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44
import numpy as np
55
import pandas as pd
66
import pymc as pm
7+
import pytensor
78
import pytensor.tensor as pt
89
import pytest
10+
from arviz import InferenceData, dict_to_dataset
911
from pymc import ImputationWarning, inputvars
1012
from pymc.distributions import transforms
1113
from pymc.logprob.abstract import _logprob
1214
from pymc.util import UNSET
15+
from pytensor.graph import vectorize_graph
1316
from scipy.special import logsumexp
17+
from scipy.stats import norm
1418

1519
from pymc_experimental.model.marginal_model import (
1620
FiniteDiscreteMarginalRV,
@@ -251,6 +255,42 @@ def test_marginalized_change_point_model_sampling(disaster_model):
251255
)
252256

253257

258+
def test_recover_marginals_basic():
259+
with MarginalModel() as m:
260+
p = np.array([0.5, 0.2, 0.3])
261+
k = pm.Categorical("k", p=p)
262+
mu = np.array([-3.0, 0.0, 3.0])
263+
mu_ = pt.as_tensor_variable(mu)
264+
y = pm.Normal("y", mu=mu_[k])
265+
266+
m.marginalize([k])
267+
268+
rng = np.random.default_rng(211)
269+
270+
with m:
271+
prior = pm.sample_prior_predictive(
272+
samples=20,
273+
random_seed=rng,
274+
return_inferencedata=False,
275+
)
276+
idata = InferenceData(posterior=dict_to_dataset(prior))
277+
278+
idata = m.recover_marginals(idata, include_samples=True)
279+
assert "k" in idata
280+
assert "lp_k" in idata
281+
assert idata.k.shape == idata.y.shape
282+
assert idata.lp_k.shape == idata.k.shape + (len(p),)
283+
284+
def true_logp(y):
285+
y = y.repeat(len(p)).reshape(len(y), -1)
286+
return np.log(p) + norm.logpdf(y, loc=mu)
287+
288+
np.testing.assert_almost_equal(
289+
true_logp(idata.y.values.flatten()),
290+
idata.lp_k[0].values,
291+
)
292+
293+
254294
@pytest.mark.filterwarnings("error")
255295
def test_not_supported_marginalized():
256296
"""Marginalized graphs with non-Elemwise Operations are not supported as they

0 commit comments

Comments
 (0)