Skip to content

Commit 1fc5d55

Browse files
committed
Adding unmarginalize
1 parent 00d7a2b commit 1fc5d55

File tree

2 files changed

+153
-3
lines changed

2 files changed

+153
-3
lines changed

pymc_experimental/model/marginal_model.py

+127-2
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@
22
from typing import Sequence, Tuple, Union
33

44
import numpy as np
5+
import pymc
56
import pytensor.tensor as pt
7+
from arviz import dict_to_dataset
68
from pymc import SymbolicRandomVariable
9+
from pymc.backends.arviz import coords_and_dims_for_inferencedata
710
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
811
from pymc.distributions.transforms import Chain
912
from pymc.logprob.abstract import _logprob
1013
from pymc.logprob.basic import conditional_logp
1114
from pymc.logprob.transforms import IntervalTransform
1215
from pymc.model import Model
13-
from pymc.pytensorf import constant_fold, inputvars
16+
from pymc.pytensorf import compile_pymc, constant_fold, inputvars
17+
from pymc.util import dataset_to_point_list, treedict
1418
from pytensor import Mode
1519
from pytensor.compile import SharedVariable
1620
from pytensor.compile.builders import OpFromGraph
@@ -206,7 +210,7 @@ def clone(self):
206210
cloned_vars = clone_replace(vars)
207211
vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)}
208212

209-
m.named_vars = {name: vars_to_clone[var] for name, var in self.named_vars.items()}
213+
m.named_vars = treedict({name: vars_to_clone[var] for name, var in self.named_vars.items()})
210214
m.named_vars_to_dims = self.named_vars_to_dims
211215
m.values_to_rvs = {i: vars_to_clone[rv] for i, rv in self.values_to_rvs.items()}
212216
m.rvs_to_values = {vars_to_clone[rv]: i for rv, i in self.rvs_to_values.items()}
@@ -244,6 +248,127 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV
244248
# Raise errors and warnings immediately
245249
self.clone()._marginalize(user_warnings=True)
246250

251+
def unmarginalize(
252+
self, idata, var_names=None, include_samples=False, extend_inferencedata=True
253+
):
254+
"""Computes log-likelihoods of marginalized variables conditioned on parameters
255+
of the model given InferenceData with posterior group
256+
257+
Parameters
258+
----------
259+
idata : InferenceData
260+
InferenceData with posterior group
261+
var_names : sequence of str, optional
262+
List of Observed variable names for which to compute log_likelihood. Defaults to all observed variables
263+
include_samples : bool, default False
264+
Include samples of the marginalized variables
265+
extend_inferencedata : bool, default True
266+
Whether to extend the original InferenceData or return a new one
267+
268+
Returns
269+
-------
270+
idata : InferenceData
271+
InferenceData with var_names added to posterior
272+
273+
"""
274+
if var_names is None:
275+
var_names = self.marginalized_rvs
276+
277+
joint_logp = self.logp()
278+
posterior = idata.posterior
279+
280+
# Remove Deterministics
281+
posterior_values = posterior[
282+
[rv.name for rv in mm.free_RVs if rv not in self.marginalized_rvs]
283+
]
284+
285+
sample_dims = ("chain", "draw")
286+
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)
287+
rv_dict = {}
288+
rv_dims_dict = {}
289+
290+
for rv in var_names:
291+
m = self.clone()
292+
m.register_rv(rv, name=rv.name)
293+
m.marginalized_rvs = [r for r in m.marginalized_rvs if r is not rv]
294+
295+
rv_shape = constant_fold(tuple(rv.shape))
296+
rv_domain = get_domain_of_finite_discrete_rv(rv)
297+
rv_domain_tensor = pt.swapaxes(
298+
pt.full(
299+
(*rv_shape, len(rv_domain)),
300+
rv_domain,
301+
dtype=rv.dtype,
302+
),
303+
axis1=0,
304+
axis2=-1,
305+
)
306+
307+
marginalized_value = m.rvs_to_values[rv]
308+
309+
other_values = [v for v in m.value_vars if v is not marginalized_value]
310+
311+
# TODO: Handle constants
312+
# TODO: Handle transformed variables
313+
joint_logp_op = OpFromGraph(
314+
[marginalized_value] + other_values, [joint_logp], inline=True
315+
)
316+
joint_logps = [
317+
joint_logp_op(rv_domain_tensor[i], *other_values) for i in range(len(rv_domain))
318+
]
319+
320+
rv_loglike_fn = None
321+
if include_samples:
322+
sample_rv_outs = pm.Categorical.dist(logit_p=joint_logps)
323+
rv_loglike_fn = compile_pymc(
324+
inputs=other_values,
325+
outputs=[pt.stack(joint_logps, 0), sample_rv_outs],
326+
on_unused_input="ignore",
327+
)
328+
else:
329+
rv_loglike_fn = compile_pymc(
330+
inputs=other_values,
331+
outputs=pt.stack(joint_logps, 0),
332+
on_unused_input="ignore",
333+
)
334+
335+
logvs = [rv_loglike_fn(**vs) for vs in posterior_pts]
336+
337+
if include_samples:
338+
logps, samples = zip(*logvs)
339+
logps = np.array(logps)
340+
rv_dict[rv.name] = np.reshape(
341+
samples, tuple(len(coord) for coord in stacked_dims.values())
342+
)
343+
rv_dims_dict[rv.name] = sample_dims
344+
rv_dict["lp_" + rv.name] = np.reshape(
345+
logps, tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:]
346+
)
347+
rv_dims_dict["lp_" + rv.name] = sample_dims + ("lp_" + rv.name + "_dims",)
348+
else:
349+
logps = np.array(logvs)
350+
rv_dict["lp_" + rv.name] = np.reshape(
351+
logps, tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:]
352+
)
353+
rv_dims_dict["lp_" + rv.name] = sample_dims + ("lp_" + rv.name + "_dims",)
354+
355+
coords, dims = coords_and_dims_for_inferencedata(model)
356+
rv_dataset = dict_to_dataset(
357+
rv_dict,
358+
library=pymc,
359+
dims=dims,
360+
coords=coords,
361+
default_dims=list(sample_dims),
362+
skip_event_dims=True,
363+
)
364+
365+
if extend_inferencedata:
366+
rv_dict = {k: (rv_dims_dict[k], v) for (k, v) in rv_dict.items()}
367+
idata = idata.posterior.assign(**rv_dict)
368+
return idata
369+
else:
370+
return rv_dataset
371+
247372

248373
class MarginalRV(SymbolicRandomVariable):
249374
"""Base class for Marginalized RVs"""

pymc_experimental/tests/model/test_marginal_model.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,12 @@ def test_marginalized_bernoulli_logp():
5252

5353
idx = pm.Bernoulli.dist(0.7, name="idx")
5454
y = pm.Normal.dist(mu=mu[idx], sigma=1.0, name="y")
55-
marginal_rv_node = FiniteDiscreteMarginalRV([mu], [idx, y], ndim_supp=None, n_updates=0,)(
55+
marginal_rv_node = FiniteDiscreteMarginalRV(
56+
[mu],
57+
[idx, y],
58+
ndim_supp=None,
59+
n_updates=0,
60+
)(
5661
mu
5762
)[0].owner
5863

@@ -251,6 +256,26 @@ def test_marginalized_change_point_model_sampling(disaster_model):
251256
)
252257

253258

259+
@pytest.mark.slow
260+
@pytest.mark.filterwarnings("error")
261+
def test_unmarginalized_basic(disaster_model):
262+
m, years = disaster_model
263+
264+
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
265+
m.marginalize([m["switchpoint"]])
266+
267+
rng = np.random.default_rng(211)
268+
269+
with m:
270+
idata = pm.sample(chains=2, random_seed=rng).posterior.stack(sample=("draw", "chain"))
271+
272+
idata = m.unmarginalize(idata, include_samples=True)
273+
assert "switchpoint" in idata
274+
assert "lp_switchpoint" in idata
275+
assert idata.switchpoint.shape == idata.early_mean.shape
276+
assert idata.lp_switchpoint.shape == idata.switchpoint.shape + (len(years),)
277+
278+
254279
@pytest.mark.filterwarnings("error")
255280
def test_not_supported_marginalized():
256281
"""Marginalized graphs with non-Elemwise Operations are not supported as they

0 commit comments

Comments
 (0)