@@ -502,9 +502,6 @@ def logp_fn(marginalized_rv_const, *non_sequences):
502
502
503
503
@_logprob .register (DiscreteMarginalMarkovChainRV )
504
504
def finite_discrete_marginal_rv_logp (op , values , * inputs , ** kwargs ):
505
- def eval_logp (x ):
506
- return logp (init_dist_ , x )
507
-
508
505
marginalized_rvs_node = op .make_node (* inputs )
509
506
inner_rvs = clone_replace (
510
507
op .inner_outputs ,
@@ -513,19 +510,15 @@ def eval_logp(x):
513
510
514
511
chain_rv , * dependent_rvs = inner_rvs
515
512
P_ , n_steps_ , init_dist_ , rng = chain_rv .owner .inputs
516
-
517
- domain = pt .arange (P_ .shape [0 ], dtype = "int32" )
518
-
519
- vec_eval_logp = pt .vectorize (eval_logp , "()->()" )
520
-
521
513
log_P_ = pt .log (P_ )
522
- log_alpha_init = vec_eval_logp ( domain ) + log_P_
514
+ domain = pt . arange ( P_ . shape [ 0 ], dtype = "int32" )
523
515
524
516
# Construct logp in two steps
525
517
# Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)
526
518
527
- # This will break the dependency between chain and the init_dist_ random variable
528
- # TODO: Make this comment more robust after I understand better.
519
+ # First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating
520
+ # around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise,
521
+ # PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step.
529
522
chain_dummy = chain_rv .clone ()
530
523
dependent_rvs = clone_replace (dependent_rvs , {chain_rv : chain_dummy })
531
524
input_dict = dict (zip (dependent_rvs , values ))
@@ -536,34 +529,41 @@ def eval_logp(x):
536
529
chain_dummy : pt .moveaxis (pt .broadcast_to (domain , (* values [0 ].shape , domain .size )), - 1 , 0 )
537
530
}
538
531
539
- # This is a (k, T) matrix of logp terms, one for each state - emission pair
540
- vec_logp_emission = vectorize_graph (tuple (logp_value_dict .values ()), sub_dict )
532
+ # This is a list of (k, T) matrices of logp terms, one for each state - emission pair, for each RV that depends
533
+ # on the markov chain being marginalized. Since they all depend on the same Markov chain, it is safe to assume they
534
+ # all share the same length. Finally, because we only consider the **joint** logp of all variables that depend on
535
+ # the chain, we can sum all of these logp values now.
536
+ vec_logp_emission = pt .stack (vectorize_graph (tuple (logp_value_dict .values ()), sub_dict )).sum (
537
+ axis = 0
538
+ )
541
539
542
540
# Step 2: Compute the transition probabilities
543
- # This is the "forward algorithm", alpha_t = sum (p(s_t | s_{t-1}) * alpha_{t-1})
541
+ # This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}} (p(s_t | s_{t-1}) * alpha_{t-1})
544
542
# We do it entirely in logs, though.
545
- def step_alpha (logp_emission , log_alpha , log_P ):
546
543
547
- return pt .logsumexp (log_alpha [:, None ] + log_P , 0 )
544
+ # To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) under
545
+ # the initial distribution. This is robust to everything the user can throw at it.
546
+ def eval_logp (x ):
547
+ return logp (init_dist_ , x )
548
+
549
+ vec_eval_logp = pt .vectorize (eval_logp , "()->()" )
550
+ log_alpha_init = vec_eval_logp (domain ) + vec_logp_emission [..., 0 ]
551
+
552
+ def step_alpha (logp_emission , log_alpha , log_P ):
553
+ step_log_prob = pt .logsumexp (log_alpha [:, None ] + log_P , 0 )
554
+ return logp_emission + step_log_prob
548
555
549
556
log_alpha_seq , _ = scan (
550
557
step_alpha ,
551
558
non_sequences = [log_P_ ],
552
559
outputs_info = [log_alpha_init ],
553
- sequences = pt .moveaxis (vec_logp_emission , - 1 , 0 ),
560
+ # Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
561
+ sequences = pt .moveaxis (vec_logp_emission [..., 1 :], - 1 , 0 ),
554
562
)
555
-
556
- # Scan works over the T dimension, so output is (T, k). We need to swap to (k, T)
557
- log_alpha_seq = pt .moveaxis (
558
- pt .concatenate ([log_alpha_init , log_alpha_seq [..., - 1 ]], axis = 0 ), - 1 , 0
559
- )
560
-
561
- # Final logp is the sum of the sum of the emission probs and the transition probabilities
562
- # pt.add is used in case there are multiple emissions that depend on the same markov chain; in this case, we compute
563
- # the joint probability of seeing everything together.
564
- joint_log_obs_given_states = pt .logsumexp (log_alpha_seq , axis = 0 )
563
+ # Final logp is just the sum of the last scan state
564
+ joint_logp = pt .logsumexp (log_alpha_seq [- 1 ])
565
565
566
566
# If there are multple emisson streams, we have to add dummy logps for the remaining value variables. The first
567
567
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
568
568
dummy_logps = (pt .constant (np .zeros (shape = ())),) * (len (values ) - 1 )
569
- return joint_log_obs_given_states , * dummy_logps
569
+ return joint_logp , * dummy_logps
0 commit comments