32
32
from aesara .tensor .nlinalg import det , eigh , matrix_inverse , trace
33
33
from aesara .tensor .random .basic import MultinomialRV , dirichlet , multivariate_normal
34
34
from aesara .tensor .random .op import RandomVariable , default_supp_shape_from_params
35
- from aesara .tensor .random .utils import broadcast_params
35
+ from aesara .tensor .random .utils import broadcast_params , normalize_size_param
36
36
from aesara .tensor .slinalg import Cholesky
37
37
from aesara .tensor .slinalg import solve_lower_triangular as solve_lower
38
38
from aesara .tensor .slinalg import solve_upper_triangular as solve_upper
@@ -1134,6 +1134,19 @@ def make_node(self, rng, size, dtype, n, eta, D):
1134
1134
1135
1135
D = at .as_tensor_variable (D )
1136
1136
1137
+ # We resize the sd_dist `D` automatically so that it has (size x n) independent
1138
+ # draws which is what the `_LKJCholeskyCovRV.rng_fn` expects. This makes the
1139
+ # random and logp methods equivalent, as the latter also assumes a unique value
1140
+ # for each diagonal element.
1141
+ # Since `eta` and `n` are forced to be scalars we don't need to worry about
1142
+ # implied batched dimensions for the time being.
1143
+ size = normalize_size_param (size )
1144
+ if D .owner .op .ndim_supp == 0 :
1145
+ D = change_rv_size (D , at .concatenate ((size , (n ,))))
1146
+ else :
1147
+ # The support shape must be `n` but we have no way of controlling it
1148
+ D = change_rv_size (D , size )
1149
+
1137
1150
return super ().make_node (rng , size , dtype , n , eta , D )
1138
1151
1139
1152
def _infer_shape (self , size , dist_params , param_shapes = None ):
@@ -1179,7 +1192,7 @@ def __new__(cls, name, eta, n, sd_dist, **kwargs):
1179
1192
return super ().__new__ (cls , name , eta , n , sd_dist , ** kwargs )
1180
1193
1181
1194
@classmethod
1182
- def dist (cls , eta , n , sd_dist , size = None , ** kwargs ):
1195
+ def dist (cls , eta , n , sd_dist , ** kwargs ):
1183
1196
eta = at .as_tensor_variable (floatX (eta ))
1184
1197
n = at .as_tensor_variable (intX (n ))
1185
1198
@@ -1191,18 +1204,6 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs):
1191
1204
):
1192
1205
raise TypeError ("sd_dist must be a scalar or vector distribution variable" )
1193
1206
1194
- # We resize the sd_dist automatically so that it has (size x n) independent draws
1195
- # which is what the `_LKJCholeskyCovRV.rng_fn` expects. This makes the random
1196
- # and logp methods equivalent, as the latter also assumes a unique value for each
1197
- # diagonal element.
1198
- # Since `eta` and `n` are forced to be scalars we don't need to worry about
1199
- # implied batched dimensions for the time being.
1200
- if sd_dist .owner .op .ndim_supp == 0 :
1201
- sd_dist = change_rv_size (sd_dist , to_tuple (size ) + (n ,))
1202
- else :
1203
- # The support shape must be `n` but we have no way of controlling it
1204
- sd_dist = change_rv_size (sd_dist , to_tuple (size ))
1205
-
1206
1207
# sd_dist is part of the generative graph, but should be completely ignored
1207
1208
# by the logp graph, since the LKJ logp explicitly includes these terms.
1208
1209
# Setting sd_dist.tag.ignore_logprob to True, will prevent Aeppl warning about
@@ -1211,7 +1212,7 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs):
1211
1212
# sd_dist prior components from the logp expression.
1212
1213
sd_dist .tag .ignore_logprob = True
1213
1214
1214
- return super ().dist ([n , eta , sd_dist ], size = size , ** kwargs )
1215
+ return super ().dist ([n , eta , sd_dist ], ** kwargs )
1215
1216
1216
1217
def moment (rv , size , n , eta , sd_dists ):
1217
1218
diag_idxs = (at .cumsum (at .arange (1 , n + 1 )) - 1 ).astype ("int32" )
0 commit comments