Skip to content

Commit daf76ba

Browse files
Make test_value a numpy array
1 parent 4127be6 commit daf76ba

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

pymc_experimental/tests/test_marginal_model.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def test_marginalized_hmm_with_one_emission():
483483
m.marginalize([chain])
484484

485485
logp_fn = m.compile_logp()
486-
test_value = [-1, 1, -1, 1]
486+
test_value = np.array([-1, 1, -1, 1])
487487

488488
expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval()
489489
np.testing.assert_allclose(logp_fn({f"emission": test_value}), expected_logp)
@@ -501,10 +501,8 @@ def test_marginalized_hmm_with_many_emissions():
501501
m.marginalize([chain])
502502

503503
logp_fn = m.compile_logp()
504-
test_value = [-1, 1, -1, 1]
504+
test_value = np.array([-1, 1, -1, 1])
505505

506-
expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval()
506+
e_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum() * 2
507507
test_point = {"emission_1": test_value, "emission_2": test_value * -1}
508-
509-
assert False
510-
# np.testing.assert_allclose(logp_fn(test_point), expected_logp)
508+
np.testing.assert_allclose(logp_fn(test_point), e_logp.eval())

0 commit comments

Comments
 (0)