Skip to content

Commit 958ada4

Browse files
Marginalize DiscreteMarkovChain
Co-authored-by: Jesse Grabowski <[email protected]>
1 parent c6cd151 commit 958ada4

File tree

2 files changed

+185
-15
lines changed

2 files changed

+185
-15
lines changed

pymc_experimental/model/marginal_model.py

+99-14
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,16 @@
1010
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
1111
from pymc.distributions.transforms import Chain
1212
from pymc.logprob.abstract import _logprob
13-
from pymc.logprob.basic import conditional_logp
13+
from pymc.logprob.basic import conditional_logp, logp
1414
from pymc.logprob.transforms import IntervalTransform
1515
from pymc.model import Model
1616
from pymc.pytensorf import compile_pymc, constant_fold, inputvars
1717
from pymc.util import _get_seeds_per_chain, dataset_to_point_list, treedict
18-
from pytensor import Mode
18+
from pytensor import Mode, scan
1919
from pytensor.compile import SharedVariable
2020
from pytensor.compile.builders import OpFromGraph
21-
from pytensor.graph import (
22-
Constant,
23-
FunctionGraph,
24-
ancestors,
25-
clone_replace,
26-
vectorize_graph,
27-
)
21+
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
22+
from pytensor.graph.replace import vectorize_graph
2823
from pytensor.scan import map as scan_map
2924
from pytensor.tensor import TensorType, TensorVariable
3025
from pytensor.tensor.elemwise import Elemwise
@@ -33,6 +28,8 @@
3328

3429
__all__ = ["MarginalModel"]
3530

31+
from pymc_experimental.distributions import DiscreteMarkovChain
32+
3633

3734
class MarginalModel(Model):
3835
"""Subclass of PyMC Model that implements functionality for automatic
@@ -245,16 +242,25 @@ def marginalize(
245242
self[var] if isinstance(var, str) else var for var in rvs_to_marginalize
246243
]
247244

248-
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
249245
for rv_to_marginalize in rvs_to_marginalize:
250246
if rv_to_marginalize not in self.free_RVs:
251247
raise ValueError(
252248
f"Marginalized RV {rv_to_marginalize} is not a free RV in the model"
253249
)
254-
if not isinstance(rv_to_marginalize.owner.op, supported_dists):
250+
251+
rv_op = rv_to_marginalize.owner.op
252+
if isinstance(rv_op, DiscreteMarkovChain):
253+
if rv_op.n_lags > 1:
254+
raise NotImplementedError(
255+
"Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
256+
)
257+
if rv_to_marginalize.owner.inputs[0].type.ndim > 2:
258+
raise NotImplementedError(
259+
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
260+
)
261+
elif not isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)):
255262
raise NotImplementedError(
256-
f"RV with distribution {rv_to_marginalize.owner.op} cannot be marginalized. "
257-
f"Supported distribution include {supported_dists}"
263+
f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
258264
)
259265

260266
if rv_to_marginalize.name in self.named_vars_to_dims:
@@ -490,6 +496,10 @@ class FiniteDiscreteMarginalRV(MarginalRV):
490496
"""Base class for Finite Discrete Marginalized RVs"""
491497

492498

499+
class DiscreteMarginalMarkovChainRV(MarginalRV):
500+
"""Base class for Discrete Marginal Markov Chain RVs"""
501+
502+
493503
def static_shape_ancestors(vars):
494504
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
495505
return [
@@ -618,11 +628,17 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
618628
replace_inputs.update({input_rv: input_rv.type() for input_rv in input_rvs})
619629
cloned_outputs = clone_replace(outputs, replace=replace_inputs)
620630

621-
marginalization_op = FiniteDiscreteMarginalRV(
631+
if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain):
632+
marginalize_constructor = DiscreteMarginalMarkovChainRV
633+
else:
634+
marginalize_constructor = FiniteDiscreteMarginalRV
635+
636+
marginalization_op = marginalize_constructor(
622637
inputs=list(replace_inputs.values()),
623638
outputs=cloned_outputs,
624639
ndim_supp=ndim_supp,
625640
)
641+
626642
marginalized_rvs = marginalization_op(*replace_inputs.keys())
627643
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))
628644
return rvs_to_marginalize, marginalized_rvs
@@ -638,6 +654,9 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
638654
elif isinstance(op, DiscreteUniform):
639655
lower, upper = constant_fold(rv.owner.inputs[3:])
640656
return tuple(range(lower, upper + 1))
657+
elif isinstance(op, DiscreteMarkovChain):
658+
p = rv.owner.inputs[0]
659+
return tuple(range(pt.get_vector_length(p[-1])))
641660

642661
raise NotImplementedError(f"Cannot compute domain for op {op}")
643662

@@ -728,3 +747,69 @@ def logp_fn(marginalized_rv_const, *non_sequences):
728747

729748
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
730749
return joint_logps, *(pt.constant(0),) * (len(values) - 1)
750+
751+
752+
@_logprob.register(DiscreteMarginalMarkovChainRV)
753+
def marginal_hmm_logp(op, values, *inputs, **kwargs):
754+
755+
marginalized_rvs_node = op.make_node(*inputs)
756+
inner_rvs = clone_replace(
757+
op.inner_outputs,
758+
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
759+
)
760+
761+
chain_rv, *dependent_rvs = inner_rvs
762+
P, n_steps_, init_dist_, rng = chain_rv.owner.inputs
763+
domain = pt.arange(P.shape[-1], dtype="int32")
764+
765+
# Construct logp in two steps
766+
# Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)
767+
768+
# First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating
769+
# around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise,
770+
# PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step.
771+
chain_value = chain_rv.clone()
772+
dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value})
773+
logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values)))
774+
775+
# Reduce and add the batch dims beyond the chain dimension
776+
reduced_logp_emissions = _add_reduce_batch_dependent_logps(
777+
chain_rv.type, logp_emissions_dict.values()
778+
)
779+
780+
# Add a batch dimension for the domain of the chain
781+
chain_shape = constant_fold(tuple(chain_rv.shape))
782+
batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0)
783+
batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value})
784+
785+
# Step 2: Compute the transition probabilities
786+
# This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1})
787+
# We do it entirely in logs, though.
788+
789+
# To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) under
790+
# the initial distribution. This is robust to everything the user can throw at it.
791+
batch_logp_init_dist = pt.vectorize(lambda x: logp(init_dist_, x), "()->()")(
792+
batch_chain_value[..., 0]
793+
)
794+
log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0]
795+
796+
def step_alpha(logp_emission, log_alpha, log_P):
797+
step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0)
798+
return logp_emission + step_log_prob
799+
800+
P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2)
801+
log_P = pt.shape_padright(pt.log(P), P_bcast_dims)
802+
log_alpha_seq, _ = scan(
803+
step_alpha,
804+
non_sequences=[log_P],
805+
outputs_info=[log_alpha_init],
806+
# Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
807+
sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0),
808+
)
809+
# Final logp is just the sum of the last scan state
810+
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
811+
812+
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
813+
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
814+
dummy_logps = (pt.constant(0),) * (len(values) - 1)
815+
return joint_logp, *dummy_logps

pymc_experimental/tests/model/test_marginal_model.py

+86-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from scipy.special import log_softmax, logsumexp
1515
from scipy.stats import halfnorm, norm
1616

17+
from pymc_experimental.distributions import DiscreteMarkovChain
1718
from pymc_experimental.model.marginal_model import (
1819
FiniteDiscreteMarginalRV,
1920
MarginalModel,
@@ -467,7 +468,7 @@ def test_not_supported_marginalized():
467468
y = pm.Dirichlet("y", a=pm.math.switch(x, [1, 1, 1], [10, 10, 10]))
468469
with pytest.raises(
469470
NotImplementedError,
470-
match="Marginalization of withe dependent Multivariate RVs not implemented",
471+
match="Marginalization with dependent Multivariate RVs not implemented",
471472
):
472473
m.marginalize(x)
473474

@@ -642,3 +643,87 @@ def dist(idx, size):
642643
):
643644
pt = {"norm": test_value}
644645
np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))
646+
647+
648+
@pytest.mark.parametrize("batch_chain", (True,), ids=lambda x: f"batch_chain={x}")
649+
@pytest.mark.parametrize("batch_emission", (True,), ids=lambda x: f"batch_emission={x}")
650+
def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
651+
if batch_chain and not batch_emission:
652+
pytest.skip("Redundant implicit combination")
653+
654+
with MarginalModel() as m:
655+
P = [[0, 1], [1, 0]]
656+
init_dist = pm.Categorical.dist(p=[1, 0])
657+
chain = DiscreteMarkovChain(
658+
"chain", P=P, init_dist=init_dist, steps=3, shape=(3, 4) if batch_chain else None
659+
)
660+
emission = pm.Normal(
661+
"emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None
662+
)
663+
664+
m.marginalize([chain])
665+
logp_fn = m.compile_logp()
666+
667+
test_value = np.array([-1, 1, -1, 1])
668+
expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval()
669+
if batch_emission:
670+
test_value = np.broadcast_to(test_value, (3, 4))
671+
expected_logp *= 3
672+
np.testing.assert_allclose(logp_fn({f"emission": test_value}), expected_logp)
673+
674+
675+
@pytest.mark.parametrize(
676+
"categorical_emission",
677+
[
678+
False,
679+
# Categorical has a core vector parameter,
680+
# so it is not possible to build a graph that uses elemwise operations exclusively
681+
pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError)),
682+
],
683+
)
684+
def test_marginalized_hmm_categorical_emission(categorical_emission):
685+
"""Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0"""
686+
with MarginalModel() as m:
687+
P = np.array([[0.5, 0.5], [0.3, 0.7]])
688+
init_dist = pm.Categorical.dist(p=[0.375, 0.625])
689+
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2)
690+
if categorical_emission:
691+
emission = pm.Categorical(
692+
"emission", p=pt.where(pt.eq(chain, 0)[..., None], [0.8, 0.2], [0.4, 0.6])
693+
)
694+
else:
695+
emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6))
696+
m.marginalize([chain])
697+
698+
test_value = np.array([0, 0, 1])
699+
expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video
700+
logp_fn = m.compile_logp()
701+
np.testing.assert_allclose(logp_fn({f"emission": test_value}), expected_logp)
702+
703+
704+
@pytest.mark.parametrize("batch_emission1", (False, True))
705+
@pytest.mark.parametrize("batch_emission2", (False, True))
706+
def test_marginalized_hmm_multiple_emissions(batch_emission1, batch_emission2):
707+
emission1_shape = (2, 4) if batch_emission1 else (4,)
708+
emission2_shape = (2, 4) if batch_emission2 else (4,)
709+
with MarginalModel() as m:
710+
P = [[0, 1], [1, 0]]
711+
init_dist = pm.Categorical.dist(p=[1, 0])
712+
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=3)
713+
emission_1 = pm.Normal("emission_1", mu=chain * 2 - 1, sigma=1e-1, shape=emission1_shape)
714+
emission_2 = pm.Normal(
715+
"emission_2", mu=(1 - chain) * 2 - 1, sigma=1e-1, shape=emission2_shape
716+
)
717+
718+
with pytest.warns(UserWarning, match="multiple dependent variables"):
719+
m.marginalize([chain])
720+
721+
logp_fn = m.compile_logp()
722+
723+
test_value = np.array([-1, 1, -1, 1])
724+
multiplier = 2 + batch_emission1 + batch_emission2
725+
expected_logp = norm.logpdf(np.zeros_like(test_value), 0, 1e-1).sum() * multiplier
726+
test_value_emission1 = np.broadcast_to(test_value, emission1_shape)
727+
test_value_emission2 = np.broadcast_to(-test_value, emission2_shape)
728+
test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2}
729+
np.testing.assert_allclose(logp_fn(test_point), expected_logp)

0 commit comments

Comments
 (0)