Skip to content

Commit 8eb5623

Browse files
committed
refactor tests
1 parent 1425508 commit 8eb5623

File tree

2 files changed

+61
-24
lines changed

2 files changed

+61
-24
lines changed

pymc3/tests/test_distributions.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2937,16 +2937,11 @@ def test_car_logp(size):
29372937

29382938

29392939
class TestBugfixes:
2940-
@pytest.mark.parametrize(
2941-
"dist_cls,kwargs", [(MvNormal, dict(mu=0)), (MvStudentT, dict(mu=0, nu=2))]
2942-
)
2940+
@pytest.mark.parametrize("dist_cls,kwargs", [(MvNormal, dict()), (MvStudentT, dict(nu=2))])
29432941
@pytest.mark.parametrize("dims", [1, 2, 4])
29442942
def test_issue_3051(self, dims, dist_cls, kwargs):
2945-
mu = np.repeat(kwargs["mu"], dims)
2946-
if "nu" in kwargs:
2947-
d = dist_cls.dist(nu=kwargs["nu"], mu=mu, cov=np.eye(dims), size=(20))
2948-
else:
2949-
d = dist_cls.dist(mu=mu, cov=np.eye(dims), size=(20))
2943+
mu = np.repeat(0, dims)
2944+
d = dist_cls.dist(mu=mu, cov=np.eye(dims), **kwargs, size=(20))
29502945

29512946
X = np.random.normal(size=(20, dims))
29522947
actual_t = logpt(d, X)

pymc3/tests/test_distributions_random.py

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,7 @@ class TestPoisson(BaseTestDistribution):
843843
tests_to_run = ["check_pymc_params_match_rv_op"]
844844

845845

846-
class TestMvNormal(BaseTestDistribution):
846+
class TestMvNormalCov(BaseTestDistribution):
847847
pymc_dist = pm.MvNormal
848848
pymc_dist_params = {
849849
"mu": np.array([1.0, 2.0]),
@@ -893,6 +893,63 @@ class TestMvNormalTau(BaseTestDistribution):
893893
tests_to_run = ["check_pymc_params_match_rv_op"]
894894

895895

896+
class TestMvStudentTCov(BaseTestDistribution):
897+
pymc_dist = pm.MvStudentT
898+
pymc_dist_params = {
899+
"nu": 5,
900+
"mu": np.array([1.0, 2.0]),
901+
"cov": np.array([[2.0, 0.0], [0.0, 3.5]]),
902+
}
903+
expected_rv_op_params = {
904+
"nu": 5,
905+
"mu": np.array([1.0, 2.0]),
906+
"cov": np.array([[2.0, 0.0], [0.0, 3.5]]),
907+
}
908+
sizes_to_check = [None, (1), (2, 3)]
909+
sizes_expected = [(2,), (1, 2), (2, 3, 2)]
910+
reference_dist_params = {
911+
"df": 5,
912+
"loc": np.array([1.0, 2.0]),
913+
"shape": np.array([[2.0, 0.0], [0.0, 3.5]]),
914+
}
915+
reference_dist = seeded_scipy_distribution_builder("multivariate_t")
916+
tests_to_run = [
917+
"check_pymc_params_match_rv_op",
918+
"check_pymc_draws_match_reference",
919+
"check_rv_size",
920+
]
921+
922+
923+
class TestMvStudentTChol(BaseTestDistribution):
924+
pymc_dist = pm.MvStudentT
925+
pymc_dist_params = {
926+
"nu": 5,
927+
"mu": np.array([1.0, 2.0]),
928+
"chol": np.array([[2.0, 0.0], [0.0, 3.5]]),
929+
}
930+
expected_rv_op_params = {
931+
"nu": 5,
932+
"mu": np.array([1.0, 2.0]),
933+
"cov": quaddist_matrix(chol=pymc_dist_params["chol"]).eval(),
934+
}
935+
tests_to_run = ["check_pymc_params_match_rv_op"]
936+
937+
938+
class TestMvStudentTTau(BaseTestDistribution):
939+
pymc_dist = pm.MvStudentT
940+
pymc_dist_params = {
941+
"nu": 5,
942+
"mu": np.array([1.0, 2.0]),
943+
"tau": np.array([[2.0, 0.0], [0.0, 3.5]]),
944+
}
945+
expected_rv_op_params = {
946+
"nu": 5,
947+
"mu": np.array([1.0, 2.0]),
948+
"cov": quaddist_matrix(tau=pymc_dist_params["tau"]).eval(),
949+
}
950+
tests_to_run = ["check_pymc_params_match_rv_op"]
951+
952+
896953
class TestDirichlet(BaseTestDistribution):
897954
pymc_dist = pm.Dirichlet
898955
pymc_dist_params = {"a": np.array([1.0, 2.0])}
@@ -1402,21 +1459,6 @@ def ref_rand_evd(size, mu, evds, sigma):
14021459
model_args=evd_args,
14031460
)
14041461

1405-
def test_mv_t(self):
1406-
def ref_rand(size, nu, Sigma, mu):
1407-
normal = st.multivariate_normal.rvs(cov=Sigma, size=size)
1408-
chi2 = st.chi2.rvs(df=nu, size=size)[..., None]
1409-
return mu + (normal / np.sqrt(chi2 / nu))
1410-
1411-
for n in [2, 3]:
1412-
pymc3_random(
1413-
pm.MvStudentT,
1414-
{"nu": Domain([5, 10, 25, 50]), "Sigma": PdMatrix(n), "mu": Vector(R, n)},
1415-
size=100,
1416-
valuedomain=Vector(R, n),
1417-
ref_rand=ref_rand,
1418-
)
1419-
14201462
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
14211463
def test_dirichlet_multinomial(self):
14221464
def ref_rand(size, a, n):

0 commit comments

Comments
 (0)