Skip to content

Commit c21ae69

Browse files
committed
Adding recover_marginals utility function
1 parent f1ece1c commit c21ae69

File tree

2 files changed

+284
-6
lines changed

2 files changed

+284
-6
lines changed

pymc_experimental/model/marginal_model.py

+186-5
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,37 @@
22
from typing import Sequence, Tuple, Union
33

44
import numpy as np
5+
import pymc
56
import pytensor.tensor as pt
7+
from arviz import dict_to_dataset
68
from pymc import SymbolicRandomVariable
9+
from pymc.backends.arviz import coords_and_dims_for_inferencedata
710
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
811
from pymc.distributions.transforms import Chain
912
from pymc.logprob.abstract import _logprob
1013
from pymc.logprob.basic import conditional_logp
1114
from pymc.logprob.transforms import IntervalTransform
1215
from pymc.model import Model
13-
from pymc.pytensorf import constant_fold, inputvars
16+
from pymc.pytensorf import compile_pymc, constant_fold, inputvars
17+
from pymc.util import dataset_to_point_list, treedict
1418
from pytensor import Mode
1519
from pytensor.compile import SharedVariable
1620
from pytensor.compile.builders import OpFromGraph
17-
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
21+
from pytensor.graph import (
22+
Constant,
23+
FunctionGraph,
24+
ancestors,
25+
clone_replace,
26+
vectorize_graph,
27+
)
1828
from pytensor.scan import map as scan_map
1929
from pytensor.tensor import TensorVariable
2030
from pytensor.tensor.elemwise import Elemwise
31+
from pytensor.tensor.shape import Shape
32+
from scipy.special import log_softmax
2133

2234
__all__ = ["MarginalModel"]
2335

24-
from pytensor.tensor.shape import Shape
25-
2636

2737
class MarginalModel(Model):
2838
"""Subclass of PyMC Model that implements functionality for automatic
@@ -205,8 +215,9 @@ def clone(self):
205215
vars = self.basic_RVs + self.potentials + self.deterministics + self.marginalized_rvs
206216
cloned_vars = clone_replace(vars)
207217
vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)}
218+
m.vars_to_clone = vars_to_clone
208219

209-
m.named_vars = {name: vars_to_clone[var] for name, var in self.named_vars.items()}
220+
m.named_vars = treedict({name: vars_to_clone[var] for name, var in self.named_vars.items()})
210221
m.named_vars_to_dims = self.named_vars_to_dims
211222
m.values_to_rvs = {i: vars_to_clone[rv] for i, rv in self.values_to_rvs.items()}
212223
m.rvs_to_values = {vars_to_clone[rv]: i for rv, i in self.rvs_to_values.items()}
@@ -244,6 +255,176 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV
244255
# Raise errors and warnings immediately
245256
self.clone()._marginalize(user_warnings=True)
246257

258+
def _to_transformed(self):
259+
"Create a function from the untransformed space to the transformed space"
260+
transformed_rvs = []
261+
transformed_names = []
262+
263+
for rv in self.free_RVs:
264+
transform = self.rvs_to_transforms.get(rv)
265+
if transform is None:
266+
transformed_rvs.append(rv)
267+
transformed_names.append(rv.name)
268+
else:
269+
transformed_rv = transform.forward(rv, *rv.owner.inputs)
270+
transformed_rvs.append(transformed_rv)
271+
transformed_names.append(self.rvs_to_values[rv].name)
272+
273+
fn = self.compile_fn(inputs=self.free_RVs, outs=transformed_rvs)
274+
return fn, transformed_names
275+
276+
def unmarginalize(self, rvs_to_unmarginalize):
277+
for rv in rvs_to_unmarginalize:
278+
self.marginalized_rvs.remove(rv)
279+
self.register_rv(rv, name=rv.name)
280+
281+
def recover_marginals(
282+
self, idata, var_names=None, return_samples=True, extend_inferencedata=True
283+
):
284+
"""Computes normalized posterior probabilities of marginalized variables
285+
conditioned on parameters of the model given InferenceData with posterior group
286+
287+
When there are multiple marginalized variables, each marginalized variable is
288+
conditioned on both the parameters and the other variables still marginalized
289+
290+
All log-probabilities are within the transformed space
291+
292+
Parameters
293+
----------
294+
idata : InferenceData
295+
InferenceData with posterior group
296+
var_names : sequence of str, optional
297+
List of Observed variable names for which to compute log_likelihood. Defaults to all observed variables
298+
return_samples : bool, default True
299+
If True, also return samples of the marginalized variables
300+
extend_inferencedata : bool, default True
301+
Whether to extend the original InferenceData or return a new one
302+
303+
Returns
304+
-------
305+
idata : InferenceData
306+
InferenceData with var_names added to posterior
307+
308+
"""
309+
if var_names is None:
310+
var_names = self.marginalized_rvs
311+
312+
posterior = idata.posterior
313+
314+
# Remove Deterministics
315+
posterior_values = posterior[
316+
[rv.name for rv in self.free_RVs if rv not in self.marginalized_rvs]
317+
]
318+
319+
sample_dims = ("chain", "draw")
320+
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)
321+
322+
# Handle Transforms
323+
transform_fn, transform_names = self._to_transformed()
324+
325+
def transform_input(inputs):
326+
return dict(zip(transform_names, transform_fn(inputs)))
327+
328+
posterior_pts = [transform_input(vs) for vs in posterior_pts]
329+
330+
rv_dict = {}
331+
rv_dims_dict = {}
332+
333+
for rv in var_names:
334+
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
335+
if not isinstance(rv.owner.op, supported_dists):
336+
raise NotImplementedError(
337+
f"RV with distribution {rv.owner.op} cannot be marginalized. "
338+
f"Supported distribution include {supported_dists}"
339+
)
340+
341+
m = self.clone()
342+
rv = m.vars_to_clone[rv]
343+
m.unmarginalize([rv])
344+
joint_logp = m.logp()
345+
346+
rv_shape = constant_fold(tuple(rv.shape))
347+
rv_domain = get_domain_of_finite_discrete_rv(rv)
348+
rv_domain_tensor = pt.swapaxes(
349+
pt.full(
350+
(*rv_shape, len(rv_domain)),
351+
rv_domain,
352+
dtype=rv.dtype,
353+
),
354+
axis1=0,
355+
axis2=-1,
356+
)
357+
358+
marginalized_value = m.rvs_to_values[rv]
359+
360+
other_values = [v for v in m.value_vars if v is not marginalized_value]
361+
362+
# TODO: Handle constants
363+
joint_logps = vectorize_graph(
364+
joint_logp,
365+
replace={marginalized_value: rv_domain_tensor},
366+
)
367+
368+
rv_loglike_fn = None
369+
if return_samples:
370+
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps)
371+
rv_loglike_fn = compile_pymc(
372+
inputs=other_values,
373+
outputs=[joint_logps, sample_rv_outs],
374+
on_unused_input="ignore",
375+
)
376+
else:
377+
rv_loglike_fn = compile_pymc(
378+
inputs=other_values,
379+
outputs=joint_logps,
380+
on_unused_input="ignore",
381+
)
382+
383+
logvs = [rv_loglike_fn(**vs) for vs in posterior_pts]
384+
385+
if return_samples:
386+
logps, samples = zip(*logvs)
387+
logps = np.array(logps)
388+
rv_dict[rv.name] = np.reshape(
389+
samples, tuple(len(coord) for coord in stacked_dims.values())
390+
)
391+
rv_dims_dict[rv.name] = sample_dims
392+
rv_dict["lp_" + rv.name] = log_softmax(
393+
np.reshape(
394+
logps,
395+
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:],
396+
),
397+
axis=len(stacked_dims),
398+
)
399+
rv_dims_dict["lp_" + rv.name] = sample_dims + ("lp_" + rv.name + "_dims",)
400+
else:
401+
logps = np.array(logvs)
402+
rv_dict["lp_" + rv.name] = log_softmax(
403+
np.reshape(
404+
logps,
405+
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:],
406+
),
407+
axis=len(stacked_dims),
408+
)
409+
rv_dims_dict["lp_" + rv.name] = sample_dims + ("lp_" + rv.name + "_dims",)
410+
411+
coords, dims = coords_and_dims_for_inferencedata(self)
412+
rv_dataset = dict_to_dataset(
413+
rv_dict,
414+
library=pymc,
415+
dims=dims,
416+
coords=coords,
417+
default_dims=list(sample_dims),
418+
skip_event_dims=True,
419+
)
420+
421+
if extend_inferencedata:
422+
rv_dict = {k: (rv_dims_dict[k], v) for (k, v) in rv_dict.items()}
423+
idata = idata.posterior.assign(**rv_dict)
424+
return idata
425+
else:
426+
return rv_dataset
427+
247428

248429
class MarginalRV(SymbolicRandomVariable):
249430
"""Base class for Marginalized RVs"""

pymc_experimental/tests/model/test_marginal_model.py

+98-1
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
import pymc as pm
77
import pytensor.tensor as pt
88
import pytest
9+
from arviz import InferenceData, dict_to_dataset
910
from pymc import ImputationWarning, inputvars
1011
from pymc.distributions import transforms
1112
from pymc.logprob.abstract import _logprob
1213
from pymc.util import UNSET
13-
from scipy.special import logsumexp
14+
from scipy.special import log_softmax, logsumexp
15+
from scipy.stats import halfnorm, norm
1416

1517
from pymc_experimental.model.marginal_model import (
1618
FiniteDiscreteMarginalRV,
@@ -251,6 +253,101 @@ def test_marginalized_change_point_model_sampling(disaster_model):
251253
)
252254

253255

256+
def test_recover_marginals_basic():
257+
with MarginalModel() as m:
258+
sigma = pm.HalfNormal("sigma")
259+
p = np.array([0.5, 0.2, 0.3])
260+
k = pm.Categorical("k", p=p)
261+
mu = np.array([-3.0, 0.0, 3.0])
262+
mu_ = pt.as_tensor_variable(mu)
263+
y = pm.Normal("y", mu=mu_[k], sigma=sigma)
264+
265+
m.marginalize([k])
266+
267+
rng = np.random.default_rng(211)
268+
269+
with m:
270+
prior = pm.sample_prior_predictive(
271+
samples=20,
272+
random_seed=rng,
273+
return_inferencedata=False,
274+
)
275+
idata = InferenceData(posterior=dict_to_dataset(prior))
276+
277+
idata = m.recover_marginals(idata, return_samples=True)
278+
assert "k" in idata
279+
assert "lp_k" in idata
280+
assert idata.k.shape == idata.y.shape
281+
assert idata.lp_k.shape == idata.k.shape + (len(p),)
282+
283+
def true_logp(y, sigma):
284+
y = y.repeat(len(p)).reshape(len(y), -1)
285+
sigma = sigma.repeat(len(p)).reshape(len(sigma), -1)
286+
return log_softmax(
287+
np.log(p)
288+
+ norm.logpdf(y, loc=mu, scale=sigma)
289+
+ halfnorm.logpdf(sigma)
290+
+ np.log(sigma),
291+
axis=1,
292+
)
293+
294+
np.testing.assert_almost_equal(
295+
true_logp(idata.y.values.flatten(), idata.sigma.values.flatten()),
296+
idata.lp_k[0].values,
297+
)
298+
299+
300+
def test_nested_recover_marginals():
301+
"""Test that marginalization works when there are nested marginalized RVs"""
302+
303+
with MarginalModel() as m:
304+
idx = pm.Bernoulli("idx", p=0.75)
305+
sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95))
306+
sub_dep = pm.Normal("y", mu=idx + sub_idx, sigma=1.0)
307+
308+
m.marginalize([idx, sub_idx])
309+
310+
rng = np.random.default_rng(211)
311+
312+
with m:
313+
prior = pm.sample_prior_predictive(
314+
samples=20,
315+
random_seed=rng,
316+
return_inferencedata=False,
317+
)
318+
idata = InferenceData(posterior=dict_to_dataset(prior))
319+
320+
idata = m.recover_marginals(idata, return_samples=True)
321+
assert "idx" in idata
322+
assert "lp_idx" in idata
323+
assert idata.idx.shape == idata.y.shape
324+
assert idata.lp_idx.shape == idata.idx.shape + (2,)
325+
assert "sub_idx" in idata
326+
assert "lp_sub_idx" in idata
327+
assert idata.sub_idx.shape == idata.y.shape
328+
assert idata.lp_sub_idx.shape == idata.sub_idx.shape + (2,)
329+
330+
def true_idx_logp(y):
331+
idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.15 * 0.25 * norm.pdf(y, loc=1))
332+
idx_1 = np.log(0.05 * 0.75 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2))
333+
return log_softmax(np.stack([idx_0, idx_1]).T, axis=1)
334+
335+
np.testing.assert_almost_equal(
336+
true_idx_logp(idata.y.values.flatten()),
337+
idata.lp_idx[0].values,
338+
)
339+
340+
def true_sub_idx_logp(y):
341+
sub_idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.05 * 0.75 * norm.pdf(y, loc=1))
342+
sub_idx_1 = np.log(0.15 * 0.25 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2))
343+
return log_softmax(np.stack([sub_idx_0, sub_idx_1]).T, axis=1)
344+
345+
np.testing.assert_almost_equal(
346+
true_sub_idx_logp(idata.y.values.flatten()),
347+
idata.lp_sub_idx[0].values,
348+
)
349+
350+
254351
@pytest.mark.filterwarnings("error")
255352
def test_not_supported_marginalized():
256353
"""Marginalized graphs with non-Elemwise Operations are not supported as they

0 commit comments

Comments
 (0)