Skip to content

Commit 6f13f7e

Browse files
committed
Add LKJCholeskyCov moment
1 parent 1a35a3d commit 6f13f7e

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

pymc/distributions/multivariate.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,6 +1202,12 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs):
12021202

12031203
return super().dist([n, eta, sd_dist], size=size, **kwargs)
12041204

1205+
def get_moment(rv, size, n, eta, sd_dists):
1206+
diag_idxs = (at.cumsum(at.arange(1, n + 1)) - 1).astype("int32")
1207+
moment = at.zeros_like(rv)
1208+
moment = at.set_subtensor(moment[..., diag_idxs], 1)
1209+
return moment
1210+
12051211
def logp(value, n, eta, sd_dist):
12061212
"""
12071213
Calculate log-probability of Covariance matrix with LKJ

pymc/tests/test_distributions_moments.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
KroneckerNormal,
4040
Kumaraswamy,
4141
Laplace,
42+
LKJCholeskyCov,
4243
LKJCorr,
4344
Logistic,
4445
LogitNormal,
@@ -1439,3 +1440,24 @@ def test_lkjcorr_moment(n, eta, size, expected):
14391440
with Model() as model:
14401441
LKJCorr("x", n=n, eta=eta, size=size)
14411442
assert_moment_is_expected(model, expected)
1443+
1444+
1445+
@pytest.mark.parametrize(
1446+
"n, eta, size, expected",
1447+
[
1448+
(3, 1, None, np.array([1, 0, 1, 0, 0, 1])),
1449+
(4, 1, None, np.array([1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0])),
1450+
(3, 1, 1, np.array([[1, 0, 1, 0, 0, 1]])),
1451+
(
1452+
4,
1453+
1,
1454+
(2, 3),
1455+
np.full((2, 3, 10), np.array([1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0])),
1456+
),
1457+
],
1458+
)
1459+
def test_lkjcholeskycov_moment(n, eta, size, expected):
1460+
with Model() as model:
1461+
sd_dist = pm.Exponential.dist(1, size=(*to_tuple(size), n))
1462+
LKJCholeskyCov("x", n=n, eta=eta, sd_dist=sd_dist, size=size, compute_corr=False)
1463+
assert_moment_is_expected(model, expected, check_finite_logp=size is None)

0 commit comments

Comments
 (0)