Skip to content

Commit 971ed63

Browse files
committed
Support more kinds of marginalization via dim analysis
This commit lifts the restriction that only Elemwise operations may link marginalized to dependent RVs. We map input dims to output dims, to assess whether an operation mixes information from different dims or not. Graphs where information is not mixed can be efficiently marginalized.
1 parent a78d378 commit 971ed63

File tree

6 files changed

+917
-211
lines changed

6 files changed

+917
-211
lines changed

pymc_experimental/model/marginal/distributions.py

+126-45
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,55 @@
11
from collections.abc import Sequence
22

33
import numpy as np
4+
import pytensor.tensor as pt
45

5-
from pymc import Bernoulli, Categorical, DiscreteUniform, SymbolicRandomVariable, logp
6-
from pymc.logprob import conditional_logp
7-
from pymc.logprob.abstract import _logprob
6+
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
7+
from pymc.logprob.abstract import MeasurableOp, _logprob
8+
from pymc.logprob.basic import conditional_logp, logp
89
from pymc.pytensorf import constant_fold
9-
from pytensor import Mode, clone_replace, graph_replace, scan
10-
from pytensor import map as scan_map
11-
from pytensor import tensor as pt
12-
from pytensor.graph import vectorize_graph
13-
from pytensor.tensor import TensorType, TensorVariable
10+
from pytensor import Variable
11+
from pytensor.compile.builders import OpFromGraph
12+
from pytensor.compile.mode import Mode
13+
from pytensor.graph import Op, vectorize_graph
14+
from pytensor.graph.replace import clone_replace, graph_replace
15+
from pytensor.scan import map as scan_map
16+
from pytensor.scan import scan
17+
from pytensor.tensor import TensorVariable
1418

1519
from pymc_experimental.distributions import DiscreteMarkovChain
1620

1721

18-
class MarginalRV(SymbolicRandomVariable):
22+
class MarginalRV(OpFromGraph, MeasurableOp):
1923
"""Base class for Marginalized RVs"""
2024

25+
def __init__(self, *args, dims_connections: tuple[tuple[int | None]], **kwargs) -> None:
26+
self.dims_connections = dims_connections
27+
super().__init__(*args, **kwargs)
2128

22-
class FiniteDiscreteMarginalRV(MarginalRV):
23-
"""Base class for Finite Discrete Marginalized RVs"""
29+
@property
30+
def support_axes(self) -> tuple[tuple[int]]:
31+
"""Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable."""
32+
marginalized_ndim_supp = self.inner_outputs[0].owner.op.ndim_supp
33+
support_axes_vars = []
34+
for dims_connection in self.dims_connections:
35+
ndim = len(dims_connection)
36+
marginalized_supp_axes = ndim - marginalized_ndim_supp
37+
support_axes_vars.append(
38+
tuple(
39+
-i
40+
for i, dim in enumerate(reversed(dims_connection), start=1)
41+
if (dim is None or dim > marginalized_supp_axes)
42+
)
43+
)
44+
return tuple(support_axes_vars)
2445

2546

26-
class DiscreteMarginalMarkovChainRV(MarginalRV):
27-
"""Base class for Discrete Marginal Markov Chain RVs"""
47+
class MarginalFiniteDiscreteRV(MarginalRV):
48+
"""Base class for Marginalized Finite Discrete RVs"""
49+
50+
51+
class MarginalDiscreteMarkovChainRV(MarginalRV):
52+
"""Base class for Marginalized Discrete Markov Chain RVs"""
2853

2954

3055
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
@@ -34,7 +59,8 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
3459
return (0, 1)
3560
elif isinstance(op, Categorical):
3661
[p_param] = dist_params
37-
return tuple(range(pt.get_vector_length(p_param)))
62+
[p_param_length] = constant_fold([p_param.shape[-1]])
63+
return tuple(range(p_param_length))
3864
elif isinstance(op, DiscreteUniform):
3965
lower, upper = constant_fold(dist_params)
4066
return tuple(np.arange(lower, upper + 1))
@@ -45,31 +71,77 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
4571
raise NotImplementedError(f"Cannot compute domain for op {op}")
4672

4773

48-
def _add_reduce_batch_dependent_logps(
49-
marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable]
50-
):
51-
"""Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`."""
74+
def reduce_batch_dependent_logps(
75+
dependent_dims_connections: Sequence[tuple[int | None, ...]],
76+
dependent_ops: Sequence[Op],
77+
dependent_logps: Sequence[TensorVariable],
78+
) -> TensorVariable:
79+
"""Combine the logps of dependent RVs and align them with the marginalized logp.
80+
81+
This requires reducing extra batch dims and transposing when they are not aligned.
82+
83+
idx = pm.Bernoulli(idx, shape=(3, 2)) # 0, 1
84+
pm.Normal("dep1", mu=idx.T[..., None] * 2, shape=(3, 2, 5))
85+
pm.Normal("dep2", mu=idx * 2, shape=(7, 2, 3))
86+
87+
marginalize(idx)
88+
89+
The marginalized op will have dims_connections = [(1, 0, None), (None, 0, 1)]
90+
which tells us we need to reduce the last axis of dep1 logp and the first of dep2 logp,
91+
as well as transpose the remaining axis of dep1 logp before adding the two element-wise.
92+
93+
"""
94+
from pymc_experimental.model.marginal.graph_analysis import get_support_axes
5295

53-
mbcast = marginalized_type.broadcastable
5496
reduced_logps = []
55-
for dependent_logp in dependent_logps:
56-
dbcast = dependent_logp.type.broadcastable
57-
dim_diff = len(dbcast) - len(mbcast)
58-
mbcast_aligned = (True,) * dim_diff + mbcast
59-
vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v]
60-
reduced_logps.append(dependent_logp.sum(vbcast_axis))
61-
return pt.add(*reduced_logps)
97+
for dependent_op, dependent_logp, dependent_dims_connection in zip(
98+
dependent_ops, dependent_logps, dependent_dims_connections
99+
):
100+
if dependent_logp.type.ndim > 0:
101+
# Find which support axis implied by the MarginalRV need to be reduced
102+
# Some may have already been reduced by the logp expression of the dependent RV (e.g., multivariate RVs)
103+
dep_supp_axes = get_support_axes(dependent_op)[0]
62104

105+
# Dependent RV support axes are already collapsed in the logp, so we ignore them
106+
supp_axes = [
107+
-i
108+
for i, dim in enumerate(reversed(dependent_dims_connection), start=1)
109+
if (dim is None and -i not in dep_supp_axes)
110+
]
111+
dependent_logp = dependent_logp.sum(supp_axes)
63112

64-
@_logprob.register(FiniteDiscreteMarginalRV)
65-
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
66-
# Clone the inner RV graph of the Marginalized RV
113+
# Finally, we need to align the dependent logp batch dimensions with the marginalized logp
114+
dims_alignment = [dim for dim in dependent_dims_connection if dim is not None]
115+
dependent_logp = dependent_logp.transpose(*dims_alignment)
116+
117+
reduced_logps.append(dependent_logp)
118+
119+
reduced_logp = pt.add(*reduced_logps)
120+
return reduced_logp
121+
122+
123+
def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> TensorVariable:
124+
"""Align the logp with the order specified in dims."""
125+
dims_alignment = [dim for dim in dims if dim is not None]
126+
return logp.transpose(*dims_alignment)
127+
128+
129+
def recreate_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
67130
marginalized_rvs_node = op.make_node(*inputs)
68-
marginalized_rv, *inner_rvs = clone_replace(
131+
return clone_replace(
69132
op.inner_outputs,
70133
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
71134
)
72135

136+
137+
DUMMY_ZERO = pt.constant(0, name="dummy_zero")
138+
139+
140+
@_logprob.register(MarginalFiniteDiscreteRV)
141+
def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inputs, **kwargs):
142+
# Clone the inner RV graph of the Marginalized RV
143+
marginalized_rv, *inner_rvs = recreate_ofg_outputs(op, inputs)
144+
73145
# Obtain the joint_logp graph of the inner RV graph
74146
inner_rv_values = dict(zip(inner_rvs, values))
75147
marginalized_vv = marginalized_rv.clone()
@@ -78,8 +150,10 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
78150

79151
# Reduce logp dimensions corresponding to broadcasted variables
80152
marginalized_logp = logps_dict.pop(marginalized_vv)
81-
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
82-
marginalized_rv.type, logps_dict.values()
153+
joint_logp = marginalized_logp + reduce_batch_dependent_logps(
154+
dependent_dims_connections=op.dims_connections,
155+
dependent_ops=[inner_rv.owner.op for inner_rv in inner_rvs],
156+
dependent_logps=[logps_dict[value] for value in values],
83157
)
84158

85159
# Compute the joint_logp for all possible n values of the marginalized RV. We assume
@@ -116,21 +190,20 @@ def logp_fn(marginalized_rv_const, *non_sequences):
116190
mode=Mode().including("local_remove_check_parameter"),
117191
)
118192

119-
joint_logps = pt.logsumexp(joint_logps, axis=0)
193+
joint_logp = pt.logsumexp(joint_logps, axis=0)
194+
195+
# Align logp with non-collapsed batch dimensions of first RV
196+
joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp)
120197

121198
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
122-
return joint_logps, *(pt.constant(0),) * (len(values) - 1)
199+
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
200+
return joint_logp, *dummy_logps
123201

124202

125-
@_logprob.register(DiscreteMarginalMarkovChainRV)
203+
@_logprob.register(MarginalDiscreteMarkovChainRV)
126204
def marginal_hmm_logp(op, values, *inputs, **kwargs):
127-
marginalized_rvs_node = op.make_node(*inputs)
128-
inner_rvs = clone_replace(
129-
op.inner_outputs,
130-
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
131-
)
205+
chain_rv, *dependent_rvs = recreate_ofg_outputs(op, inputs)
132206

133-
chain_rv, *dependent_rvs = inner_rvs
134207
P, n_steps_, init_dist_, rng = chain_rv.owner.inputs
135208
domain = pt.arange(P.shape[-1], dtype="int32")
136209

@@ -145,8 +218,10 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
145218
logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values)))
146219

147220
# Reduce and add the batch dims beyond the chain dimension
148-
reduced_logp_emissions = _add_reduce_batch_dependent_logps(
149-
chain_rv.type, logp_emissions_dict.values()
221+
reduced_logp_emissions = reduce_batch_dependent_logps(
222+
dependent_dims_connections=op.dims_connections,
223+
dependent_ops=[dependent_rv.owner.op for dependent_rv in dependent_rvs],
224+
dependent_logps=[logp_emissions_dict[value] for value in values],
150225
)
151226

152227
# Add a batch dimension for the domain of the chain
@@ -185,7 +260,13 @@ def step_alpha(logp_emission, log_alpha, log_P):
185260
# Final logp is just the sum of the last scan state
186261
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
187262

263+
# Align logp with non-collapsed batch dimensions of first RV
264+
remaining_dims_first_emission = list(op.dims_connections[0])
265+
# The last dim of chain_rv was removed when computing the logp
266+
remaining_dims_first_emission.remove(chain_rv.type.ndim - 1)
267+
joint_logp = align_logp_dims(remaining_dims_first_emission, joint_logp)
268+
188269
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
189-
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
190-
dummy_logps = (pt.constant(0),) * (len(values) - 1)
270+
# return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
271+
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
191272
return joint_logp, *dummy_logps

0 commit comments

Comments
 (0)