Skip to content

Commit d965959

Browse files
committed
Reorganize marginalize codebase
1 parent c8a4d85 commit d965959

10 files changed

+1117
-1068
lines changed

pymc_experimental/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pymc_experimental import gp, statespace, utils
1717
from pymc_experimental.distributions import *
1818
from pymc_experimental.inference.fit import fit
19-
from pymc_experimental.model.marginal_model import MarginalModel
19+
from pymc_experimental.model.marginal.marginal_model import MarginalModel
2020
from pymc_experimental.model.model_api import as_model
2121
from pymc_experimental.version import __version__
2222

pymc_experimental/model/marginal/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
from collections.abc import Sequence
2+
3+
import numpy as np
4+
5+
from pymc import Bernoulli, Categorical, DiscreteUniform, SymbolicRandomVariable, logp
6+
from pymc.logprob import conditional_logp
7+
from pymc.logprob.abstract import _logprob
8+
from pymc.pytensorf import constant_fold
9+
from pytensor import Mode, clone_replace, graph_replace, scan
10+
from pytensor import map as scan_map
11+
from pytensor import tensor as pt
12+
from pytensor.graph import vectorize_graph
13+
from pytensor.tensor import TensorType, TensorVariable
14+
15+
from pymc_experimental.distributions import DiscreteMarkovChain
16+
17+
18+
class MarginalRV(SymbolicRandomVariable):
19+
"""Base class for Marginalized RVs"""
20+
21+
22+
class FiniteDiscreteMarginalRV(MarginalRV):
23+
"""Base class for Finite Discrete Marginalized RVs"""
24+
25+
26+
class DiscreteMarginalMarkovChainRV(MarginalRV):
27+
"""Base class for Discrete Marginal Markov Chain RVs"""
28+
29+
30+
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
31+
op = rv.owner.op
32+
dist_params = rv.owner.op.dist_params(rv.owner)
33+
if isinstance(op, Bernoulli):
34+
return (0, 1)
35+
elif isinstance(op, Categorical):
36+
[p_param] = dist_params
37+
return tuple(range(pt.get_vector_length(p_param)))
38+
elif isinstance(op, DiscreteUniform):
39+
lower, upper = constant_fold(dist_params)
40+
return tuple(np.arange(lower, upper + 1))
41+
elif isinstance(op, DiscreteMarkovChain):
42+
P, *_ = dist_params
43+
return tuple(range(pt.get_vector_length(P[-1])))
44+
45+
raise NotImplementedError(f"Cannot compute domain for op {op}")
46+
47+
48+
def _add_reduce_batch_dependent_logps(
49+
marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable]
50+
):
51+
"""Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`."""
52+
53+
mbcast = marginalized_type.broadcastable
54+
reduced_logps = []
55+
for dependent_logp in dependent_logps:
56+
dbcast = dependent_logp.type.broadcastable
57+
dim_diff = len(dbcast) - len(mbcast)
58+
mbcast_aligned = (True,) * dim_diff + mbcast
59+
vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v]
60+
reduced_logps.append(dependent_logp.sum(vbcast_axis))
61+
return pt.add(*reduced_logps)
62+
63+
64+
@_logprob.register(FiniteDiscreteMarginalRV)
65+
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
66+
# Clone the inner RV graph of the Marginalized RV
67+
marginalized_rvs_node = op.make_node(*inputs)
68+
marginalized_rv, *inner_rvs = clone_replace(
69+
op.inner_outputs,
70+
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
71+
)
72+
73+
# Obtain the joint_logp graph of the inner RV graph
74+
inner_rv_values = dict(zip(inner_rvs, values))
75+
marginalized_vv = marginalized_rv.clone()
76+
rv_values = inner_rv_values | {marginalized_rv: marginalized_vv}
77+
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
78+
79+
# Reduce logp dimensions corresponding to broadcasted variables
80+
marginalized_logp = logps_dict.pop(marginalized_vv)
81+
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
82+
marginalized_rv.type, logps_dict.values()
83+
)
84+
85+
# Compute the joint_logp for all possible n values of the marginalized RV. We assume
86+
# each original dimension is independent so that it suffices to evaluate the graph
87+
# n times, once with each possible value of the marginalized RV replicated across
88+
# batched dimensions of the marginalized RV
89+
90+
# PyMC does not allow RVs in the logp graph, even if we are just using the shape
91+
marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False)
92+
marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
93+
marginalized_rv_domain_tensor = pt.moveaxis(
94+
pt.full(
95+
(*marginalized_rv_shape, len(marginalized_rv_domain)),
96+
marginalized_rv_domain,
97+
dtype=marginalized_rv.dtype,
98+
),
99+
-1,
100+
0,
101+
)
102+
103+
try:
104+
joint_logps = vectorize_graph(
105+
joint_logp, replace={marginalized_vv: marginalized_rv_domain_tensor}
106+
)
107+
except Exception:
108+
# Fallback to Scan
109+
def logp_fn(marginalized_rv_const, *non_sequences):
110+
return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const})
111+
112+
joint_logps, _ = scan_map(
113+
fn=logp_fn,
114+
sequences=marginalized_rv_domain_tensor,
115+
non_sequences=[*values, *inputs],
116+
mode=Mode().including("local_remove_check_parameter"),
117+
)
118+
119+
joint_logps = pt.logsumexp(joint_logps, axis=0)
120+
121+
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
122+
return joint_logps, *(pt.constant(0),) * (len(values) - 1)
123+
124+
125+
@_logprob.register(DiscreteMarginalMarkovChainRV)
126+
def marginal_hmm_logp(op, values, *inputs, **kwargs):
127+
marginalized_rvs_node = op.make_node(*inputs)
128+
inner_rvs = clone_replace(
129+
op.inner_outputs,
130+
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
131+
)
132+
133+
chain_rv, *dependent_rvs = inner_rvs
134+
P, n_steps_, init_dist_, rng = chain_rv.owner.inputs
135+
domain = pt.arange(P.shape[-1], dtype="int32")
136+
137+
# Construct logp in two steps
138+
# Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)
139+
140+
# First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating
141+
# around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise,
142+
# PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step.
143+
chain_value = chain_rv.clone()
144+
dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value})
145+
logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values)))
146+
147+
# Reduce and add the batch dims beyond the chain dimension
148+
reduced_logp_emissions = _add_reduce_batch_dependent_logps(
149+
chain_rv.type, logp_emissions_dict.values()
150+
)
151+
152+
# Add a batch dimension for the domain of the chain
153+
chain_shape = constant_fold(tuple(chain_rv.shape))
154+
batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0)
155+
batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value})
156+
157+
# Step 2: Compute the transition probabilities
158+
# This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1})
159+
# We do it entirely in logs, though.
160+
161+
# To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states)
162+
# under the initial distribution. This is robust to everything the user can throw at it.
163+
init_dist_value = init_dist_.type()
164+
logp_init_dist = logp(init_dist_, init_dist_value)
165+
# There is a degerate batch dim for lags=1 (the only supported case),
166+
# that we have to work around, by expanding the batch value and then squeezing it out of the logp
167+
batch_logp_init_dist = vectorize_graph(
168+
logp_init_dist, {init_dist_value: batch_chain_value[:, None, ..., 0]}
169+
).squeeze(1)
170+
log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0]
171+
172+
def step_alpha(logp_emission, log_alpha, log_P):
173+
step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0)
174+
return logp_emission + step_log_prob
175+
176+
P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2)
177+
log_P = pt.shape_padright(pt.log(P), P_bcast_dims)
178+
log_alpha_seq, _ = scan(
179+
step_alpha,
180+
non_sequences=[log_P],
181+
outputs_info=[log_alpha_init],
182+
# Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
183+
sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0),
184+
)
185+
# Final logp is just the sum of the last scan state
186+
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
187+
188+
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
189+
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
190+
dummy_logps = (pt.constant(0),) * (len(values) - 1)
191+
return joint_logp, *dummy_logps
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from pytensor.compile import SharedVariable
2+
from pytensor.graph import Constant, FunctionGraph, ancestors
3+
from pytensor.tensor import TensorVariable
4+
from pytensor.tensor.elemwise import DimShuffle, Elemwise
5+
from pytensor.tensor.shape import Shape
6+
7+
8+
def static_shape_ancestors(vars):
9+
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
10+
return [
11+
var
12+
for var in ancestors(vars)
13+
if (
14+
var.owner
15+
and isinstance(var.owner.op, Shape)
16+
# All static dims lengths of Shape input are known
17+
and None not in var.owner.inputs[0].type.shape
18+
)
19+
]
20+
21+
22+
def find_conditional_input_rvs(output_rvs, all_rvs):
23+
"""Find conditionally indepedent input RVs."""
24+
blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
25+
blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
26+
return [
27+
var
28+
for var in ancestors(output_rvs, blockers=blockers)
29+
if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable))
30+
]
31+
32+
33+
def is_conditional_dependent(
34+
dependent_rv: TensorVariable, dependable_rv: TensorVariable, all_rvs
35+
) -> bool:
36+
"""Check if dependent_rv is conditionall dependent on dependable_rv,
37+
given all conditionally independent all_rvs"""
38+
39+
return dependable_rv in find_conditional_input_rvs((dependent_rv,), all_rvs)
40+
41+
42+
def find_conditional_dependent_rvs(dependable_rv, all_rvs):
43+
"""Find rvs than depend on dependable"""
44+
return [
45+
rv
46+
for rv in all_rvs
47+
if (rv is not dependable_rv and is_conditional_dependent(rv, dependable_rv, all_rvs))
48+
]
49+
50+
51+
def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs):
52+
# TODO: No need to consider apply nodes outside the subgraph...
53+
fg = FunctionGraph(outputs=output_rvs, clone=False)
54+
55+
non_elemwise_blockers = [
56+
o
57+
for node in fg.apply_nodes
58+
if not (
59+
isinstance(node.op, Elemwise)
60+
# Allow expand_dims on the left
61+
or (
62+
isinstance(node.op, DimShuffle)
63+
and not node.op.drop
64+
and node.op.shuffle == sorted(node.op.shuffle)
65+
)
66+
)
67+
for o in node.outputs
68+
]
69+
blocker_candidates = [rv_to_marginalize, *other_input_rvs, *non_elemwise_blockers]
70+
blockers = [var for var in blocker_candidates if var not in output_rvs]
71+
72+
truncated_inputs = [
73+
var
74+
for var in ancestors(output_rvs, blockers=blockers)
75+
if (
76+
var in blockers
77+
or (var.owner is None and not isinstance(var, Constant | SharedVariable))
78+
)
79+
]
80+
81+
# Check that we reach the marginalized rv following a pure elemwise graph
82+
if rv_to_marginalize not in truncated_inputs:
83+
return False
84+
85+
# Check that none of the truncated inputs depends on the marginalized_rv
86+
other_truncated_inputs = [inp for inp in truncated_inputs if inp is not rv_to_marginalize]
87+
# TODO: We don't need to go all the way to the root variables
88+
if rv_to_marginalize in ancestors(
89+
other_truncated_inputs, blockers=[rv_to_marginalize, *other_input_rvs]
90+
):
91+
return False
92+
return True

0 commit comments

Comments
 (0)