Skip to content

Commit edfa63c

Browse files
Tests pass
1 parent 8f9ae40 commit edfa63c

File tree

2 files changed

+29
-29
lines changed

2 files changed

+29
-29
lines changed

pymc_experimental/marginal_model.py

+28-28
Original file line numberDiff line numberDiff line change
@@ -502,9 +502,6 @@ def logp_fn(marginalized_rv_const, *non_sequences):
502502

503503
@_logprob.register(DiscreteMarginalMarkovChainRV)
504504
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
505-
def eval_logp(x):
506-
return logp(init_dist_, x)
507-
508505
marginalized_rvs_node = op.make_node(*inputs)
509506
inner_rvs = clone_replace(
510507
op.inner_outputs,
@@ -513,19 +510,15 @@ def eval_logp(x):
513510

514511
chain_rv, *dependent_rvs = inner_rvs
515512
P_, n_steps_, init_dist_, rng = chain_rv.owner.inputs
516-
517-
domain = pt.arange(P_.shape[0], dtype="int32")
518-
519-
vec_eval_logp = pt.vectorize(eval_logp, "()->()")
520-
521513
log_P_ = pt.log(P_)
522-
log_alpha_init = vec_eval_logp(domain) + log_P_
514+
domain = pt.arange(P_.shape[0], dtype="int32")
523515

524516
# Construct logp in two steps
525517
# Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)
526518

527-
# This will break the dependency between chain and the init_dist_ random variable
528-
# TODO: Make this comment more robust after I understand better.
519+
# First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating
520+
# around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise,
521+
# PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step.
529522
chain_dummy = chain_rv.clone()
530523
dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_dummy})
531524
input_dict = dict(zip(dependent_rvs, values))
@@ -536,34 +529,41 @@ def eval_logp(x):
536529
chain_dummy: pt.moveaxis(pt.broadcast_to(domain, (*values[0].shape, domain.size)), -1, 0)
537530
}
538531

539-
# This is a (k, T) matrix of logp terms, one for each state - emission pair
540-
vec_logp_emission = vectorize_graph(tuple(logp_value_dict.values()), sub_dict)
532+
# This is a list of (k, T) matrices of logp terms, one for each state - emission pair, for each RV that depends
533+
# on the markov chain being marginalized. Since they all depend on the same Markov chain, it is safe to assume they
534+
# all share the same length. Finally, because we only consider the **joint** logp of all variables that depend on
535+
# the chain, we can sum all of these logp values now.
536+
vec_logp_emission = pt.stack(vectorize_graph(tuple(logp_value_dict.values()), sub_dict)).sum(
537+
axis=0
538+
)
541539

542540
# Step 2: Compute the transition probabilities
543-
# This is the "forward algorithm", alpha_t = sum(p(s_t | s_{t-1}) * alpha_{t-1})
541+
# This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1})
544542
# We do it entirely in logs, though.
545-
def step_alpha(logp_emission, log_alpha, log_P):
546543

547-
return pt.logsumexp(log_alpha[:, None] + log_P, 0)
544+
# To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) under
545+
# the initial distribution. This is robust to everything the user can throw at it.
546+
def eval_logp(x):
547+
return logp(init_dist_, x)
548+
549+
vec_eval_logp = pt.vectorize(eval_logp, "()->()")
550+
log_alpha_init = vec_eval_logp(domain) + vec_logp_emission[..., 0]
551+
552+
def step_alpha(logp_emission, log_alpha, log_P):
553+
step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, 0)
554+
return logp_emission + step_log_prob
548555

549556
log_alpha_seq, _ = scan(
550557
step_alpha,
551558
non_sequences=[log_P_],
552559
outputs_info=[log_alpha_init],
553-
sequences=pt.moveaxis(vec_logp_emission, -1, 0),
560+
# Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
561+
sequences=pt.moveaxis(vec_logp_emission[..., 1:], -1, 0),
554562
)
555-
556-
# Scan works over the T dimension, so output is (T, k). We need to swap to (k, T)
557-
log_alpha_seq = pt.moveaxis(
558-
pt.concatenate([log_alpha_init, log_alpha_seq[..., -1]], axis=0), -1, 0
559-
)
560-
561-
# Final logp is the sum of the sum of the emission probs and the transition probabilities
562-
# pt.add is used in case there are multiple emissions that depend on the same markov chain; in this case, we compute
563-
# the joint probability of seeing everything together.
564-
joint_log_obs_given_states = pt.logsumexp(log_alpha_seq, axis=0)
563+
# Final logp is just the sum of the last scan state
564+
joint_logp = pt.logsumexp(log_alpha_seq[-1])
565565

566566
# If there are multple emisson streams, we have to add dummy logps for the remaining value variables. The first
567567
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
568568
dummy_logps = (pt.constant(np.zeros(shape=())),) * (len(values) - 1)
569-
return joint_log_obs_given_states, *dummy_logps
569+
return joint_logp, *dummy_logps

pymc_experimental/tests/test_marginal_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def test_marginalized_hmm_one_bernoulli_emission():
494494
P = np.array([[0.5, 0.5], [0.3, 0.7]])
495495
init_dist = pm.Categorical.dist(p=[0.375, 0.625])
496496
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2)
497-
emission = pm.Bernoulli("emission", p=pt.where(chain, 0.2, 0.6))
497+
emission = pm.Bernoulli("emission", p=pt.where(chain, 0.6, 0.2))
498498
m.marginalize([chain])
499499

500500
test_value = np.array([0, 0, 1])

0 commit comments

Comments
 (0)