@@ -483,7 +483,7 @@ def test_marginalized_hmm_with_one_emission():
483
483
m .marginalize ([chain ])
484
484
485
485
logp_fn = m .compile_logp ()
486
- test_value = [- 1 , 1 , - 1 , 1 ]
486
+ test_value = np . array ( [- 1 , 1 , - 1 , 1 ])
487
487
488
488
expected_logp = pm .logp (pm .Normal .dist (0 , 1e-1 ), np .zeros_like (test_value )).sum ().eval ()
489
489
np .testing .assert_allclose (logp_fn ({f"emission" : test_value }), expected_logp )
@@ -501,10 +501,8 @@ def test_marginalized_hmm_with_many_emissions():
501
501
m .marginalize ([chain ])
502
502
503
503
logp_fn = m .compile_logp ()
504
- test_value = [- 1 , 1 , - 1 , 1 ]
504
+ test_value = np . array ( [- 1 , 1 , - 1 , 1 ])
505
505
506
- expected_logp = pm .logp (pm .Normal .dist (0 , 1e-1 ), np .zeros_like (test_value )).sum (). eval ()
506
+ e_logp = pm .logp (pm .Normal .dist (0 , 1e-1 ), np .zeros_like (test_value )).sum () * 2
507
507
test_point = {"emission_1" : test_value , "emission_2" : test_value * - 1 }
508
-
509
- assert False
510
- # np.testing.assert_allclose(logp_fn(test_point), expected_logp)
508
+ np .testing .assert_allclose (logp_fn (test_point ), e_logp .eval ())
0 commit comments