Skip to content

Commit 4127be6

Browse files
Fix typos, remove unnecessary stack
1 parent 88c3b67 commit 4127be6

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

pymc_experimental/marginal_model.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,9 @@ def eval_logp(x):
522522
vec_eval_logp = pt.vectorize(eval_logp, "()->()")
523523
logp_init = vec_eval_logp(domain)
524524

525+
# Construct logp in two steps
526+
# Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)
527+
525528
# This will break the dependency between chain and the init_dist_ random variable
526529
# TODO: Make this comment more robust after I understand better.
527530
chain_dummy = chain_rv.clone()
@@ -534,19 +537,25 @@ def eval_logp(x):
534537
chain_dummy: pt.moveaxis(pt.broadcast_to(domain, (*values[0].shape, domain.size)), -1, 0)
535538
}
536539

537-
# TODO: @Ricardo: If you don't concatenate here, you get -inf in the logp (why?)
538-
# TODO: I'm stacking the results (adds a batch dim to the left) and summing away the batch dim == joint probability?
539-
vec_logp_emission = pt.stack(vectorize_graph(tuple(logp_value_dict.values()), sub_dict)).sum(
540-
axis=0
541-
)
540+
# This is a (k, T) matrix of logp terms, one for each state - observation pair
541+
vec_logp_emission = vectorize_graph(tuple(logp_value_dict.values()), sub_dict)
542542

543+
# Step 2: Compute the transition probabilities
544+
# This is the "forward algorithm", alpha_t = sum(p(s_t | s_{t-1}) * alpha_{t-1})
545+
# We do it entirely in logs, though.
543546
log_alpha_seq, _ = scan(
544547
step_alpha, non_sequences=[pt.log(P_)], outputs_info=[logp_init], n_steps=n_steps_
545548
)
546549

550+
# Scan works over the T dimension, so output is (T, k). We need to swap to (k, T)
547551
log_alpha_seq = pt.moveaxis(pt.concatenate([logp_init[None], log_alpha_seq], axis=0), -1, 0)
548-
joint_log_obs_given_states = pt.logsumexp(pt.add(vec_logp_emission) + log_alpha_seq, axis=0)
549552

550-
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
551-
dummy_logps = (pt.constant(0.0),) * (len(values) - 1)
552-
return joint_log_obs_given_states, dummy_logps
553+
# Final logp is the sum of the sum of the emission probs and the transition probabilities
554+
# pt.add is used in case there are multiple emissions that depend on the same markov chain; in this case, we compute
555+
# the joint probability of seeing everything together.
556+
joint_log_obs_given_states = pt.logsumexp(pt.add(*vec_logp_emission) + log_alpha_seq, axis=0)
557+
558+
# If there are multple emisson streams, we have to add dummy logps for the remaining value variables. The first
559+
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
560+
dummy_logps = (pt.constant(np.zeros((4,))),) * (len(values) - 1)
561+
return joint_log_obs_given_states, *dummy_logps

0 commit comments

Comments
 (0)