27
27
from aesara .tensor import gammaln
28
28
from aesara .tensor .nlinalg import det , eigh , matrix_inverse , trace
29
29
from aesara .tensor .random .basic import MultinomialRV , dirichlet , multivariate_normal
30
+ from aesara .tensor .random .op import RandomVariable , default_shape_from_params
30
31
from aesara .tensor .random .utils import broadcast_params
31
32
from aesara .tensor .slinalg import (
32
33
Cholesky ,
@@ -248,6 +249,66 @@ def _distr_parameters_for_repr(self):
248
249
return ["mu" , "cov" ]
249
250
250
251
252
+ def safe_multivariate_t (nu , mu , cov , size = None , rng = None ):
253
+ res = np .atleast_1d (
254
+ stats .multivariate_t (loc = mu , shape = cov , df = nu , allow_singular = True ).rvs (
255
+ size = size , random_state = rng
256
+ )
257
+ )
258
+
259
+ if size is not None :
260
+ res = res .reshape (list (size ) + [- 1 ])
261
+
262
+ return res
263
+
264
+
265
+ class MvStudentTRV (RandomVariable ):
266
+ name = "multivariate_studentt"
267
+ ndim_supp = 1
268
+ ndims_params = [0 , 1 , 2 ]
269
+ dtype = "floatX"
270
+ _print_name = ("MvStudentT" , "\\ operatorname{MvStudentT}" )
271
+
272
+ def __call__ (self , nu , mu = None , cov = None , size = None , ** kwargs ):
273
+
274
+ dtype = aesara .config .floatX if self .dtype == "floatX" else self .dtype
275
+
276
+ if mu is None :
277
+ mu = np .array ([0.0 ], dtype = dtype )
278
+ if cov is None :
279
+ cov = np .array ([[1.0 ]], dtype = dtype )
280
+ return super ().__call__ (nu , mu , cov , size = size , ** kwargs )
281
+
282
+ def _shape_from_params (self , dist_params , rep_param_idx = 1 , param_shapes = None ):
283
+ return default_shape_from_params (self .ndim_supp , dist_params , rep_param_idx , param_shapes )
284
+
285
+ @classmethod
286
+ def rng_fn (cls , rng , nu , mu , cov , size ):
287
+
288
+ if mu .ndim > 1 or cov .ndim > 2 :
289
+ # Neither SciPy nor NumPy implement parameter broadcasting for
290
+ # multivariate normals (or many other multivariate distributions),
291
+ # so we have implement a quick and dirty one here
292
+ mu , cov = broadcast_params ([mu , cov ], cls .ndims_params [1 :])
293
+ size = tuple (size or ())
294
+
295
+ if size :
296
+ mu = np .broadcast_to (mu , size + mu .shape )
297
+ cov = np .broadcast_to (cov , size + cov .shape )
298
+
299
+ res = np .empty (mu .shape )
300
+ for idx in np .ndindex (mu .shape [:- 1 ]):
301
+ m = mu [idx ]
302
+ c = cov [idx ]
303
+ res [idx ] = safe_multivariate_t (nu , m , c , rng = rng )
304
+ return res
305
+ else :
306
+ return safe_multivariate_t (nu , mu , cov , size = size , rng = rng )
307
+
308
+
309
+ mv_studentt = MvStudentTRV ()
310
+
311
+
251
312
class MvStudentT (Continuous ):
252
313
r"""
253
314
Multivariate Student-T log-likelihood.
@@ -288,55 +349,20 @@ class MvStudentT(Continuous):
288
349
lower: bool, default=True
289
350
Whether the cholesky fatcor is given as a lower triangular matrix.
290
351
"""
352
+ rv_op = mv_studentt
291
353
292
- def __init__ (
293
- self , nu , Sigma = None , mu = None , cov = None , tau = None , chol = None , lower = True , * args , ** kwargs
294
- ):
354
+ @classmethod
355
+ def dist (cls , nu , Sigma = None , mu = None , cov = None , tau = None , chol = None , lower = True , ** kwargs ):
295
356
if Sigma is not None :
296
357
if cov is not None :
297
358
raise ValueError ("Specify only one of cov and Sigma" )
298
359
cov = Sigma
299
- super ().__init__ (mu = mu , cov = cov , tau = tau , chol = chol , lower = lower , * args , ** kwargs )
300
- self .nu = nu = at .as_tensor_variable (nu )
301
- self .mean = self .median = self .mode = self .mu = self .mu
302
-
303
- def random (self , point = None , size = None ):
304
- """
305
- Draw random values from Multivariate Student's T distribution.
306
-
307
- Parameters
308
- ----------
309
- point: dict, optional
310
- Dict of variable values on which random values are to be
311
- conditioned (uses default point if not specified).
312
- size: int, optional
313
- Desired size of random sample (returns one sample if not
314
- specified).
315
-
316
- Returns
317
- -------
318
- array
319
- """
320
- # with _DrawValuesContext():
321
- # nu, mu = draw_values([self.nu, self.mu], point=point, size=size)
322
- # if self._cov_type == "cov":
323
- # (cov,) = draw_values([self.cov], point=point, size=size)
324
- # dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov, shape=self.shape)
325
- # elif self._cov_type == "tau":
326
- # (tau,) = draw_values([self.tau], point=point, size=size)
327
- # dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau, shape=self.shape)
328
- # else:
329
- # (chol,) = draw_values([self.chol_cov], point=point, size=size)
330
- # dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol, shape=self.shape)
331
- #
332
- # samples = dist.random(point, size)
333
- #
334
- # chi2_samples = np.random.chisquare(nu, size)
335
- # # Add distribution shape to chi2 samples
336
- # chi2_samples = chi2_samples.reshape(chi2_samples.shape + (1,) * len(self.shape))
337
- # return (samples / np.sqrt(chi2_samples / nu)) + mu
360
+ nu = at .as_tensor_variable (nu )
361
+ mu = at .as_tensor_variable (mu )
362
+ cov = quaddist_matrix (cov , chol , tau , lower )
363
+ return super ().dist ([nu , mu , cov ], ** kwargs )
338
364
339
- def logp (value , nu , cov ):
365
+ def logp (value , nu , mu , cov ):
340
366
"""
341
367
Calculate log-probability of Multivariate Student's T distribution
342
368
at specified value.
@@ -350,15 +376,15 @@ def logp(value, nu, cov):
350
376
-------
351
377
TensorVariable
352
378
"""
353
- quaddist , logdet , ok = quaddist_parse (value , nu , cov )
379
+ quaddist , logdet , ok = quaddist_parse (value , mu , cov )
354
380
k = floatX (value .shape [- 1 ])
355
381
356
382
norm = gammaln ((nu + k ) / 2.0 ) - gammaln (nu / 2.0 ) - 0.5 * k * floatX (np .log (nu * np .pi ))
357
383
inner = - (nu + k ) / 2.0 * at .log1p (quaddist / nu )
358
384
return bound (norm + inner - logdet , ok )
359
385
360
386
def _distr_parameters_for_repr (self ):
361
- return ["mu " , "nu " , "cov" ]
387
+ return ["nu " , "mu " , "cov" ]
362
388
363
389
364
390
class Dirichlet (Continuous ):
0 commit comments