Skip to content

Commit a553c55

Browse files
committed
simplify rng_fn method and refactor tests
1 parent 8eb5623 commit a553c55

File tree

2 files changed

+26
-35
lines changed

2 files changed

+26
-35
lines changed

pymc3/distributions/multivariate.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -249,19 +249,6 @@ def _distr_parameters_for_repr(self):
249249
return ["mu", "cov"]
250250

251251

252-
def safe_multivariate_t(nu, mu, cov, size=None, rng=None):
253-
res = np.atleast_1d(
254-
stats.multivariate_t(loc=mu, shape=cov, df=nu, allow_singular=True).rvs(
255-
size=size, random_state=rng
256-
)
257-
)
258-
259-
if size is not None:
260-
res = res.reshape(list(size) + [-1])
261-
262-
return res
263-
264-
265252
class MvStudentTRV(RandomVariable):
266253
name = "multivariate_studentt"
267254
ndim_supp = 1
@@ -285,25 +272,22 @@ def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
285272
@classmethod
286273
def rng_fn(cls, rng, nu, mu, cov, size):
287274

288-
if mu.ndim > 1 or cov.ndim > 2:
289-
# Neither SciPy nor NumPy implement parameter broadcasting for
290-
# multivariate normals (or many other multivariate distributions),
291-
# so we have implement a quick and dirty one here
292-
mu, cov = broadcast_params([mu, cov], cls.ndims_params[1:])
293-
size = tuple(size or ())
275+
# Don't reassign broadcasted cov, since MvNormal expects two dimensional cov only.
276+
mu, _ = broadcast_params([mu, cov], cls.ndims_params[1:])
294277

295-
if size:
296-
mu = np.broadcast_to(mu, size + mu.shape)
297-
cov = np.broadcast_to(cov, size + cov.shape)
298-
299-
res = np.empty(mu.shape)
300-
for idx in np.ndindex(mu.shape[:-1]):
301-
m = mu[idx]
302-
c = cov[idx]
303-
res[idx] = safe_multivariate_t(nu, m, c, rng=rng)
304-
return res
305-
else:
306-
return safe_multivariate_t(nu, mu, cov, size=size, rng=rng)
278+
chi2_samples = rng.chisquare(nu, size=size)
279+
# Add distribution shape to chi2 samples
280+
chi2_samples = chi2_samples.reshape(chi2_samples.shape + (1,) * len(mu.shape))
281+
282+
mv_samples = pm.MvNormal.dist(
283+
mu=np.zeros_like(mu), cov=cov, size=size, rng=aesara.shared(rng)
284+
).eval()
285+
286+
size = tuple(size or ())
287+
if size:
288+
mu = np.broadcast_to(mu, size + mu.shape)
289+
290+
return (mv_samples / np.sqrt(chi2_samples / nu)) + mu
307291

308292

309293
mv_studentt = MvStudentTRV()

pymc3/tests/test_distributions_random.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,11 @@ class TestMvNormalTau(BaseTestDistribution):
894894

895895

896896
class TestMvStudentTCov(BaseTestDistribution):
897+
def mvstudentt_rng_fn(self, size, nu, mu, cov, rng):
898+
chi2_samples = rng.chisquare(nu, size=size)
899+
mv_samples = rng.multivariate_normal(np.zeros_like(mu), cov, size=size)
900+
return (mv_samples / np.sqrt(chi2_samples[:, None] / nu)) + mu
901+
897902
pymc_dist = pm.MvStudentT
898903
pymc_dist_params = {
899904
"nu": 5,
@@ -908,11 +913,13 @@ class TestMvStudentTCov(BaseTestDistribution):
908913
sizes_to_check = [None, (1), (2, 3)]
909914
sizes_expected = [(2,), (1, 2), (2, 3, 2)]
910915
reference_dist_params = {
911-
"df": 5,
912-
"loc": np.array([1.0, 2.0]),
913-
"shape": np.array([[2.0, 0.0], [0.0, 3.5]]),
916+
"nu": 5,
917+
"mu": np.array([1.0, 2.0]),
918+
"cov": np.array([[2.0, 0.0], [0.0, 3.5]]),
914919
}
915-
reference_dist = seeded_scipy_distribution_builder("multivariate_t")
920+
reference_dist = lambda self: functools.partial(
921+
self.mvstudentt_rng_fn, rng=self.get_random_state()
922+
)
916923
tests_to_run = [
917924
"check_pymc_params_match_rv_op",
918925
"check_pymc_draws_match_reference",

0 commit comments

Comments
 (0)