Skip to content

Commit 9edbdb1

Browse files
committed
remove incorrect and add new test cases
1 parent 93e2c7f commit 9edbdb1

File tree

1 file changed

+23
-4
lines changed

1 file changed

+23
-4
lines changed

pymc/tests/test_distributions_moments.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -758,15 +758,34 @@ def test_categorical_moment(p, size, expected):
758758
(np.ones(1), np.identity(1), None, np.ones(1)),
759759
(np.ones(10), np.identity(10), None, np.ones(10)),
760760
(np.ones(2), np.identity(2), 4, np.ones((4, 2))),
761-
(np.ones(2), np.identity(2), (4, 2), np.ones((4, 2, 2))),
762-
(np.ones((2, 2)), np.identity(2), None, np.ones((2, 2))),
763-
(np.ones((2, 2)), np.identity(2), 4, np.ones((4, 2, 2))),
764-
(np.ones((2, 2)), np.identity(2), (4, 2), np.ones((4, 2, 2, 2))),
761+
(np.ones(2), np.identity(2), (4, 3), np.ones((4, 3, 2))),
762+
(np.array([1, 0, 3.0]), np.identity(3), None, np.array([1, 0, 3.0])),
763+
(np.array([1, 0, 3.0]), np.identity(3), 4, np.full((4, 3), [1, 0, 3.0])),
764+
(np.array([1, 0, 3.0]), np.identity(3), (4, 2), np.full((4, 2, 3), [1, 0, 3.0])),
765+
(
766+
np.array([1, 3.0]),
767+
np.identity(2),
768+
(4, 5),
769+
np.full((4, 5, 2), [1, 3.0]),
770+
),
771+
(
772+
np.array([1, 3.0]),
773+
np.array([[1.0, 0.5], [0.5, 2]]),
774+
(4, 5),
775+
np.full((4, 5, 2), [1, 3.0]),
776+
),
777+
(
778+
np.array([1, 3, 0.0]),
779+
np.array([[1.0, 0.5, 0.1], [0.5, 2, 0.5], [0.1, 0.5, 5]]),
780+
(4, 5),
781+
np.full((4, 5, 3), [1, 3, 0.0]),
782+
),
765783
],
766784
)
767785
def test_mv_normal_moment(mu, cov, size, expected):
768786
with Model() as model:
769787
MvNormal("x", mu=mu, cov=cov, size=size)
788+
assert_moment_is_expected(model, expected)
770789

771790

772791
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)