Skip to content

Commit 4f75687

Browse files
authored
Implement utility to recover marginalized variables from MarginalModel (#285)
Adding recover_marginals utility function
1 parent 99f30aa commit 4f75687

File tree

2 files changed

+422
-10
lines changed

2 files changed

+422
-10
lines changed

pymc_experimental/model/marginal_model.py

+249-9
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,37 @@
22
from typing import Sequence, Tuple, Union
33

44
import numpy as np
5+
import pymc
56
import pytensor.tensor as pt
7+
from arviz import dict_to_dataset
68
from pymc import SymbolicRandomVariable
9+
from pymc.backends.arviz import coords_and_dims_for_inferencedata
710
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
811
from pymc.distributions.transforms import Chain
912
from pymc.logprob.abstract import _logprob
1013
from pymc.logprob.basic import conditional_logp
1114
from pymc.logprob.transforms import IntervalTransform
1215
from pymc.model import Model
13-
from pymc.pytensorf import constant_fold, inputvars
16+
from pymc.pytensorf import compile_pymc, constant_fold, inputvars
17+
from pymc.util import _get_seeds_per_chain, dataset_to_point_list, treedict
1418
from pytensor import Mode
1519
from pytensor.compile import SharedVariable
1620
from pytensor.compile.builders import OpFromGraph
17-
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
21+
from pytensor.graph import (
22+
Constant,
23+
FunctionGraph,
24+
ancestors,
25+
clone_replace,
26+
vectorize_graph,
27+
)
1828
from pytensor.scan import map as scan_map
1929
from pytensor.tensor import TensorVariable
2030
from pytensor.tensor.elemwise import Elemwise
31+
from pytensor.tensor.shape import Shape
32+
from pytensor.tensor.special import log_softmax
2133

2234
__all__ = ["MarginalModel"]
2335

24-
from pytensor.tensor.shape import Shape
25-
2636

2737
class MarginalModel(Model):
2838
"""Subclass of PyMC Model that implements functionality for automatic
@@ -74,6 +84,7 @@ class MarginalModel(Model):
7484
def __init__(self, *args, **kwargs):
7585
super().__init__(*args, **kwargs)
7686
self.marginalized_rvs = []
87+
self._marginalized_named_vars_to_dims = treedict()
7788

7889
def _delete_rv_mappings(self, rv: TensorVariable) -> None:
7990
"""Remove all model mappings referring to rv
@@ -205,8 +216,9 @@ def clone(self):
205216
vars = self.basic_RVs + self.potentials + self.deterministics + self.marginalized_rvs
206217
cloned_vars = clone_replace(vars)
207218
vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)}
219+
m.vars_to_clone = vars_to_clone
208220

209-
m.named_vars = {name: vars_to_clone[var] for name, var in self.named_vars.items()}
221+
m.named_vars = treedict({name: vars_to_clone[var] for name, var in self.named_vars.items()})
210222
m.named_vars_to_dims = self.named_vars_to_dims
211223
m.values_to_rvs = {i: vars_to_clone[rv] for i, rv in self.values_to_rvs.items()}
212224
m.rvs_to_values = {vars_to_clone[rv]: i for rv, i in self.rvs_to_values.items()}
@@ -220,11 +232,18 @@ def clone(self):
220232
m.deterministics = [vars_to_clone[det] for det in self.deterministics]
221233

222234
m.marginalized_rvs = [vars_to_clone[rv] for rv in self.marginalized_rvs]
235+
m._marginalized_named_vars_to_dims = self._marginalized_named_vars_to_dims
223236
return m
224237

225-
def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorVariable]]):
238+
def marginalize(
239+
self,
240+
rvs_to_marginalize: Union[TensorVariable, str, Sequence[TensorVariable], Sequence[str]],
241+
):
226242
if not isinstance(rvs_to_marginalize, Sequence):
227243
rvs_to_marginalize = (rvs_to_marginalize,)
244+
rvs_to_marginalize = [
245+
self[var] if isinstance(var, str) else var for var in rvs_to_marginalize
246+
]
228247

229248
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
230249
for rv_to_marginalize in rvs_to_marginalize:
@@ -238,12 +257,233 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV
238257
f"Supported distribution include {supported_dists}"
239258
)
240259

260+
if rv_to_marginalize.name in self.named_vars_to_dims:
261+
dims = self.named_vars_to_dims[rv_to_marginalize.name]
262+
self._marginalized_named_vars_to_dims[rv_to_marginalize.name] = dims
263+
241264
self._delete_rv_mappings(rv_to_marginalize)
242265
self.marginalized_rvs.append(rv_to_marginalize)
243266

244267
# Raise errors and warnings immediately
245268
self.clone()._marginalize(user_warnings=True)
246269

270+
def _to_transformed(self):
271+
"Create a function from the untransformed space to the transformed space"
272+
transformed_rvs = []
273+
transformed_names = []
274+
275+
for rv in self.free_RVs:
276+
transform = self.rvs_to_transforms.get(rv)
277+
if transform is None:
278+
transformed_rvs.append(rv)
279+
transformed_names.append(rv.name)
280+
else:
281+
transformed_rv = transform.forward(rv, *rv.owner.inputs)
282+
transformed_rvs.append(transformed_rv)
283+
transformed_names.append(self.rvs_to_values[rv].name)
284+
285+
fn = self.compile_fn(inputs=self.free_RVs, outs=transformed_rvs)
286+
return fn, transformed_names
287+
288+
def unmarginalize(self, rvs_to_unmarginalize):
289+
for rv in rvs_to_unmarginalize:
290+
self.marginalized_rvs.remove(rv)
291+
if rv.name in self._marginalized_named_vars_to_dims:
292+
dims = self._marginalized_named_vars_to_dims.pop(rv.name)
293+
else:
294+
dims = None
295+
self.register_rv(rv, name=rv.name, dims=dims)
296+
297+
def recover_marginals(
298+
self,
299+
idata,
300+
var_names=None,
301+
return_samples=True,
302+
extend_inferencedata=True,
303+
random_seed=None,
304+
):
305+
"""Computes posterior log-probabilities and samples of marginalized variables
306+
conditioned on parameters of the model given InferenceData with posterior group
307+
308+
When there are multiple marginalized variables, each marginalized variable is
309+
conditioned on both the parameters and the other variables still marginalized
310+
311+
All log-probabilities are within the transformed space
312+
313+
Parameters
314+
----------
315+
idata : InferenceData
316+
InferenceData with posterior group
317+
var_names : sequence of str, optional
318+
List of variable names for which to compute posterior log-probabilities and samples. Defaults to all marginalized variables
319+
return_samples : bool, default True
320+
If True, also return samples of the marginalized variables
321+
extend_inferencedata : bool, default True
322+
Whether to extend the original InferenceData or return a new one
323+
random_seed: int, array-like of int or SeedSequence, optional
324+
Seed used to generating samples
325+
326+
Returns
327+
-------
328+
idata : InferenceData
329+
InferenceData with where a lp_{varname} and {varname} for each marginalized variable in var_names added to the posterior group
330+
331+
.. code-block:: python
332+
333+
import pymc as pm
334+
from pymc_experimental import MarginalModel
335+
336+
with MarginalModel() as m:
337+
p = pm.Beta("p", 1, 1)
338+
x = pm.Bernoulli("x", p=p, shape=(3,))
339+
y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])
340+
341+
m.marginalize([x])
342+
343+
idata = pm.sample()
344+
m.recover_marginals(idata, var_names=["x"])
345+
346+
347+
"""
348+
if var_names is None:
349+
var_names = [var.name for var in self.marginalized_rvs]
350+
351+
var_names = [var if isinstance(var, str) else var.name for var in var_names]
352+
vars_to_recover = [v for v in self.marginalized_rvs if v.name in var_names]
353+
missing_names = [v.name for v in vars_to_recover if v not in self.marginalized_rvs]
354+
if missing_names:
355+
raise ValueError(f"Unrecognized var_names: {missing_names}")
356+
357+
if return_samples and random_seed is not None:
358+
seeds = _get_seeds_per_chain(random_seed, len(vars_to_recover))
359+
else:
360+
seeds = [None] * len(vars_to_recover)
361+
362+
posterior = idata.posterior
363+
364+
# Remove Deterministics
365+
posterior_values = posterior[
366+
[rv.name for rv in self.free_RVs if rv not in self.marginalized_rvs]
367+
]
368+
369+
sample_dims = ("chain", "draw")
370+
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)
371+
372+
# Handle Transforms
373+
transform_fn, transform_names = self._to_transformed()
374+
375+
def transform_input(inputs):
376+
return dict(zip(transform_names, transform_fn(inputs)))
377+
378+
posterior_pts = [transform_input(vs) for vs in posterior_pts]
379+
380+
rv_dict = {}
381+
rv_dims = {}
382+
for seed, rv in zip(seeds, vars_to_recover):
383+
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
384+
if not isinstance(rv.owner.op, supported_dists):
385+
raise NotImplementedError(
386+
f"RV with distribution {rv.owner.op} cannot be recovered. "
387+
f"Supported distribution include {supported_dists}"
388+
)
389+
390+
m = self.clone()
391+
rv = m.vars_to_clone[rv]
392+
m.unmarginalize([rv])
393+
dependent_vars = find_conditional_dependent_rvs(rv, m.basic_RVs)
394+
joint_logps = m.logp(vars=dependent_vars + [rv], sum=False)
395+
396+
marginalized_value = m.rvs_to_values[rv]
397+
other_values = [v for v in m.value_vars if v is not marginalized_value]
398+
399+
# Handle batch dims for marginalized value and its dependent RVs
400+
joint_logp = joint_logps[-1]
401+
for dv in joint_logps[:-1]:
402+
dbcast = dv.type.broadcastable
403+
mbcast = marginalized_value.type.broadcastable
404+
mbcast = (True,) * (len(dbcast) - len(mbcast)) + mbcast
405+
values_axis_bcast = [
406+
i for i, (m, v) in enumerate(zip(mbcast, dbcast)) if m and not v
407+
]
408+
joint_logp += dv.sum(values_axis_bcast)
409+
410+
rv_shape = constant_fold(tuple(rv.shape))
411+
rv_domain = get_domain_of_finite_discrete_rv(rv)
412+
rv_domain_tensor = pt.moveaxis(
413+
pt.full(
414+
(*rv_shape, len(rv_domain)),
415+
rv_domain,
416+
dtype=rv.dtype,
417+
),
418+
-1,
419+
0,
420+
)
421+
422+
joint_logps = vectorize_graph(
423+
joint_logp,
424+
replace={marginalized_value: rv_domain_tensor},
425+
)
426+
joint_logps = pt.moveaxis(joint_logps, 0, -1)
427+
428+
rv_loglike_fn = None
429+
joint_logps_norm = log_softmax(joint_logps, axis=-1)
430+
if return_samples:
431+
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps)
432+
if isinstance(rv.owner.op, DiscreteUniform):
433+
sample_rv_outs += rv_domain[0]
434+
435+
rv_loglike_fn = compile_pymc(
436+
inputs=other_values,
437+
outputs=[joint_logps_norm, sample_rv_outs],
438+
on_unused_input="ignore",
439+
random_seed=seed,
440+
)
441+
else:
442+
rv_loglike_fn = compile_pymc(
443+
inputs=other_values,
444+
outputs=joint_logps_norm,
445+
on_unused_input="ignore",
446+
random_seed=seed,
447+
)
448+
449+
logvs = [rv_loglike_fn(**vs) for vs in posterior_pts]
450+
451+
logps = None
452+
samples = None
453+
if return_samples:
454+
logps, samples = zip(*logvs)
455+
logps = np.array(logps)
456+
samples = np.array(samples)
457+
rv_dict[rv.name] = samples.reshape(
458+
tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:],
459+
)
460+
else:
461+
logps = np.array(logvs)
462+
463+
rv_dict["lp_" + rv.name] = logps.reshape(
464+
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:],
465+
)
466+
if rv.name in m.named_vars_to_dims:
467+
rv_dims[rv.name] = list(m.named_vars_to_dims[rv.name])
468+
rv_dims["lp_" + rv.name] = rv_dims[rv.name] + ["lp_" + rv.name + "_dim"]
469+
470+
coords, dims = coords_and_dims_for_inferencedata(self)
471+
dims.update(rv_dims)
472+
rv_dataset = dict_to_dataset(
473+
rv_dict,
474+
library=pymc,
475+
dims=dims,
476+
coords=coords,
477+
default_dims=list(sample_dims),
478+
skip_event_dims=True,
479+
)
480+
481+
if extend_inferencedata:
482+
idata.posterior = idata.posterior.assign(rv_dataset)
483+
return idata
484+
else:
485+
return rv_dataset
486+
247487

248488
class MarginalRV(SymbolicRandomVariable):
249489
"""Base class for Marginalized RVs"""
@@ -444,14 +684,14 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
444684
# PyMC does not allow RVs in the logp graph, even if we are just using the shape
445685
marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape))
446686
marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
447-
marginalized_rv_domain_tensor = pt.swapaxes(
687+
marginalized_rv_domain_tensor = pt.moveaxis(
448688
pt.full(
449689
(*marginalized_rv_shape, len(marginalized_rv_domain)),
450690
marginalized_rv_domain,
451691
dtype=marginalized_rv.dtype,
452692
),
453-
axis1=0,
454-
axis2=-1,
693+
-1,
694+
0,
455695
)
456696

457697
# Arbitrary cutoff to switch to Scan implementation to keep graph size under control

0 commit comments

Comments
 (0)