Skip to content

Commit 9e78ad2

Browse files
yadav-sachinricardoV94
authored andcommitted
Add MvNormal moment (pymc-devs#5171)
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent e39bb0d commit 9e78ad2

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

pymc/distributions/multivariate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,13 @@ def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs):
228228
cov = quaddist_matrix(cov, chol, tau, lower)
229229
return super().dist([mu, cov], **kwargs)
230230

231+
def get_moment(rv, size, mu, cov):
232+
moment = mu
233+
if not rv_size_is_none(size):
234+
moment_size = at.concatenate([size, mu.shape])
235+
moment = at.full(moment_size, mu)
236+
return moment
237+
231238
def logp(value, mu, cov):
232239
"""
233240
Calculate log-probability of Multivariate Normal distribution

pymc/tests/test_distributions_moments.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
ZeroInflatedBinomial,
4848
ZeroInflatedPoisson,
4949
)
50+
from pymc.distributions.multivariate import MvNormal
5051
from pymc.distributions.shape_utils import rv_size_is_none
5152
from pymc.initial_point import make_initial_point_fn
5253
from pymc.model import Model
@@ -774,6 +775,40 @@ def test_categorical_moment(p, size, expected):
774775
assert_moment_is_expected(model, expected)
775776

776777

778+
@pytest.mark.parametrize(
779+
"mu, cov, size, expected",
780+
[
781+
(np.ones(1), np.identity(1), None, np.ones(1)),
782+
(np.ones(3), np.identity(3), None, np.ones(3)),
783+
(np.ones((2, 2)), np.identity(2), None, np.ones((2, 2))),
784+
(np.array([1, 0, 3.0]), np.identity(3), None, np.array([1, 0, 3.0])),
785+
(np.array([1, 0, 3.0]), np.identity(3), (4, 2), np.full((4, 2, 3), [1, 0, 3.0])),
786+
(
787+
np.array([1, 3.0]),
788+
np.identity(2),
789+
5,
790+
np.full((5, 2), [1, 3.0]),
791+
),
792+
(
793+
np.array([1, 3.0]),
794+
np.array([[1.0, 0.5], [0.5, 2]]),
795+
(4, 5),
796+
np.full((4, 5, 2), [1, 3.0]),
797+
),
798+
(
799+
np.array([[3.0, 5], [1, 4]]),
800+
np.identity(2),
801+
(4, 5),
802+
np.full((4, 5, 2, 2), [[3.0, 5], [1, 4]]),
803+
),
804+
],
805+
)
806+
def test_mv_normal_moment(mu, cov, size, expected):
807+
with Model() as model:
808+
MvNormal("x", mu=mu, cov=cov, size=size)
809+
assert_moment_is_expected(model, expected)
810+
811+
777812
@pytest.mark.parametrize(
778813
"mu, sigma, size, expected",
779814
[

0 commit comments

Comments
 (0)