@@ -249,19 +249,6 @@ def _distr_parameters_for_repr(self):
249
249
return ["mu" , "cov" ]
250
250
251
251
252
- def safe_multivariate_t (nu , mu , cov , size = None , rng = None ):
253
- res = np .atleast_1d (
254
- stats .multivariate_t (loc = mu , shape = cov , df = nu , allow_singular = True ).rvs (
255
- size = size , random_state = rng
256
- )
257
- )
258
-
259
- if size is not None :
260
- res = res .reshape (list (size ) + [- 1 ])
261
-
262
- return res
263
-
264
-
265
252
class MvStudentTRV (RandomVariable ):
266
253
name = "multivariate_studentt"
267
254
ndim_supp = 1
@@ -285,25 +272,22 @@ def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
285
272
@classmethod
286
273
def rng_fn (cls , rng , nu , mu , cov , size ):
287
274
288
- if mu .ndim > 1 or cov .ndim > 2 :
289
- # Neither SciPy nor NumPy implement parameter broadcasting for
290
- # multivariate normals (or many other multivariate distributions),
291
- # so we have implement a quick and dirty one here
292
- mu , cov = broadcast_params ([mu , cov ], cls .ndims_params [1 :])
293
- size = tuple (size or ())
275
+ # Don't reassign broadcasted cov, since MvNormal expects two dimensional cov only.
276
+ mu , _ = broadcast_params ([mu , cov ], cls .ndims_params [1 :])
294
277
295
- if size :
296
- mu = np .broadcast_to (mu , size + mu .shape )
297
- cov = np .broadcast_to (cov , size + cov .shape )
298
-
299
- res = np .empty (mu .shape )
300
- for idx in np .ndindex (mu .shape [:- 1 ]):
301
- m = mu [idx ]
302
- c = cov [idx ]
303
- res [idx ] = safe_multivariate_t (nu , m , c , rng = rng )
304
- return res
305
- else :
306
- return safe_multivariate_t (nu , mu , cov , size = size , rng = rng )
278
+ chi2_samples = rng .chisquare (nu , size = size )
279
+ # Add distribution shape to chi2 samples
280
+ chi2_samples = chi2_samples .reshape (chi2_samples .shape + (1 ,) * len (mu .shape ))
281
+
282
+ mv_samples = pm .MvNormal .dist (
283
+ mu = np .zeros_like (mu ), cov = cov , size = size , rng = aesara .shared (rng )
284
+ ).eval ()
285
+
286
+ size = tuple (size or ())
287
+ if size :
288
+ mu = np .broadcast_to (mu , size + mu .shape )
289
+
290
+ return (mv_samples / np .sqrt (chi2_samples / nu )) + mu
307
291
308
292
309
293
mv_studentt = MvStudentTRV ()
0 commit comments