Skip to content

Commit f8a5b91

Browse files
Sayam753ricardoV94
andauthored
Make MvStudentT distribution v4 compatible (#4731)
* Make MvStudentT distribution v4 compatible * Enable pymc3/tests/test_distributions.py::TestBugfixes::test_issue_3051 * refactor tests * simplify rng_fn method and refactor tests * Refactor rng method to compute sqrt over chi2_samples, one step earlier * Modify rng method to use aesara multivariate_normal.rng_fn * Add suggestions Co-authored-by: Ricardo <[email protected]> * Add a check against negative support of degrees of freedom Co-authored-by: Ricardo <[email protected]> Co-authored-by: Ricardo <[email protected]>
1 parent 02a973b commit f8a5b91

File tree

3 files changed

+127
-72
lines changed

3 files changed

+127
-72
lines changed

pymc3/distributions/multivariate.py

Lines changed: 58 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from aesara.tensor import gammaln
2828
from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace
2929
from aesara.tensor.random.basic import MultinomialRV, dirichlet, multivariate_normal
30+
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
3031
from aesara.tensor.random.utils import broadcast_params
3132
from aesara.tensor.slinalg import (
3233
Cholesky,
@@ -41,7 +42,7 @@
4142

4243
from pymc3.aesaraf import floatX, intX
4344
from pymc3.distributions import transforms
44-
from pymc3.distributions.continuous import ChiSquared, Normal
45+
from pymc3.distributions.continuous import ChiSquared, Normal, assert_negative_support
4546
from pymc3.distributions.dist_math import bound, factln, logpow, multigammaln
4647
from pymc3.distributions.distribution import Continuous, Discrete
4748
from pymc3.math import kron_diag, kron_dot, kron_solve_lower, kronecker
@@ -248,6 +249,48 @@ def _distr_parameters_for_repr(self):
248249
return ["mu", "cov"]
249250

250251

252+
class MvStudentTRV(RandomVariable):
253+
name = "multivariate_studentt"
254+
ndim_supp = 1
255+
ndims_params = [0, 1, 2]
256+
dtype = "floatX"
257+
_print_name = ("MvStudentT", "\\operatorname{MvStudentT}")
258+
259+
def __call__(self, nu, mu=None, cov=None, size=None, **kwargs):
260+
261+
dtype = aesara.config.floatX if self.dtype == "floatX" else self.dtype
262+
263+
if mu is None:
264+
mu = np.array([0.0], dtype=dtype)
265+
if cov is None:
266+
cov = np.array([[1.0]], dtype=dtype)
267+
return super().__call__(nu, mu, cov, size=size, **kwargs)
268+
269+
def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
270+
return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes)
271+
272+
@classmethod
273+
def rng_fn(cls, rng, nu, mu, cov, size):
274+
275+
# Don't reassign broadcasted cov, since MvNormal expects two dimensional cov only.
276+
mu, _ = broadcast_params([mu, cov], cls.ndims_params[1:])
277+
278+
chi2_samples = np.sqrt(rng.chisquare(nu, size=size) / nu)
279+
# Add distribution shape to chi2 samples
280+
chi2_samples = chi2_samples.reshape(chi2_samples.shape + (1,) * len(mu.shape))
281+
282+
mv_samples = multivariate_normal.rng_fn(rng=rng, mean=np.zeros_like(mu), cov=cov, size=size)
283+
284+
size = tuple(size or ())
285+
if size:
286+
mu = np.broadcast_to(mu, size + mu.shape)
287+
288+
return (mv_samples / chi2_samples) + mu
289+
290+
291+
mv_studentt = MvStudentTRV()
292+
293+
251294
class MvStudentT(Continuous):
252295
r"""
253296
Multivariate Student-T log-likelihood.
@@ -273,8 +316,8 @@ class MvStudentT(Continuous):
273316
274317
Parameters
275318
----------
276-
nu: int
277-
Degrees of freedom.
319+
nu: float
320+
Degrees of freedom, should be a positive scalar.
278321
Sigma: matrix
279322
Covariance matrix. Use `cov` in new code.
280323
mu: array
@@ -288,55 +331,21 @@ class MvStudentT(Continuous):
288331
lower: bool, default=True
289332
Whether the cholesky fatcor is given as a lower triangular matrix.
290333
"""
334+
rv_op = mv_studentt
291335

292-
def __init__(
293-
self, nu, Sigma=None, mu=None, cov=None, tau=None, chol=None, lower=True, *args, **kwargs
294-
):
336+
@classmethod
337+
def dist(cls, nu, Sigma=None, mu=None, cov=None, tau=None, chol=None, lower=True, **kwargs):
295338
if Sigma is not None:
296339
if cov is not None:
297340
raise ValueError("Specify only one of cov and Sigma")
298341
cov = Sigma
299-
super().__init__(mu=mu, cov=cov, tau=tau, chol=chol, lower=lower, *args, **kwargs)
300-
self.nu = nu = at.as_tensor_variable(nu)
301-
self.mean = self.median = self.mode = self.mu = self.mu
302-
303-
def random(self, point=None, size=None):
304-
"""
305-
Draw random values from Multivariate Student's T distribution.
306-
307-
Parameters
308-
----------
309-
point: dict, optional
310-
Dict of variable values on which random values are to be
311-
conditioned (uses default point if not specified).
312-
size: int, optional
313-
Desired size of random sample (returns one sample if not
314-
specified).
315-
316-
Returns
317-
-------
318-
array
319-
"""
320-
# with _DrawValuesContext():
321-
# nu, mu = draw_values([self.nu, self.mu], point=point, size=size)
322-
# if self._cov_type == "cov":
323-
# (cov,) = draw_values([self.cov], point=point, size=size)
324-
# dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov, shape=self.shape)
325-
# elif self._cov_type == "tau":
326-
# (tau,) = draw_values([self.tau], point=point, size=size)
327-
# dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau, shape=self.shape)
328-
# else:
329-
# (chol,) = draw_values([self.chol_cov], point=point, size=size)
330-
# dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol, shape=self.shape)
331-
#
332-
# samples = dist.random(point, size)
333-
#
334-
# chi2_samples = np.random.chisquare(nu, size)
335-
# # Add distribution shape to chi2 samples
336-
# chi2_samples = chi2_samples.reshape(chi2_samples.shape + (1,) * len(self.shape))
337-
# return (samples / np.sqrt(chi2_samples / nu)) + mu
342+
nu = at.as_tensor_variable(floatX(nu))
343+
mu = at.as_tensor_variable(floatX(mu))
344+
cov = quaddist_matrix(cov, chol, tau, lower)
345+
assert_negative_support(nu, "nu", "MvStudentT")
346+
return super().dist([nu, mu, cov], **kwargs)
338347

339-
def logp(value, nu, cov):
348+
def logp(value, nu, mu, cov):
340349
"""
341350
Calculate log-probability of Multivariate Student's T distribution
342351
at specified value.
@@ -350,15 +359,15 @@ def logp(value, nu, cov):
350359
-------
351360
TensorVariable
352361
"""
353-
quaddist, logdet, ok = quaddist_parse(value, nu, cov)
362+
quaddist, logdet, ok = quaddist_parse(value, mu, cov)
354363
k = floatX(value.shape[-1])
355364

356-
norm = gammaln((nu + k) / 2.0) - gammaln(nu / 2.0) - 0.5 * k * floatX(np.log(nu * np.pi))
365+
norm = gammaln((nu + k) / 2.0) - gammaln(nu / 2.0) - 0.5 * k * at.log(nu * np.pi)
357366
inner = -(nu + k) / 2.0 * at.log1p(quaddist / nu)
358367
return bound(norm + inner - logdet, ok)
359368

360369
def _distr_parameters_for_repr(self):
361-
return ["mu", "nu", "cov"]
370+
return ["nu", "mu", "cov"]
362371

363372

364373
class Dirichlet(Continuous):

pymc3/tests/test_distributions.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2017,7 +2017,6 @@ def test_kroneckernormal(self, n, m, sigma):
20172017
)
20182018

20192019
@pytest.mark.parametrize("n", [1, 2])
2020-
@pytest.mark.xfail(reason="Distribution not refactored yet")
20212020
def test_mvt(self, n):
20222021
self.check_logp(
20232022
MvStudentT,
@@ -2030,6 +2029,7 @@ def test_mvt(self, n):
20302029
RealMatrix(2, n),
20312030
{"nu": Rplus, "Sigma": PdMatrix(n), "mu": Vector(R, n)},
20322031
mvt_logpdf,
2032+
extra_args={"size": 2},
20332033
)
20342034

20352035
@pytest.mark.parametrize("n", [2, 3, 4])
@@ -2936,13 +2936,11 @@ def test_car_logp(size):
29362936

29372937

29382938
class TestBugfixes:
2939-
@pytest.mark.parametrize(
2940-
"dist_cls,kwargs", [(MvNormal, dict(mu=0)), (MvStudentT, dict(mu=0, nu=2))]
2941-
)
2939+
@pytest.mark.parametrize("dist_cls,kwargs", [(MvNormal, dict()), (MvStudentT, dict(nu=2))])
29422940
@pytest.mark.parametrize("dims", [1, 2, 4])
2943-
@pytest.mark.xfail(reason="Distribution not refactored yet")
29442941
def test_issue_3051(self, dims, dist_cls, kwargs):
2945-
d = dist_cls.dist(**kwargs, cov=np.eye(dims), size=(dims,))
2942+
mu = np.repeat(0, dims)
2943+
d = dist_cls.dist(mu=mu, cov=np.eye(dims), **kwargs, size=(20))
29462944

29472945
X = np.random.normal(size=(20, dims))
29482946
actual_t = logpt(d, X)

pymc3/tests/test_distributions_random.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,7 @@ class TestPoisson(BaseTestDistribution):
889889
tests_to_run = ["check_pymc_params_match_rv_op"]
890890

891891

892-
class TestMvNormal(BaseTestDistribution):
892+
class TestMvNormalCov(BaseTestDistribution):
893893
pymc_dist = pm.MvNormal
894894
pymc_dist_params = {
895895
"mu": np.array([1.0, 2.0]),
@@ -939,6 +939,70 @@ class TestMvNormalTau(BaseTestDistribution):
939939
tests_to_run = ["check_pymc_params_match_rv_op"]
940940

941941

942+
class TestMvStudentTCov(BaseTestDistribution):
943+
def mvstudentt_rng_fn(self, size, nu, mu, cov, rng):
944+
chi2_samples = rng.chisquare(nu, size=size)
945+
mv_samples = rng.multivariate_normal(np.zeros_like(mu), cov, size=size)
946+
return (mv_samples / np.sqrt(chi2_samples[:, None] / nu)) + mu
947+
948+
pymc_dist = pm.MvStudentT
949+
pymc_dist_params = {
950+
"nu": 5,
951+
"mu": np.array([1.0, 2.0]),
952+
"cov": np.array([[2.0, 0.0], [0.0, 3.5]]),
953+
}
954+
expected_rv_op_params = {
955+
"nu": 5,
956+
"mu": np.array([1.0, 2.0]),
957+
"cov": np.array([[2.0, 0.0], [0.0, 3.5]]),
958+
}
959+
sizes_to_check = [None, (1), (2, 3)]
960+
sizes_expected = [(2,), (1, 2), (2, 3, 2)]
961+
reference_dist_params = {
962+
"nu": 5,
963+
"mu": np.array([1.0, 2.0]),
964+
"cov": np.array([[2.0, 0.0], [0.0, 3.5]]),
965+
}
966+
reference_dist = lambda self: functools.partial(
967+
self.mvstudentt_rng_fn, rng=self.get_random_state()
968+
)
969+
tests_to_run = [
970+
"check_pymc_params_match_rv_op",
971+
"check_pymc_draws_match_reference",
972+
"check_rv_size",
973+
]
974+
975+
976+
class TestMvStudentTChol(BaseTestDistribution):
977+
pymc_dist = pm.MvStudentT
978+
pymc_dist_params = {
979+
"nu": 5,
980+
"mu": np.array([1.0, 2.0]),
981+
"chol": np.array([[2.0, 0.0], [0.0, 3.5]]),
982+
}
983+
expected_rv_op_params = {
984+
"nu": 5,
985+
"mu": np.array([1.0, 2.0]),
986+
"cov": quaddist_matrix(chol=pymc_dist_params["chol"]).eval(),
987+
}
988+
tests_to_run = ["check_pymc_params_match_rv_op"]
989+
990+
991+
class TestMvStudentTTau(BaseTestDistribution):
992+
pymc_dist = pm.MvStudentT
993+
pymc_dist_params = {
994+
"nu": 5,
995+
"mu": np.array([1.0, 2.0]),
996+
"tau": np.array([[2.0, 0.0], [0.0, 3.5]]),
997+
}
998+
expected_rv_op_params = {
999+
"nu": 5,
1000+
"mu": np.array([1.0, 2.0]),
1001+
"cov": quaddist_matrix(tau=pymc_dist_params["tau"]).eval(),
1002+
}
1003+
tests_to_run = ["check_pymc_params_match_rv_op"]
1004+
1005+
9421006
class TestDirichlet(BaseTestDistribution):
9431007
pymc_dist = pm.Dirichlet
9441008
pymc_dist_params = {"a": np.array([1.0, 2.0])}
@@ -1471,22 +1535,6 @@ def ref_rand_evd(size, mu, evds, sigma):
14711535
model_args=evd_args,
14721536
)
14731537

1474-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
1475-
def test_mv_t(self):
1476-
def ref_rand(size, nu, Sigma, mu):
1477-
normal = st.multivariate_normal.rvs(cov=Sigma, size=size)
1478-
chi2 = st.chi2.rvs(df=nu, size=size)[..., None]
1479-
return mu + (normal / np.sqrt(chi2 / nu))
1480-
1481-
for n in [2, 3]:
1482-
pymc3_random(
1483-
pm.MvStudentT,
1484-
{"nu": Domain([5, 10, 25, 50]), "Sigma": PdMatrix(n), "mu": Vector(R, n)},
1485-
size=100,
1486-
valuedomain=Vector(R, n),
1487-
ref_rand=ref_rand,
1488-
)
1489-
14901538
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
14911539
def test_dirichlet_multinomial(self):
14921540
def ref_rand(size, a, n):

0 commit comments

Comments
 (0)