@@ -502,9 +502,6 @@ def logp_fn(marginalized_rv_const, *non_sequences):
502
502
503
503
@_logprob .register (DiscreteMarginalMarkovChainRV )
504
504
def finite_discrete_marginal_rv_logp (op , values , * inputs , ** kwargs ):
505
- def step_alpha (log_alpha , log_P ):
506
- return pt .logsumexp (log_alpha [:, None ] + log_P , 0 )
507
-
508
505
def eval_logp (x ):
509
506
return logp (init_dist_ , x )
510
507
@@ -520,7 +517,9 @@ def eval_logp(x):
520
517
domain = pt .arange (P_ .shape [0 ], dtype = "int32" )
521
518
522
519
vec_eval_logp = pt .vectorize (eval_logp , "()->()" )
523
- logp_init = vec_eval_logp (domain )
520
+
521
+ log_P_ = pt .log (P_ )
522
+ log_alpha_init = vec_eval_logp (domain ) + log_P_
524
523
525
524
# Construct logp in two steps
526
525
# Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)
@@ -543,19 +542,28 @@ def eval_logp(x):
543
542
# Step 2: Compute the transition probabilities
544
543
# This is the "forward algorithm", alpha_t = sum(p(s_t | s_{t-1}) * alpha_{t-1})
545
544
# We do it entirely in logs, though.
545
+ def step_alpha (logp_emission , log_alpha , log_P ):
546
+
547
+ return pt .logsumexp (log_alpha [:, None ] + log_P , 0 )
548
+
546
549
log_alpha_seq , _ = scan (
547
- step_alpha , non_sequences = [pt .log (P_ )], outputs_info = [logp_init ], n_steps = n_steps_
550
+ step_alpha ,
551
+ non_sequences = [log_P_ ],
552
+ outputs_info = [log_alpha_init ],
553
+ sequences = pt .moveaxis (vec_logp_emission , - 1 , 0 ),
548
554
)
549
555
550
556
# Scan works over the T dimension, so output is (T, k). We need to swap to (k, T)
551
- log_alpha_seq = pt .moveaxis (pt .concatenate ([logp_init [None ], log_alpha_seq ], axis = 0 ), - 1 , 0 )
557
+ log_alpha_seq = pt .moveaxis (
558
+ pt .concatenate ([log_alpha_init , log_alpha_seq [..., - 1 ]], axis = 0 ), - 1 , 0
559
+ )
552
560
553
561
# Final logp is the sum of the sum of the emission probs and the transition probabilities
554
562
# pt.add is used in case there are multiple emissions that depend on the same markov chain; in this case, we compute
555
563
# the joint probability of seeing everything together.
556
- joint_log_obs_given_states = pt .logsumexp (pt . add ( * vec_logp_emission ) + log_alpha_seq , axis = 0 )
564
+ joint_log_obs_given_states = pt .logsumexp (log_alpha_seq , axis = 0 )
557
565
558
566
# If there are multple emisson streams, we have to add dummy logps for the remaining value variables. The first
559
567
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
560
- dummy_logps = (pt .constant (np .zeros (( 4 , ))),) * (len (values ) - 1 )
568
+ dummy_logps = (pt .constant (np .zeros (shape = ( ))),) * (len (values ) - 1 )
561
569
return joint_log_obs_given_states , * dummy_logps
0 commit comments