|
46 | 46 | ZeroInflatedBinomial,
|
47 | 47 | ZeroInflatedPoisson,
|
48 | 48 | )
|
| 49 | +from pymc.distributions.multivariate import MvNormal |
49 | 50 | from pymc.distributions.shape_utils import rv_size_is_none
|
50 | 51 | from pymc.initial_point import make_initial_point_fn
|
51 | 52 | from pymc.model import Model
|
@@ -753,6 +754,40 @@ def test_categorical_moment(p, size, expected):
|
753 | 754 | assert_moment_is_expected(model, expected)
|
754 | 755 |
|
755 | 756 |
|
| 757 | +@pytest.mark.parametrize( |
| 758 | + "mu, cov, size, expected", |
| 759 | + [ |
| 760 | + (np.ones(1), np.identity(1), None, np.ones(1)), |
| 761 | + (np.ones(3), np.identity(3), None, np.ones(3)), |
| 762 | + (np.ones((2, 2)), np.identity(2), None, np.ones((2, 2))), |
| 763 | + (np.array([1, 0, 3.0]), np.identity(3), None, np.array([1, 0, 3.0])), |
| 764 | + (np.array([1, 0, 3.0]), np.identity(3), (4, 2), np.full((4, 2, 3), [1, 0, 3.0])), |
| 765 | + ( |
| 766 | + np.array([1, 3.0]), |
| 767 | + np.identity(2), |
| 768 | + 5, |
| 769 | + np.full((5, 2), [1, 3.0]), |
| 770 | + ), |
| 771 | + ( |
| 772 | + np.array([1, 3.0]), |
| 773 | + np.array([[1.0, 0.5], [0.5, 2]]), |
| 774 | + (4, 5), |
| 775 | + np.full((4, 5, 2), [1, 3.0]), |
| 776 | + ), |
| 777 | + ( |
| 778 | + np.array([[3.0, 5], [1, 4]]), |
| 779 | + np.identity(2), |
| 780 | + (4, 5), |
| 781 | + np.full((4, 5, 2, 2), [[3.0, 5], [1, 4]]), |
| 782 | + ), |
| 783 | + ], |
| 784 | +) |
| 785 | +def test_mv_normal_moment(mu, cov, size, expected): |
| 786 | + with Model() as model: |
| 787 | + MvNormal("x", mu=mu, cov=cov, size=size) |
| 788 | + assert_moment_is_expected(model, expected) |
| 789 | + |
| 790 | + |
756 | 791 | @pytest.mark.parametrize(
|
757 | 792 | "mu, sigma, size, expected",
|
758 | 793 | [
|
|
0 commit comments