Skip to content

Commit e1d9c69

Browse files
committed
Make MvStudentT distribution v4 compatible
1 parent 75c2e1e commit e1d9c69

File tree

3 files changed

+72
-47
lines changed

3 files changed

+72
-47
lines changed

pymc3/distributions/multivariate.py

Lines changed: 71 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from aesara.tensor import gammaln
2828
from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace
2929
from aesara.tensor.random.basic import MultinomialRV, dirichlet, multivariate_normal
30+
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
3031
from aesara.tensor.random.utils import broadcast_params
3132
from aesara.tensor.slinalg import (
3233
Cholesky,
@@ -248,6 +249,66 @@ def _distr_parameters_for_repr(self):
248249
return ["mu", "cov"]
249250

250251

252+
def safe_multivariate_t(nu, mu, cov, size=None, rng=None):
253+
res = np.atleast_1d(
254+
stats.multivariate_t(loc=mu, shape=cov, df=nu, allow_singular=True).rvs(
255+
size=size, random_state=rng
256+
)
257+
)
258+
259+
if size is not None:
260+
res = res.reshape(list(size) + [-1])
261+
262+
return res
263+
264+
265+
class MvStudentTRV(RandomVariable):
266+
name = "multivariate_studentt"
267+
ndim_supp = 1
268+
ndims_params = [0, 1, 2]
269+
dtype = "floatX"
270+
_print_name = ("MvStudentT", "\\operatorname{MvStudentT}")
271+
272+
def __call__(self, nu, mu=None, cov=None, size=None, **kwargs):
273+
274+
dtype = aesara.config.floatX if self.dtype == "floatX" else self.dtype
275+
276+
if mu is None:
277+
mu = np.array([0.0], dtype=dtype)
278+
if cov is None:
279+
cov = np.array([[1.0]], dtype=dtype)
280+
return super().__call__(nu, mu, cov, size=size, **kwargs)
281+
282+
def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
283+
return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes)
284+
285+
@classmethod
286+
def rng_fn(cls, rng, nu, mu, cov, size):
287+
288+
if mu.ndim > 1 or cov.ndim > 2:
289+
# Neither SciPy nor NumPy implement parameter broadcasting for
290+
# multivariate normals (or many other multivariate distributions),
291+
# so we have implement a quick and dirty one here
292+
mu, cov = broadcast_params([mu, cov], cls.ndims_params[1:])
293+
size = tuple(size or ())
294+
295+
if size:
296+
mu = np.broadcast_to(mu, size + mu.shape)
297+
cov = np.broadcast_to(cov, size + cov.shape)
298+
299+
res = np.empty(mu.shape)
300+
for idx in np.ndindex(mu.shape[:-1]):
301+
m = mu[idx]
302+
c = cov[idx]
303+
res[idx] = safe_multivariate_t(nu, m, c, rng=rng)
304+
return res
305+
else:
306+
return safe_multivariate_t(nu, mu, cov, size=size, rng=rng)
307+
308+
309+
mv_studentt = MvStudentTRV()
310+
311+
251312
class MvStudentT(Continuous):
252313
r"""
253314
Multivariate Student-T log-likelihood.
@@ -288,55 +349,20 @@ class MvStudentT(Continuous):
288349
lower: bool, default=True
289350
Whether the cholesky fatcor is given as a lower triangular matrix.
290351
"""
352+
rv_op = mv_studentt
291353

292-
def __init__(
293-
self, nu, Sigma=None, mu=None, cov=None, tau=None, chol=None, lower=True, *args, **kwargs
294-
):
354+
@classmethod
355+
def dist(cls, nu, Sigma=None, mu=None, cov=None, tau=None, chol=None, lower=True, **kwargs):
295356
if Sigma is not None:
296357
if cov is not None:
297358
raise ValueError("Specify only one of cov and Sigma")
298359
cov = Sigma
299-
super().__init__(mu=mu, cov=cov, tau=tau, chol=chol, lower=lower, *args, **kwargs)
300-
self.nu = nu = at.as_tensor_variable(nu)
301-
self.mean = self.median = self.mode = self.mu = self.mu
302-
303-
def random(self, point=None, size=None):
304-
"""
305-
Draw random values from Multivariate Student's T distribution.
306-
307-
Parameters
308-
----------
309-
point: dict, optional
310-
Dict of variable values on which random values are to be
311-
conditioned (uses default point if not specified).
312-
size: int, optional
313-
Desired size of random sample (returns one sample if not
314-
specified).
315-
316-
Returns
317-
-------
318-
array
319-
"""
320-
# with _DrawValuesContext():
321-
# nu, mu = draw_values([self.nu, self.mu], point=point, size=size)
322-
# if self._cov_type == "cov":
323-
# (cov,) = draw_values([self.cov], point=point, size=size)
324-
# dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov, shape=self.shape)
325-
# elif self._cov_type == "tau":
326-
# (tau,) = draw_values([self.tau], point=point, size=size)
327-
# dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau, shape=self.shape)
328-
# else:
329-
# (chol,) = draw_values([self.chol_cov], point=point, size=size)
330-
# dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol, shape=self.shape)
331-
#
332-
# samples = dist.random(point, size)
333-
#
334-
# chi2_samples = np.random.chisquare(nu, size)
335-
# # Add distribution shape to chi2 samples
336-
# chi2_samples = chi2_samples.reshape(chi2_samples.shape + (1,) * len(self.shape))
337-
# return (samples / np.sqrt(chi2_samples / nu)) + mu
360+
nu = at.as_tensor_variable(nu)
361+
mu = at.as_tensor_variable(mu)
362+
cov = quaddist_matrix(cov, chol, tau, lower)
363+
return super().dist([nu, mu, cov], **kwargs)
338364

339-
def logp(value, nu, cov):
365+
def logp(value, nu, mu, cov):
340366
"""
341367
Calculate log-probability of Multivariate Student's T distribution
342368
at specified value.
@@ -350,15 +376,15 @@ def logp(value, nu, cov):
350376
-------
351377
TensorVariable
352378
"""
353-
quaddist, logdet, ok = quaddist_parse(value, nu, cov)
379+
quaddist, logdet, ok = quaddist_parse(value, mu, cov)
354380
k = floatX(value.shape[-1])
355381

356382
norm = gammaln((nu + k) / 2.0) - gammaln(nu / 2.0) - 0.5 * k * floatX(np.log(nu * np.pi))
357383
inner = -(nu + k) / 2.0 * at.log1p(quaddist / nu)
358384
return bound(norm + inner - logdet, ok)
359385

360386
def _distr_parameters_for_repr(self):
361-
return ["mu", "nu", "cov"]
387+
return ["nu", "mu", "cov"]
362388

363389

364390
class Dirichlet(Continuous):

pymc3/tests/test_distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2019,7 +2019,6 @@ def test_kroneckernormal(self, n, m, sigma):
20192019
)
20202020

20212021
@pytest.mark.parametrize("n", [1, 2])
2022-
@pytest.mark.xfail(reason="Distribution not refactored yet")
20232022
def test_mvt(self, n):
20242023
self.check_logp(
20252024
MvStudentT,
@@ -2032,6 +2031,7 @@ def test_mvt(self, n):
20322031
RealMatrix(2, n),
20332032
{"nu": Rplus, "Sigma": PdMatrix(n), "mu": Vector(R, n)},
20342033
mvt_logpdf,
2034+
extra_args={"size": 2},
20352035
)
20362036

20372037
@pytest.mark.parametrize("n", [2, 3, 4])

pymc3/tests/test_distributions_random.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1402,7 +1402,6 @@ def ref_rand_evd(size, mu, evds, sigma):
14021402
model_args=evd_args,
14031403
)
14041404

1405-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
14061405
def test_mv_t(self):
14071406
def ref_rand(size, nu, Sigma, mu):
14081407
normal = st.multivariate_normal.rvs(cov=Sigma, size=size)

0 commit comments

Comments
 (0)