Skip to content

Commit 3ebdfb5

Browse files
Marginalize DiscreteMarkovChain
Co-authored-by: Jesse Grabowski <[email protected]>
1 parent c61f2cb commit 3ebdfb5

File tree

2 files changed

+185
-15
lines changed

2 files changed

+185
-15
lines changed

pymc_experimental/model/marginal_model.py

+100-15
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,16 @@
1010
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
1111
from pymc.distributions.transforms import Chain
1212
from pymc.logprob.abstract import _logprob
13-
from pymc.logprob.basic import conditional_logp
13+
from pymc.logprob.basic import conditional_logp, logp
1414
from pymc.logprob.transforms import IntervalTransform
1515
from pymc.model import Model
1616
from pymc.pytensorf import compile_pymc, constant_fold, inputvars
1717
from pymc.util import _get_seeds_per_chain, dataset_to_point_list, treedict
18-
from pytensor import Mode
18+
from pytensor import Mode, scan
1919
from pytensor.compile import SharedVariable
2020
from pytensor.compile.builders import OpFromGraph
21-
from pytensor.graph import (
22-
Constant,
23-
FunctionGraph,
24-
ancestors,
25-
clone_replace,
26-
vectorize_graph,
27-
)
21+
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
22+
from pytensor.graph.replace import vectorize_graph
2823
from pytensor.scan import map as scan_map
2924
from pytensor.tensor import TensorType, TensorVariable
3025
from pytensor.tensor.elemwise import Elemwise
@@ -33,6 +28,8 @@
3328

3429
__all__ = ["MarginalModel"]
3530

31+
from pymc_experimental.distributions import DiscreteMarkovChain
32+
3633

3734
class MarginalModel(Model):
3835
"""Subclass of PyMC Model that implements functionality for automatic
@@ -247,16 +244,25 @@ def marginalize(
247244
self[var] if isinstance(var, str) else var for var in rvs_to_marginalize
248245
]
249246

250-
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
251247
for rv_to_marginalize in rvs_to_marginalize:
252248
if rv_to_marginalize not in self.free_RVs:
253249
raise ValueError(
254250
f"Marginalized RV {rv_to_marginalize} is not a free RV in the model"
255251
)
256-
if not isinstance(rv_to_marginalize.owner.op, supported_dists):
252+
253+
rv_op = rv_to_marginalize.owner.op
254+
if isinstance(rv_op, DiscreteMarkovChain):
255+
if rv_op.n_lags > 1:
256+
raise NotImplementedError(
257+
"Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
258+
)
259+
if rv_to_marginalize.owner.inputs[0].type.ndim > 2:
260+
raise NotImplementedError(
261+
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
262+
)
263+
elif not isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)):
257264
raise NotImplementedError(
258-
f"RV with distribution {rv_to_marginalize.owner.op} cannot be marginalized. "
259-
f"Supported distribution include {supported_dists}"
265+
f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
260266
)
261267

262268
if rv_to_marginalize.name in self.named_vars_to_dims:
@@ -492,6 +498,10 @@ class FiniteDiscreteMarginalRV(MarginalRV):
492498
"""Base class for Finite Discrete Marginalized RVs"""
493499

494500

501+
class DiscreteMarginalMarkovChainRV(MarginalRV):
502+
"""Base class for Discrete Marginal Markov Chain RVs"""
503+
504+
495505
def static_shape_ancestors(vars):
496506
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
497507
return [
@@ -620,11 +630,17 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
620630
replace_inputs.update({input_rv: input_rv.type() for input_rv in input_rvs})
621631
cloned_outputs = clone_replace(outputs, replace=replace_inputs)
622632

623-
marginalization_op = FiniteDiscreteMarginalRV(
633+
if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain):
634+
marginalize_constructor = DiscreteMarginalMarkovChainRV
635+
else:
636+
marginalize_constructor = FiniteDiscreteMarginalRV
637+
638+
marginalization_op = marginalize_constructor(
624639
inputs=list(replace_inputs.values()),
625640
outputs=cloned_outputs,
626641
ndim_supp=ndim_supp,
627642
)
643+
628644
marginalized_rvs = marginalization_op(*replace_inputs.keys())
629645
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))
630646
return rvs_to_marginalize, marginalized_rvs
@@ -640,14 +656,17 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
640656
elif isinstance(op, DiscreteUniform):
641657
lower, upper = constant_fold(rv.owner.inputs[3:])
642658
return tuple(range(lower, upper + 1))
659+
elif isinstance(op, DiscreteMarkovChain):
660+
P = rv.owner.inputs[0]
661+
return tuple(range(pt.get_vector_length(P[-1])))
643662

644663
raise NotImplementedError(f"Cannot compute domain for op {op}")
645664

646665

647666
def _add_reduce_batch_dependent_logps(
648667
marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable]
649668
):
650-
"""Add the logps of dependent RVs while reducing extra batch dims as assessed from the `marginalized_type`."""
669+
"""Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`."""
651670

652671
mbcast = marginalized_type.broadcastable
653672
reduced_logps = []
@@ -730,3 +749,69 @@ def logp_fn(marginalized_rv_const, *non_sequences):
730749

731750
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
732751
return joint_logps, *(pt.constant(0),) * (len(values) - 1)
752+
753+
754+
@_logprob.register(DiscreteMarginalMarkovChainRV)
755+
def marginal_hmm_logp(op, values, *inputs, **kwargs):
756+
757+
marginalized_rvs_node = op.make_node(*inputs)
758+
inner_rvs = clone_replace(
759+
op.inner_outputs,
760+
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
761+
)
762+
763+
chain_rv, *dependent_rvs = inner_rvs
764+
P, n_steps_, init_dist_, rng = chain_rv.owner.inputs
765+
domain = pt.arange(P.shape[-1], dtype="int32")
766+
767+
# Construct logp in two steps
768+
# Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)
769+
770+
# First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating
771+
# around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise,
772+
# PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step.
773+
chain_value = chain_rv.clone()
774+
dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value})
775+
logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values)))
776+
777+
# Reduce and add the batch dims beyond the chain dimension
778+
reduced_logp_emissions = _add_reduce_batch_dependent_logps(
779+
chain_rv.type, logp_emissions_dict.values()
780+
)
781+
782+
# Add a batch dimension for the domain of the chain
783+
chain_shape = constant_fold(tuple(chain_rv.shape))
784+
batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0)
785+
batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value})
786+
787+
# Step 2: Compute the transition probabilities
788+
# This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1})
789+
# We do it entirely in logs, though.
790+
791+
# To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) under
792+
# the initial distribution. This is robust to everything the user can throw at it.
793+
batch_logp_init_dist = pt.vectorize(lambda x: logp(init_dist_, x), "()->()")(
794+
batch_chain_value[..., 0]
795+
)
796+
log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0]
797+
798+
def step_alpha(logp_emission, log_alpha, log_P):
799+
step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0)
800+
return logp_emission + step_log_prob
801+
802+
P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2)
803+
log_P = pt.shape_padright(pt.log(P), P_bcast_dims)
804+
log_alpha_seq, _ = scan(
805+
step_alpha,
806+
non_sequences=[log_P],
807+
outputs_info=[log_alpha_init],
808+
# Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
809+
sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0),
810+
)
811+
# Final logp is just the sum of the last scan state
812+
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
813+
814+
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
815+
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
816+
dummy_logps = (pt.constant(0),) * (len(values) - 1)
817+
return joint_logp, *dummy_logps

pymc_experimental/tests/model/test_marginal_model.py

+85
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from scipy.special import log_softmax, logsumexp
1515
from scipy.stats import halfnorm, norm
1616

17+
from pymc_experimental.distributions import DiscreteMarkovChain
1718
from pymc_experimental.model.marginal_model import (
1819
FiniteDiscreteMarginalRV,
1920
MarginalModel,
@@ -673,3 +674,87 @@ def dist(idx, size):
673674
):
674675
pt = {"norm": test_value}
675676
np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))
677+
678+
679+
@pytest.mark.parametrize("batch_chain", (False, True), ids=lambda x: f"batch_chain={x}")
680+
@pytest.mark.parametrize("batch_emission", (False, True), ids=lambda x: f"batch_emission={x}")
681+
def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
682+
if batch_chain and not batch_emission:
683+
pytest.skip("Redundant implicit combination")
684+
685+
with MarginalModel() as m:
686+
P = [[0, 1], [1, 0]]
687+
init_dist = pm.Categorical.dist(p=[1, 0])
688+
chain = DiscreteMarkovChain(
689+
"chain", P=P, init_dist=init_dist, steps=3, shape=(3, 4) if batch_chain else None
690+
)
691+
emission = pm.Normal(
692+
"emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None
693+
)
694+
695+
m.marginalize([chain])
696+
logp_fn = m.compile_logp()
697+
698+
test_value = np.array([-1, 1, -1, 1])
699+
expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval()
700+
if batch_emission:
701+
test_value = np.broadcast_to(test_value, (3, 4))
702+
expected_logp *= 3
703+
np.testing.assert_allclose(logp_fn({f"emission": test_value}), expected_logp)
704+
705+
706+
@pytest.mark.parametrize(
707+
"categorical_emission",
708+
[
709+
False,
710+
# Categorical has a core vector parameter,
711+
# so it is not possible to build a graph that uses elemwise operations exclusively
712+
pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError)),
713+
],
714+
)
715+
def test_marginalized_hmm_categorical_emission(categorical_emission):
716+
"""Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0"""
717+
with MarginalModel() as m:
718+
P = np.array([[0.5, 0.5], [0.3, 0.7]])
719+
init_dist = pm.Categorical.dist(p=[0.375, 0.625])
720+
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2)
721+
if categorical_emission:
722+
emission = pm.Categorical(
723+
"emission", p=pt.where(pt.eq(chain, 0)[..., None], [0.8, 0.2], [0.4, 0.6])
724+
)
725+
else:
726+
emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6))
727+
m.marginalize([chain])
728+
729+
test_value = np.array([0, 0, 1])
730+
expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video
731+
logp_fn = m.compile_logp()
732+
np.testing.assert_allclose(logp_fn({f"emission": test_value}), expected_logp)
733+
734+
735+
@pytest.mark.parametrize("batch_emission1", (False, True))
736+
@pytest.mark.parametrize("batch_emission2", (False, True))
737+
def test_marginalized_hmm_multiple_emissions(batch_emission1, batch_emission2):
738+
emission1_shape = (2, 4) if batch_emission1 else (4,)
739+
emission2_shape = (2, 4) if batch_emission2 else (4,)
740+
with MarginalModel() as m:
741+
P = [[0, 1], [1, 0]]
742+
init_dist = pm.Categorical.dist(p=[1, 0])
743+
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=3)
744+
emission_1 = pm.Normal("emission_1", mu=chain * 2 - 1, sigma=1e-1, shape=emission1_shape)
745+
emission_2 = pm.Normal(
746+
"emission_2", mu=(1 - chain) * 2 - 1, sigma=1e-1, shape=emission2_shape
747+
)
748+
749+
with pytest.warns(UserWarning, match="multiple dependent variables"):
750+
m.marginalize([chain])
751+
752+
logp_fn = m.compile_logp()
753+
754+
test_value = np.array([-1, 1, -1, 1])
755+
multiplier = 2 + batch_emission1 + batch_emission2
756+
expected_logp = norm.logpdf(np.zeros_like(test_value), 0, 1e-1).sum() * multiplier
757+
test_value_emission1 = np.broadcast_to(test_value, emission1_shape)
758+
test_value_emission2 = np.broadcast_to(-test_value, emission2_shape)
759+
test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2}
760+
np.testing.assert_allclose(logp_fn(test_point), expected_logp)

0 commit comments

Comments
 (0)