Skip to content

Commit f9d1c0e

Browse files
committed
.hmm
1 parent 49761d7 commit f9d1c0e

File tree

2 files changed

+96
-1
lines changed

2 files changed

+96
-1
lines changed

pymc_experimental/marginal_model.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
from pytensor.tensor.shape import Shape
2525

26+
from pymc_experimental.distributions import DiscreteMarkovChain
27+
2628

2729
class MarginalModel(Model):
2830
"""Subclass of PyMC Model that implements functionality for automatic
@@ -226,7 +228,7 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV
226228
if not isinstance(rvs_to_marginalize, Sequence):
227229
rvs_to_marginalize = (rvs_to_marginalize,)
228230

229-
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
231+
supported_dists = (Bernoulli, Categorical, DiscreteUniform, DiscreteMarkovChain)
230232
for rv_to_marginalize in rvs_to_marginalize:
231233
if rv_to_marginalize not in self.free_RVs:
232234
raise ValueError(
@@ -342,6 +344,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
342344
ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs}
343345
if len(ndim_supp) != 1:
344346
raise NotImplementedError()
347+
ndim_supp = tuple(ndim_supp)[0]
345348
# if max(ndim_supp) > 0:
346349
# raise NotImplementedError(
347350
# "Marginalization with dependent Multivariate RVs not implemented"
@@ -400,6 +403,9 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
400403
elif isinstance(op, DiscreteUniform):
401404
lower, upper = constant_fold(rv.owner.inputs[3:])
402405
return tuple(range(lower, upper + 1))
406+
elif isinstance(op, DiscreteMarkovChain):
407+
p = rv.owner.inputs[0]
408+
return tuple(range(pt.get_vector_length(p[-1])))
403409

404410
raise NotImplementedError(f"Cannot compute domain for op {op}")
405411

@@ -457,6 +463,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
457463
)
458464

459465
# Arbitrary cutoff to switch to Scan implementation to keep graph size under control
466+
# TODO: Try vectorize here
460467
if len(marginalized_rv_domain) <= 10:
461468
joint_logps = [
462469
joint_logp_op(marginalized_rv_domain_tensor[i], *values, *inputs)

pymc_experimental/tests/test_marginal_model.py

+88
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pymc.util import UNSET
1313
from scipy.special import logsumexp
1414

15+
from pymc_experimental.distributions import DiscreteMarkovChain
1516
from pymc_experimental.marginal_model import (
1617
FiniteDiscreteMarginalRV,
1718
MarginalModel,
@@ -470,3 +471,90 @@ def dist(idx, size):
470471
):
471472
pt = {"norm": test_value}
472473
np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))
474+
475+
476+
def hmm_logp(values, P, steps, init_dist, state_rng):
477+
478+
[e_value] = values
479+
480+
# P = [[0, 1], [1, 0]]
481+
domain = tuple(range(pt.get_vector_length(P[-1])))
482+
483+
# This should be done on log-scale
484+
# Probability of states at t0
485+
logprob_states = pm.math.stack([logp(init_dist, d) for d in domain])
486+
487+
logprob_emiss_ts = []
488+
for e_value_t in e_value:
489+
# Use vectorize
490+
logprob_emiss_t = pt.sum(
491+
[
492+
logpprob_state + logp(clone_replace(emission_rv, replace={state_rv: state_value}), e_value_t)
493+
for (logpprob_state, state_value) in zip(logprob_states, domain)
494+
]
495+
)
496+
497+
# Probability next state
498+
# prob_states = prob_states @ P
499+
logprob_states = P[:, None]
500+
501+
logprob_emiss_ts.append(logprob_emiss_t)
502+
503+
return logprob_emiss_ts.sum()
504+
505+
506+
507+
def test_hmm():
508+
509+
with MarginalModel() as m:
510+
p = pt.as_tensor(np.array([1, 0]))
511+
512+
chain_0 = pm.Bernoulli("chain_0", p=0)
513+
chain_1 = pm.Bernoulli("chain_1", p=p[chain_0])
514+
chain_2 = pm.Bernoulli("chain_2", p=p[chain_1])
515+
chain_3 = pm.Bernoulli("chain_3", p=p[chain_2])
516+
517+
pm.Normal("emission_0", chain_0 * 2 - 1, sigma=1e-1)
518+
pm.Normal("emission_1", chain_1 * 2 - 1, sigma=1e-1)
519+
pm.Normal("emission_2", chain_2 * 2 - 1, sigma=1e-1)
520+
pm.Normal("emission_3", chain_3 * 2 - 1, sigma=1e-1)
521+
522+
523+
524+
with pytest.warns(UserWarning, match="multiple dependent variables"):
525+
m.marginalize([chain_0, chain_1, chain_2, chain_3])
526+
import pytensor
527+
print()
528+
pytensor.dprint(m.clone()._marginalize().free_RVs)
529+
530+
logp_fn = m.compile_logp()
531+
test_value = [-1, 1, -1, 1]
532+
533+
expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval()
534+
535+
np.testing.assert_allclose(
536+
logp_fn({f"emission_{i}": test_value_i for i, test_value_i in enumerate(test_value)}),
537+
expected_logp,
538+
)
539+
return
540+
541+
# with MarginalModel() as m:
542+
# P = [[0, 1], [1, 0]]
543+
# zero = pm.DiracDelta.dist(np.array(0, dtype="int64"))
544+
# chain = DiscreteMarkovChain("chain", P=P, init_dist=zero, steps=3)
545+
# emmission = pm.Normal("emission", mu=chain * 2 - 1, sigma=1e-1)
546+
# np.testing.assert_equal(pm.draw(chain), [0, 1, 0, 1])
547+
# m.marginalize(chain)
548+
549+
# test_value = [-1, 1, -1, 1]
550+
# expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), [0, 0, 0, 0]).eval()
551+
552+
# np.testing.assert_allclose(
553+
# logp_fn({"emission": test_value}),
554+
# expected_logp,
555+
# )
556+
#
557+
# np.testing.assert_allclose(
558+
# logp_fn({f"emission{i}": test_value_i for i, test_value_i in enumerate(test_value)}),
559+
# expected_logp,
560+
# )

0 commit comments

Comments
 (0)