Skip to content

Commit 63aa44c

Browse files
committed
Move _LKJCholeskyCov sd_dist resizing to Op.make_node
1 parent 68eb273 commit 63aa44c

File tree

2 files changed

+28
-17
lines changed

2 files changed

+28
-17
lines changed

pymc/distributions/multivariate.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace
3333
from aesara.tensor.random.basic import MultinomialRV, dirichlet, multivariate_normal
3434
from aesara.tensor.random.op import RandomVariable, default_supp_shape_from_params
35-
from aesara.tensor.random.utils import broadcast_params
35+
from aesara.tensor.random.utils import broadcast_params, normalize_size_param
3636
from aesara.tensor.slinalg import Cholesky
3737
from aesara.tensor.slinalg import solve_lower_triangular as solve_lower
3838
from aesara.tensor.slinalg import solve_upper_triangular as solve_upper
@@ -1134,6 +1134,19 @@ def make_node(self, rng, size, dtype, n, eta, D):
11341134

11351135
D = at.as_tensor_variable(D)
11361136

1137+
# We resize the sd_dist `D` automatically so that it has (size x n) independent
1138+
# draws which is what the `_LKJCholeskyCovRV.rng_fn` expects. This makes the
1139+
# random and logp methods equivalent, as the latter also assumes a unique value
1140+
# for each diagonal element.
1141+
# Since `eta` and `n` are forced to be scalars we don't need to worry about
1142+
# implied batched dimensions for the time being.
1143+
size = normalize_size_param(size)
1144+
if D.owner.op.ndim_supp == 0:
1145+
D = change_rv_size(D, at.concatenate((size, (n,))))
1146+
else:
1147+
# The support shape must be `n` but we have no way of controlling it
1148+
D = change_rv_size(D, size)
1149+
11371150
return super().make_node(rng, size, dtype, n, eta, D)
11381151

11391152
def _infer_shape(self, size, dist_params, param_shapes=None):
@@ -1179,7 +1192,7 @@ def __new__(cls, name, eta, n, sd_dist, **kwargs):
11791192
return super().__new__(cls, name, eta, n, sd_dist, **kwargs)
11801193

11811194
@classmethod
1182-
def dist(cls, eta, n, sd_dist, size=None, **kwargs):
1195+
def dist(cls, eta, n, sd_dist, **kwargs):
11831196
eta = at.as_tensor_variable(floatX(eta))
11841197
n = at.as_tensor_variable(intX(n))
11851198

@@ -1191,18 +1204,6 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs):
11911204
):
11921205
raise TypeError("sd_dist must be a scalar or vector distribution variable")
11931206

1194-
# We resize the sd_dist automatically so that it has (size x n) independent draws
1195-
# which is what the `_LKJCholeskyCovRV.rng_fn` expects. This makes the random
1196-
# and logp methods equivalent, as the latter also assumes a unique value for each
1197-
# diagonal element.
1198-
# Since `eta` and `n` are forced to be scalars we don't need to worry about
1199-
# implied batched dimensions for the time being.
1200-
if sd_dist.owner.op.ndim_supp == 0:
1201-
sd_dist = change_rv_size(sd_dist, to_tuple(size) + (n,))
1202-
else:
1203-
# The support shape must be `n` but we have no way of controlling it
1204-
sd_dist = change_rv_size(sd_dist, to_tuple(size))
1205-
12061207
# sd_dist is part of the generative graph, but should be completely ignored
12071208
# by the logp graph, since the LKJ logp explicitly includes these terms.
12081209
# Setting sd_dist.tag.ignore_logprob to True, will prevent Aeppl warning about
@@ -1211,7 +1212,7 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs):
12111212
# sd_dist prior components from the logp expression.
12121213
sd_dist.tag.ignore_logprob = True
12131214

1214-
return super().dist([n, eta, sd_dist], size=size, **kwargs)
1215+
return super().dist([n, eta, sd_dist], **kwargs)
12151216

12161217
def moment(rv, size, n, eta, sd_dists):
12171218
diag_idxs = (at.cumsum(at.arange(1, n + 1)) - 1).astype("int32")

pymc/tests/test_distributions.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3437,8 +3437,18 @@ def test_no_warning_logp(self):
34373437
pm.MvNormal.dist(np.ones(3), np.eye(3)),
34383438
],
34393439
)
3440-
def test_sd_dist_automatically_resized(self, sd_dist):
3441-
x = pm.LKJCholeskyCov.dist(n=3, eta=1, sd_dist=sd_dist, size=10, compute_corr=False)
3440+
@pytest.mark.parametrize(
3441+
"size, shape",
3442+
[
3443+
((10,), None),
3444+
(None, (10, 6)),
3445+
(None, (10, ...)),
3446+
],
3447+
)
3448+
def test_sd_dist_automatically_resized(self, sd_dist, size, shape):
3449+
x = pm.LKJCholeskyCov.dist(
3450+
n=3, eta=1, sd_dist=sd_dist, size=size, shape=shape, compute_corr=False
3451+
)
34423452
resized_sd_dist = x.owner.inputs[-1]
34433453
assert resized_sd_dist.eval().shape == (10, 3)
34443454
# LKJCov has support shape `(n * (n+1)) // 2`

0 commit comments

Comments
 (0)