Skip to content

Commit 178be94

Browse files
Fix typos, add YouTube example as test
1 parent daf76ba commit 178be94

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

pymc_experimental/marginal_model.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -502,9 +502,6 @@ def logp_fn(marginalized_rv_const, *non_sequences):
502502

503503
@_logprob.register(DiscreteMarginalMarkovChainRV)
504504
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
505-
def step_alpha(log_alpha, log_P):
506-
return pt.logsumexp(log_alpha[:, None] + log_P, 0)
507-
508505
def eval_logp(x):
509506
return logp(init_dist_, x)
510507

@@ -520,7 +517,9 @@ def eval_logp(x):
520517
domain = pt.arange(P_.shape[0], dtype="int32")
521518

522519
vec_eval_logp = pt.vectorize(eval_logp, "()->()")
523-
logp_init = vec_eval_logp(domain)
520+
521+
log_P_ = pt.log(P_)
522+
log_alpha_init = vec_eval_logp(domain) + log_P_
524523

525524
# Construct logp in two steps
526525
# Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)
@@ -543,19 +542,28 @@ def eval_logp(x):
543542
# Step 2: Compute the transition probabilities
544543
# This is the "forward algorithm", alpha_t = sum(p(s_t | s_{t-1}) * alpha_{t-1})
545544
# We do it entirely in logs, though.
545+
def step_alpha(logp_emission, log_alpha, log_P):
546+
547+
return pt.logsumexp(log_alpha[:, None] + log_P, 0)
548+
546549
log_alpha_seq, _ = scan(
547-
step_alpha, non_sequences=[pt.log(P_)], outputs_info=[logp_init], n_steps=n_steps_
550+
step_alpha,
551+
non_sequences=[log_P_],
552+
outputs_info=[log_alpha_init],
553+
sequences=pt.moveaxis(vec_logp_emission, -1, 0),
548554
)
549555

550556
# Scan works over the T dimension, so output is (T, k). We need to swap to (k, T)
551-
log_alpha_seq = pt.moveaxis(pt.concatenate([logp_init[None], log_alpha_seq], axis=0), -1, 0)
557+
log_alpha_seq = pt.moveaxis(
558+
pt.concatenate([log_alpha_init, log_alpha_seq[..., -1]], axis=0), -1, 0
559+
)
552560

553561
# Final logp is the sum of the sum of the emission probs and the transition probabilities
554562
# pt.add is used in case there are multiple emissions that depend on the same markov chain; in this case, we compute
555563
# the joint probability of seeing everything together.
556-
joint_log_obs_given_states = pt.logsumexp(pt.add(*vec_logp_emission) + log_alpha_seq, axis=0)
564+
joint_log_obs_given_states = pt.logsumexp(log_alpha_seq, axis=0)
557565

558566
# If there are multple emisson streams, we have to add dummy logps for the remaining value variables. The first
559567
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
560-
dummy_logps = (pt.constant(np.zeros((4,))),) * (len(values) - 1)
568+
dummy_logps = (pt.constant(np.zeros(shape=())),) * (len(values) - 1)
561569
return joint_log_obs_given_states, *dummy_logps

pymc_experimental/tests/test_marginal_model.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def dist(idx, size):
473473
np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))
474474

475475

476-
def test_marginalized_hmm_with_one_emission():
476+
def test_marginalized_hmm_with_one_normal_emission():
477477
with MarginalModel() as m:
478478
P = [[0, 1], [1, 0]]
479479
init_dist = pm.Categorical.dist(p=[1, 0])
@@ -489,6 +489,20 @@ def test_marginalized_hmm_with_one_emission():
489489
np.testing.assert_allclose(logp_fn({f"emission": test_value}), expected_logp)
490490

491491

492+
def test_marginalized_hmm_one_bernoulli_emission():
493+
with MarginalModel() as m:
494+
P = np.array([[0.5, 0.5], [0.3, 0.7]])
495+
init_dist = pm.Categorical.dist(p=[0.375, 0.625])
496+
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2)
497+
emission = pm.Bernoulli("emission", p=pt.where(chain, 0.2, 0.6))
498+
m.marginalize([chain])
499+
500+
test_value = np.array([0, 0, 1])
501+
expected_logp = -5.8774646585368675
502+
logp_fn = m.compile_logp()
503+
np.testing.assert_allclose(logp_fn({f"emission": test_value}), expected_logp)
504+
505+
492506
def test_marginalized_hmm_with_many_emissions():
493507
with MarginalModel() as m:
494508
P = [[0, 1], [1, 0]]

0 commit comments

Comments
 (0)