@@ -522,6 +522,9 @@ def eval_logp(x):
522
522
vec_eval_logp = pt .vectorize (eval_logp , "()->()" )
523
523
logp_init = vec_eval_logp (domain )
524
524
525
+ # Construct logp in two steps
526
+ # Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)
527
+
525
528
# This will break the dependency between chain and the init_dist_ random variable
526
529
# TODO: Make this comment more robust after I understand better.
527
530
chain_dummy = chain_rv .clone ()
@@ -534,19 +537,25 @@ def eval_logp(x):
534
537
chain_dummy : pt .moveaxis (pt .broadcast_to (domain , (* values [0 ].shape , domain .size )), - 1 , 0 )
535
538
}
536
539
537
- # TODO: @Ricardo: If you don't concatenate here, you get -inf in the logp (why?)
538
- # TODO: I'm stacking the results (adds a batch dim to the left) and summing away the batch dim == joint probability?
539
- vec_logp_emission = pt .stack (vectorize_graph (tuple (logp_value_dict .values ()), sub_dict )).sum (
540
- axis = 0
541
- )
540
+ # This is a (k, T) matrix of logp terms, one for each state - observation pair
541
+ vec_logp_emission = vectorize_graph (tuple (logp_value_dict .values ()), sub_dict )
542
542
543
+ # Step 2: Compute the transition probabilities
544
+ # This is the "forward algorithm", alpha_t = sum(p(s_t | s_{t-1}) * alpha_{t-1})
545
+ # We do it entirely in logs, though.
543
546
log_alpha_seq , _ = scan (
544
547
step_alpha , non_sequences = [pt .log (P_ )], outputs_info = [logp_init ], n_steps = n_steps_
545
548
)
546
549
550
+ # Scan works over the T dimension, so output is (T, k). We need to swap to (k, T)
547
551
log_alpha_seq = pt .moveaxis (pt .concatenate ([logp_init [None ], log_alpha_seq ], axis = 0 ), - 1 , 0 )
548
- joint_log_obs_given_states = pt .logsumexp (pt .add (vec_logp_emission ) + log_alpha_seq , axis = 0 )
549
552
550
- # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
551
- dummy_logps = (pt .constant (0.0 ),) * (len (values ) - 1 )
552
- return joint_log_obs_given_states , dummy_logps
553
+ # Final logp is the sum of the sum of the emission probs and the transition probabilities
554
+ # pt.add is used in case there are multiple emissions that depend on the same markov chain; in this case, we compute
555
+ # the joint probability of seeing everything together.
556
+ joint_log_obs_given_states = pt .logsumexp (pt .add (* vec_logp_emission ) + log_alpha_seq , axis = 0 )
557
+
558
+ # If there are multple emisson streams, we have to add dummy logps for the remaining value variables. The first
559
+ # return is the joint probability of everything together, but PyMC still expects one logp for each one.
560
+ dummy_logps = (pt .constant (np .zeros ((4 ,))),) * (len (values ) - 1 )
561
+ return joint_log_obs_given_states , * dummy_logps
0 commit comments