@@ -115,40 +115,37 @@ def simplex_cont_transform(op, rv):
115
115
116
116
117
117
def quaddist_matrix (cov = None , chol = None , tau = None , lower = True , * args , ** kwargs ):
118
- if chol is not None and not lower :
119
- chol = chol .T
120
-
121
118
if len ([i for i in [tau , cov , chol ] if i is not None ]) != 1 :
122
119
raise ValueError ("Incompatible parameterization. Specify exactly one of tau, cov, or chol." )
123
120
124
121
if cov is not None :
125
122
cov = pt .as_tensor_variable (cov )
126
- if cov .ndim != 2 :
127
- raise ValueError ("cov must be two dimensional." )
123
+ if cov .ndim < 2 :
124
+ raise ValueError ("cov must be at least two dimensional." )
128
125
elif tau is not None :
129
126
tau = pt .as_tensor_variable (tau )
130
- if tau .ndim != 2 :
131
- raise ValueError ("tau must be two dimensional." )
132
- # TODO: What's the correct order/approach (in the non-square case)?
133
- # `pytensor.tensor.nlinalg.tensorinv`?
127
+ if tau .ndim < 2 :
128
+ raise ValueError ("tau must be at least two dimensional." )
134
129
cov = matrix_inverse (tau )
135
130
else :
136
- # TODO: What's the correct order/approach (in the non-square case)?
137
131
chol = pt .as_tensor_variable (chol )
138
- if chol .ndim != 2 :
139
- raise ValueError ("chol must be two dimensional." )
132
+ if chol .ndim < 2 :
133
+ raise ValueError ("chol must be at least two dimensional." )
134
+
135
+ if not lower :
136
+ chol = pt .swapaxes (chol , - 1 , - 2 )
140
137
141
138
# tag as lower triangular to enable pytensor rewrites of chol(l.l') -> l
142
139
chol .tag .lower_triangular = True
143
- cov = chol . dot (chol . T )
140
+ cov = pt . matmul (chol , pt . swapaxes ( chol , - 1 , - 2 ) )
144
141
145
142
return cov
146
143
147
144
148
- def quaddist_parse (value , mu , cov , mat_type = "cov" ):
145
+ def quaddist_chol (value , mu , cov ):
149
146
"""Compute (x - mu).T @ Sigma^-1 @ (x - mu) and the logdet of Sigma."""
150
- if value .ndim > 2 or value . ndim == 0 :
151
- raise ValueError ("Invalid dimension for value: %s" % value . ndim )
147
+ if value .ndim == 0 :
148
+ raise ValueError ("Value can't be a scalar" )
152
149
if value .ndim == 1 :
153
150
onedim = True
154
151
value = value [None , :]
@@ -157,42 +154,21 @@ def quaddist_parse(value, mu, cov, mat_type="cov"):
157
154
158
155
delta = value - mu
159
156
chol_cov = nan_lower_cholesky (cov )
160
- if mat_type != "tau" :
161
- dist , logdet , ok = quaddist_chol (delta , chol_cov )
162
- else :
163
- dist , logdet , ok = quaddist_tau (delta , chol_cov )
164
- if onedim :
165
- return dist [0 ], logdet , ok
166
-
167
- return dist , logdet , ok
168
-
169
157
170
- def quaddist_chol (delta , chol_mat ):
171
- diag = pt .diag (chol_mat )
158
+ diag = pt .diagonal (chol_cov , axis1 = - 2 , axis2 = - 1 )
172
159
# Check if the covariance matrix is positive definite.
173
- ok = pt .all (diag > 0 )
160
+ ok = pt .all (diag > 0 , axis = - 1 )
174
161
# If not, replace the diagonal. We return -inf later, but
175
162
# need to prevent solve_lower from throwing an exception.
176
- chol_cov = pt .switch (ok , chol_mat , 1 )
177
-
178
- delta_trans = solve_lower (chol_cov , delta .T ).T
163
+ chol_cov = pt .switch (ok [..., None , None ], chol_cov , 1 )
164
+ delta_trans = solve_lower (chol_cov , delta , b_ndim = 1 )
179
165
quaddist = (delta_trans ** 2 ).sum (axis = - 1 )
180
- logdet = pt .sum (pt .log (diag ))
181
- return quaddist , logdet , ok
182
-
183
-
184
- def quaddist_tau (delta , chol_mat ):
185
- diag = pt .nlinalg .diag (chol_mat )
186
- # Check if the precision matrix is positive definite.
187
- ok = pt .all (diag > 0 )
188
- # If not, replace the diagonal. We return -inf later, but
189
- # need to prevent solve_lower from throwing an exception.
190
- chol_tau = pt .switch (ok , chol_mat , 1 )
166
+ logdet = pt .log (diag ).sum (axis = - 1 )
191
167
192
- delta_trans = pt . dot ( delta , chol_tau )
193
- quaddist = ( delta_trans ** 2 ). sum ( axis = - 1 )
194
- logdet = - pt . sum ( pt . log ( diag ))
195
- return quaddist , logdet , ok
168
+ if onedim :
169
+ return quaddist [ 0 ], logdet , ok
170
+ else :
171
+ return quaddist , logdet , ok
196
172
197
173
198
174
class MvNormal (Continuous ):
@@ -266,10 +242,11 @@ def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs):
266
242
mu = pt .as_tensor_variable (mu )
267
243
cov = quaddist_matrix (cov , chol , tau , lower )
268
244
# PyTensor is stricter about the shape of mu, than PyMC used to be
269
- mu = pt .broadcast_arrays (mu , cov [..., - 1 ])[ 0 ]
245
+ mu , _ = pt .broadcast_arrays (mu , cov [..., - 1 ])
270
246
return super ().dist ([mu , cov ], ** kwargs )
271
247
272
248
def moment (rv , size , mu , cov ):
249
+ # mu is broadcasted to the potential length of cov in `dist`
273
250
moment = mu
274
251
if not rv_size_is_none (size ):
275
252
moment_size = pt .concatenate ([size , [mu .shape [- 1 ]]])
@@ -290,7 +267,7 @@ def logp(value, mu, cov):
290
267
-------
291
268
TensorVariable
292
269
"""
293
- quaddist , logdet , ok = quaddist_parse (value , mu , cov )
270
+ quaddist , logdet , ok = quaddist_chol (value , mu , cov )
294
271
k = floatX (value .shape [- 1 ])
295
272
norm = - 0.5 * k * pm .floatX (np .log (2 * np .pi ))
296
273
return check_parameters (
@@ -307,22 +284,6 @@ class MvStudentTRV(RandomVariable):
307
284
dtype = "floatX"
308
285
_print_name = ("MvStudentT" , "\\ operatorname{MvStudentT}" )
309
286
310
- def make_node (self , rng , size , dtype , nu , mu , cov ):
311
- nu = pt .as_tensor_variable (nu )
312
- if not nu .ndim == 0 :
313
- raise ValueError ("nu must be a scalar (ndim=0)." )
314
-
315
- return super ().make_node (rng , size , dtype , nu , mu , cov )
316
-
317
- def __call__ (self , nu , mu = None , cov = None , size = None , ** kwargs ):
318
- dtype = pytensor .config .floatX if self .dtype == "floatX" else self .dtype
319
-
320
- if mu is None :
321
- mu = np .array ([0.0 ], dtype = dtype )
322
- if cov is None :
323
- cov = np .array ([[1.0 ]], dtype = dtype )
324
- return super ().__call__ (nu , mu , cov , size = size , ** kwargs )
325
-
326
287
def _supp_shape_from_params (self , dist_params , param_shapes = None ):
327
288
return supp_shape_from_ref_param_shape (
328
289
ndim_supp = self .ndim_supp ,
@@ -333,14 +294,21 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None):
333
294
334
295
@classmethod
335
296
def rng_fn (cls , rng , nu , mu , cov , size ):
297
+ if size is None :
298
+ # When size is implicit, we need to broadcast parameters correctly,
299
+ # so that the MvNormal draws and the chisquare draws have the same number of batch dimensions.
300
+ # nu broadcasts mu and cov
301
+ if np .ndim (nu ) > max (mu .ndim - 1 , cov .ndim - 2 ):
302
+ _ , mu , cov = broadcast_params ((nu , mu , cov ), ndims_params = cls .ndims_params )
303
+ # nu is broadcasted by either mu or cov
304
+ elif np .ndim (nu ) < max (mu .ndim - 1 , cov .ndim - 2 ):
305
+ nu , _ , _ = broadcast_params ((nu , mu , cov ), ndims_params = cls .ndims_params )
306
+
336
307
mv_samples = multivariate_normal .rng_fn (rng = rng , mean = np .zeros_like (mu ), cov = cov , size = size )
337
308
338
309
# Take chi2 draws and add an axis of length 1 to the right for correct broadcasting below
339
310
chi2_samples = np .sqrt (rng .chisquare (nu , size = size ) / nu )[..., None ]
340
311
341
- if size :
342
- mu = np .broadcast_to (mu , size + (mu .shape [- 1 ],))
343
-
344
312
return (mv_samples / chi2_samples ) + mu
345
313
346
314
@@ -390,7 +358,7 @@ class MvStudentT(Continuous):
390
358
rv_op = mv_studentt
391
359
392
360
@classmethod
393
- def dist (cls , nu , Sigma = None , mu = None , scale = None , tau = None , chol = None , lower = True , ** kwargs ):
361
+ def dist (cls , nu , * , Sigma = None , mu , scale = None , tau = None , chol = None , lower = True , ** kwargs ):
394
362
cov = kwargs .pop ("cov" , None )
395
363
if cov is not None :
396
364
warnings .warn (
@@ -407,11 +375,13 @@ def dist(cls, nu, Sigma=None, mu=None, scale=None, tau=None, chol=None, lower=Tr
407
375
mu = pt .as_tensor_variable (floatX (mu ))
408
376
scale = quaddist_matrix (scale , chol , tau , lower )
409
377
# PyTensor is stricter about the shape of mu, than PyMC used to be
410
- mu = pt .broadcast_arrays (mu , scale [..., - 1 ])[ 0 ]
378
+ mu , _ = pt .broadcast_arrays (mu , scale [..., - 1 ])
411
379
412
380
return super ().dist ([nu , mu , scale ], ** kwargs )
413
381
414
382
def moment (rv , size , nu , mu , scale ):
383
+ # mu is broadcasted to the potential length of scale in `dist`
384
+ mu , _ = pt .random .utils .broadcast_params ([mu , nu ], ndims_params = [1 , 0 ])
415
385
moment = mu
416
386
if not rv_size_is_none (size ):
417
387
moment_size = pt .concatenate ([size , [mu .shape [- 1 ]]])
@@ -432,7 +402,7 @@ def logp(value, nu, mu, scale):
432
402
-------
433
403
TensorVariable
434
404
"""
435
- quaddist , logdet , ok = quaddist_parse (value , mu , scale )
405
+ quaddist , logdet , ok = quaddist_chol (value , mu , scale )
436
406
k = floatX (value .shape [- 1 ])
437
407
438
408
norm = gammaln ((nu + k ) / 2.0 ) - gammaln (nu / 2.0 ) - 0.5 * k * pt .log (nu * np .pi )
0 commit comments