Skip to content

Commit 378dbe4

Browse files
committed
.WIP
1 parent 2ec0647 commit 378dbe4

File tree

9 files changed

+696
-665
lines changed

9 files changed

+696
-665
lines changed

pymc_experimental/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from pymc_experimental import distributions, gp, statespace, utils
1717
from pymc_experimental.inference.fit import fit
18-
from pymc_experimental.model.marginal_model import MarginalModel
18+
from pymc_experimental.model.marginal.marginal_model import MarginalModel
1919
from pymc_experimental.model.model_api import as_model
2020
from pymc_experimental.version import __version__
2121

pymc_experimental/model/marginal/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
from typing import Sequence
2+
3+
import numpy as np
4+
import pytensor.tensor as pt
5+
from pymc.distributions import (
6+
Bernoulli,
7+
Categorical,
8+
DiscreteUniform,
9+
SymbolicRandomVariable
10+
)
11+
from pymc.logprob.basic import conditional_logp, logp
12+
from pymc.logprob.abstract import _logprob
13+
from pymc.pytensorf import constant_fold
14+
from pytensor.graph.replace import clone_replace, graph_replace
15+
from pytensor.scan import scan, map as scan_map
16+
from pytensor.compile.mode import Mode
17+
from pytensor.graph import vectorize_graph
18+
from pytensor.tensor import TensorVariable, TensorType
19+
20+
from pymc_experimental.distributions import DiscreteMarkovChain
21+
22+
23+
class MarginalRV(SymbolicRandomVariable):
24+
"""Base class for Marginalized RVs"""
25+
26+
27+
class FiniteDiscreteMarginalRV(MarginalRV):
28+
"""Base class for Finite Discrete Marginalized RVs"""
29+
30+
31+
class DiscreteMarginalMarkovChainRV(MarginalRV):
32+
"""Base class for Discrete Marginal Markov Chain RVs"""
33+
34+
35+
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
36+
op = rv.owner.op
37+
dist_params = rv.owner.op.dist_params(rv.owner)
38+
if isinstance(op, Bernoulli):
39+
return (0, 1)
40+
elif isinstance(op, Categorical):
41+
[p_param] = dist_params
42+
return tuple(range(pt.get_vector_length(p_param)))
43+
elif isinstance(op, DiscreteUniform):
44+
lower, upper = constant_fold(dist_params)
45+
return tuple(np.arange(lower, upper + 1))
46+
elif isinstance(op, DiscreteMarkovChain):
47+
P, *_ = dist_params
48+
return tuple(range(pt.get_vector_length(P[-1])))
49+
50+
raise NotImplementedError(f"Cannot compute domain for op {op}")
51+
52+
53+
def _add_reduce_batch_dependent_logps(
54+
marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable]
55+
):
56+
"""Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`."""
57+
58+
mbcast = marginalized_type.broadcastable
59+
reduced_logps = []
60+
for dependent_logp in dependent_logps:
61+
dbcast = dependent_logp.type.broadcastable
62+
dim_diff = len(dbcast) - len(mbcast)
63+
mbcast_aligned = (True,) * dim_diff + mbcast
64+
vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v]
65+
reduced_logps.append(dependent_logp.sum(vbcast_axis))
66+
return pt.add(*reduced_logps)
67+
68+
69+
@_logprob.register(FiniteDiscreteMarginalRV)
70+
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
71+
# Clone the inner RV graph of the Marginalized RV
72+
marginalized_rvs_node = op.make_node(*inputs)
73+
marginalized_rv, *inner_rvs = clone_replace(
74+
op.inner_outputs,
75+
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
76+
)
77+
78+
# Obtain the joint_logp graph of the inner RV graph
79+
inner_rv_values = dict(zip(inner_rvs, values))
80+
marginalized_vv = marginalized_rv.clone()
81+
rv_values = inner_rv_values | {marginalized_rv: marginalized_vv}
82+
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
83+
84+
# Reduce logp dimensions corresponding to broadcasted variables
85+
marginalized_logp = logps_dict.pop(marginalized_vv)
86+
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
87+
marginalized_rv.type, logps_dict.values()
88+
)
89+
90+
# Compute the joint_logp for all possible n values of the marginalized RV. We assume
91+
# each original dimension is independent so that it suffices to evaluate the graph
92+
# n times, once with each possible value of the marginalized RV replicated across
93+
# batched dimensions of the marginalized RV
94+
95+
# PyMC does not allow RVs in the logp graph, even if we are just using the shape
96+
marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False)
97+
marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
98+
marginalized_rv_domain_tensor = pt.moveaxis(
99+
pt.full(
100+
(*marginalized_rv_shape, len(marginalized_rv_domain)),
101+
marginalized_rv_domain,
102+
dtype=marginalized_rv.dtype,
103+
),
104+
-1,
105+
0,
106+
)
107+
108+
try:
109+
joint_logps = vectorize_graph(
110+
joint_logp, replace={marginalized_vv: marginalized_rv_domain_tensor}
111+
)
112+
except Exception:
113+
# Fallback to Scan
114+
def logp_fn(marginalized_rv_const, *non_sequences):
115+
return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const})
116+
117+
joint_logps, _ = scan_map(
118+
fn=logp_fn,
119+
sequences=marginalized_rv_domain_tensor,
120+
non_sequences=[*values, *inputs],
121+
mode=Mode().including("local_remove_check_parameter"),
122+
)
123+
124+
joint_logps = pt.logsumexp(joint_logps, axis=0)
125+
126+
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
127+
return joint_logps, *(pt.constant(0),) * (len(values) - 1)
128+
129+
130+
@_logprob.register(DiscreteMarginalMarkovChainRV)
131+
def marginal_hmm_logp(op, values, *inputs, **kwargs):
132+
marginalized_rvs_node = op.make_node(*inputs)
133+
inner_rvs = clone_replace(
134+
op.inner_outputs,
135+
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
136+
)
137+
138+
chain_rv, *dependent_rvs = inner_rvs
139+
P, n_steps_, init_dist_, rng = chain_rv.owner.inputs
140+
domain = pt.arange(P.shape[-1], dtype="int32")
141+
142+
# Construct logp in two steps
143+
# Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)
144+
145+
# First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating
146+
# around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise,
147+
# PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step.
148+
chain_value = chain_rv.clone()
149+
dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value})
150+
logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values)))
151+
152+
# Reduce and add the batch dims beyond the chain dimension
153+
reduced_logp_emissions = _add_reduce_batch_dependent_logps(
154+
chain_rv.type, logp_emissions_dict.values()
155+
)
156+
157+
# Add a batch dimension for the domain of the chain
158+
chain_shape = constant_fold(tuple(chain_rv.shape))
159+
batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0)
160+
batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value})
161+
162+
# Step 2: Compute the transition probabilities
163+
# This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1})
164+
# We do it entirely in logs, though.
165+
166+
# To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states)
167+
# under the initial distribution. This is robust to everything the user can throw at it.
168+
init_dist_value = init_dist_.type()
169+
logp_init_dist = logp(init_dist_, init_dist_value)
170+
# There is a degerate batch dim for lags=1 (the only supported case),
171+
# that we have to work around, by expanding the batch value and then squeezing it out of the logp
172+
batch_logp_init_dist = vectorize_graph(
173+
logp_init_dist, {init_dist_value: batch_chain_value[:, None, ..., 0]}
174+
).squeeze(1)
175+
log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0]
176+
177+
def step_alpha(logp_emission, log_alpha, log_P):
178+
step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0)
179+
return logp_emission + step_log_prob
180+
181+
P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2)
182+
log_P = pt.shape_padright(pt.log(P), P_bcast_dims)
183+
log_alpha_seq, _ = scan(
184+
step_alpha,
185+
non_sequences=[log_P],
186+
outputs_info=[log_alpha_init],
187+
# Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
188+
sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0),
189+
)
190+
# Final logp is just the sum of the last scan state
191+
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
192+
193+
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
194+
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
195+
dummy_logps = (pt.constant(0),) * (len(values) - 1)
196+
return joint_logp, *dummy_logps

0 commit comments

Comments
 (0)