Skip to content

Commit e257fe0

Browse files
Add MvNormal moment (#5171)
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent aae9dfb commit e257fe0

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

pymc/distributions/multivariate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,13 @@ def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs):
228228
cov = quaddist_matrix(cov, chol, tau, lower)
229229
return super().dist([mu, cov], **kwargs)
230230

231+
def get_moment(rv, size, mu, cov):
232+
moment = mu
233+
if not rv_size_is_none(size):
234+
moment_size = at.concatenate([size, mu.shape])
235+
moment = at.full(moment_size, mu)
236+
return moment
237+
231238
def logp(value, mu, cov):
232239
"""
233240
Calculate log-probability of Multivariate Normal distribution

pymc/tests/test_distributions_moments.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
ZeroInflatedBinomial,
4747
ZeroInflatedPoisson,
4848
)
49+
from pymc.distributions.multivariate import MvNormal
4950
from pymc.distributions.shape_utils import rv_size_is_none
5051
from pymc.initial_point import make_initial_point_fn
5152
from pymc.model import Model
@@ -753,6 +754,40 @@ def test_categorical_moment(p, size, expected):
753754
assert_moment_is_expected(model, expected)
754755

755756

757+
@pytest.mark.parametrize(
758+
"mu, cov, size, expected",
759+
[
760+
(np.ones(1), np.identity(1), None, np.ones(1)),
761+
(np.ones(3), np.identity(3), None, np.ones(3)),
762+
(np.ones((2, 2)), np.identity(2), None, np.ones((2, 2))),
763+
(np.array([1, 0, 3.0]), np.identity(3), None, np.array([1, 0, 3.0])),
764+
(np.array([1, 0, 3.0]), np.identity(3), (4, 2), np.full((4, 2, 3), [1, 0, 3.0])),
765+
(
766+
np.array([1, 3.0]),
767+
np.identity(2),
768+
5,
769+
np.full((5, 2), [1, 3.0]),
770+
),
771+
(
772+
np.array([1, 3.0]),
773+
np.array([[1.0, 0.5], [0.5, 2]]),
774+
(4, 5),
775+
np.full((4, 5, 2), [1, 3.0]),
776+
),
777+
(
778+
np.array([[3.0, 5], [1, 4]]),
779+
np.identity(2),
780+
(4, 5),
781+
np.full((4, 5, 2, 2), [[3.0, 5], [1, 4]]),
782+
),
783+
],
784+
)
785+
def test_mv_normal_moment(mu, cov, size, expected):
786+
with Model() as model:
787+
MvNormal("x", mu=mu, cov=cov, size=size)
788+
assert_moment_is_expected(model, expected)
789+
790+
756791
@pytest.mark.parametrize(
757792
"mu, sigma, size, expected",
758793
[

0 commit comments

Comments
 (0)