12
12
from pymc .util import UNSET
13
13
from scipy .special import logsumexp
14
14
15
+ from pymc_experimental .distributions import DiscreteMarkovChain
15
16
from pymc_experimental .marginal_model import (
16
17
FiniteDiscreteMarginalRV ,
17
18
MarginalModel ,
@@ -470,3 +471,90 @@ def dist(idx, size):
470
471
):
471
472
pt = {"norm" : test_value }
472
473
np .testing .assert_allclose (logp_fn (pt ), ref_logp_fn (pt ))
474
+
475
+
476
+ def hmm_logp (values , P , steps , init_dist , state_rng ):
477
+
478
+ [e_value ] = values
479
+
480
+ # P = [[0, 1], [1, 0]]
481
+ domain = tuple (range (pt .get_vector_length (P [- 1 ])))
482
+
483
+ # This should be done on log-scale
484
+ # Probability of states at t0
485
+ logprob_states = pm .math .stack ([logp (init_dist , d ) for d in domain ])
486
+
487
+ logprob_emiss_ts = []
488
+ for e_value_t in e_value :
489
+ # Use vectorize
490
+ logprob_emiss_t = pt .sum (
491
+ [
492
+ logpprob_state + logp (clone_replace (emission_rv , replace = {state_rv : state_value }), e_value_t )
493
+ for (logpprob_state , state_value ) in zip (logprob_states , domain )
494
+ ]
495
+ )
496
+
497
+ # Probability next state
498
+ # prob_states = prob_states @ P
499
+ logprob_states = P [:, None ]
500
+
501
+ logprob_emiss_ts .append (logprob_emiss_t )
502
+
503
+ return logprob_emiss_ts .sum ()
504
+
505
+
506
+
507
+ def test_hmm ():
508
+
509
+ with MarginalModel () as m :
510
+ p = pt .as_tensor (np .array ([1 , 0 ]))
511
+
512
+ chain_0 = pm .Bernoulli ("chain_0" , p = 0 )
513
+ chain_1 = pm .Bernoulli ("chain_1" , p = p [chain_0 ])
514
+ chain_2 = pm .Bernoulli ("chain_2" , p = p [chain_1 ])
515
+ chain_3 = pm .Bernoulli ("chain_3" , p = p [chain_2 ])
516
+
517
+ pm .Normal ("emission_0" , chain_0 * 2 - 1 , sigma = 1e-1 )
518
+ pm .Normal ("emission_1" , chain_1 * 2 - 1 , sigma = 1e-1 )
519
+ pm .Normal ("emission_2" , chain_2 * 2 - 1 , sigma = 1e-1 )
520
+ pm .Normal ("emission_3" , chain_3 * 2 - 1 , sigma = 1e-1 )
521
+
522
+
523
+
524
+ with pytest .warns (UserWarning , match = "multiple dependent variables" ):
525
+ m .marginalize ([chain_0 , chain_1 , chain_2 , chain_3 ])
526
+ import pytensor
527
+ print ()
528
+ pytensor .dprint (m .clone ()._marginalize ().free_RVs )
529
+
530
+ logp_fn = m .compile_logp ()
531
+ test_value = [- 1 , 1 , - 1 , 1 ]
532
+
533
+ expected_logp = pm .logp (pm .Normal .dist (0 , 1e-1 ), np .zeros_like (test_value )).sum ().eval ()
534
+
535
+ np .testing .assert_allclose (
536
+ logp_fn ({f"emission_{ i } " : test_value_i for i , test_value_i in enumerate (test_value )}),
537
+ expected_logp ,
538
+ )
539
+ return
540
+
541
+ # with MarginalModel() as m:
542
+ # P = [[0, 1], [1, 0]]
543
+ # zero = pm.DiracDelta.dist(np.array(0, dtype="int64"))
544
+ # chain = DiscreteMarkovChain("chain", P=P, init_dist=zero, steps=3)
545
+ # emmission = pm.Normal("emission", mu=chain * 2 - 1, sigma=1e-1)
546
+ # np.testing.assert_equal(pm.draw(chain), [0, 1, 0, 1])
547
+ # m.marginalize(chain)
548
+
549
+ # test_value = [-1, 1, -1, 1]
550
+ # expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), [0, 0, 0, 0]).eval()
551
+
552
+ # np.testing.assert_allclose(
553
+ # logp_fn({"emission": test_value}),
554
+ # expected_logp,
555
+ # )
556
+ #
557
+ # np.testing.assert_allclose(
558
+ # logp_fn({f"emission{i}": test_value_i for i, test_value_i in enumerate(test_value)}),
559
+ # expected_logp,
560
+ # )
0 commit comments