Skip to content

Commit 069533f

Browse files
Change tests for refactored distributions
More details can be found on issue #4554 #4554
1 parent 7bf2004 commit 069533f

File tree

1 file changed

+29
-83
lines changed

1 file changed

+29
-83
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 29 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
import scipy.stats as st
2525

2626
from numpy.testing import assert_almost_equal
27-
from scipy import linalg
2827
from scipy.special import expit
2928

3029
import pymc3 as pm
3130

3231
from pymc3.aesaraf import change_rv_size, floatX, intX
32+
from pymc3.distributions.multivariate import quaddist_matrix
3333
from pymc3.distributions.shape_utils import to_tuple
3434
from pymc3.exceptions import ShapeError
3535
from pymc3.tests.helpers import SeededTest
@@ -40,7 +40,6 @@
4040
NatSmall,
4141
PdMatrix,
4242
PdMatrixChol,
43-
PdMatrixCholUpper,
4443
R,
4544
RandomPdMatrix,
4645
RealMatrix,
@@ -639,6 +638,34 @@ def test_poisson(self):
639638
params = [("mu", 4)]
640639
self._pymc_params_match_rv_ones(params, params, pm.Poisson)
641640

641+
def test_mv_distribution(self):
642+
params = [("mu", np.array([1.0, 2.0])), ("cov", np.array([[2.0, 0.0], [0.0, 3.5]]))]
643+
self._pymc_params_match_rv_ones(params, params, pm.MvNormal)
644+
645+
def test_mv_distribution_chol(self):
646+
params = [("mu", np.array([1.0, 2.0])), ("chol", np.array([[2.0, 0.0], [0.0, 3.5]]))]
647+
expected_cov = quaddist_matrix(chol=params[1][1])
648+
expected_params = [("mu", np.array([1.0, 2.0])), ("cov", expected_cov.eval())]
649+
self._pymc_params_match_rv_ones(params, expected_params, pm.MvNormal)
650+
651+
def test_mv_distribution_tau(self):
652+
params = [("mu", np.array([1.0, 2.0])), ("tau", np.array([[2.0, 0.0], [0.0, 3.5]]))]
653+
expected_cov = quaddist_matrix(tau=params[1][1])
654+
expected_params = [("mu", np.array([1.0, 2.0])), ("cov", expected_cov.eval())]
655+
self._pymc_params_match_rv_ones(params, expected_params, pm.MvNormal)
656+
657+
def test_dirichlet(self):
658+
params = [("a", np.array([1.0, 2.0]))]
659+
self._pymc_params_match_rv_ones(params, params, pm.Dirichlet)
660+
661+
def test_multinomial(self):
662+
params = [("n", 85), ("p", np.array([0.28, 0.62, 0.10]))]
663+
self._pymc_params_match_rv_ones(params, params, pm.Multinomial)
664+
665+
def test_categorical(self):
666+
params = [("p", np.array([0.28, 0.62, 0.10]))]
667+
self._pymc_params_match_rv_ones(params, params, pm.Categorical)
668+
642669

643670
class TestScalarParameterSamples(SeededTest):
644671
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
@@ -835,66 +862,13 @@ def ref_rand(size, q, beta):
835862
pm.DiscreteWeibull, {"q": Unit, "beta": Rplusdunif}, ref_rand=ref_rand
836863
)
837864

838-
@pytest.mark.skip(reason="This test is covered by Aesara")
839-
@pytest.mark.parametrize("s", [2, 3, 4])
840-
def test_categorical_random(self, s):
841-
def ref_rand(size, p):
842-
return nr.choice(np.arange(p.shape[0]), p=p, size=size)
843-
844-
pymc3_random_discrete(pm.Categorical, {"p": Simplex(s)}, ref_rand=ref_rand)
845-
846865
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
847866
def test_constant_dist(self):
848867
def ref_rand(size, c):
849868
return c * np.ones(size, dtype=int)
850869

851870
pymc3_random_discrete(pm.Constant, {"c": I}, ref_rand=ref_rand)
852871

853-
@pytest.mark.skip(reason="This test is covered by Aesara")
854-
def test_mv_normal(self):
855-
def ref_rand(size, mu, cov):
856-
return st.multivariate_normal.rvs(mean=mu, cov=cov, size=size)
857-
858-
def ref_rand_tau(size, mu, tau):
859-
return ref_rand(size, mu, linalg.inv(tau))
860-
861-
def ref_rand_chol(size, mu, chol):
862-
return ref_rand(size, mu, np.dot(chol, chol.T))
863-
864-
def ref_rand_uchol(size, mu, chol):
865-
return ref_rand(size, mu, np.dot(chol.T, chol))
866-
867-
for n in [2, 3]:
868-
pymc3_random(
869-
pm.MvNormal,
870-
{"mu": Vector(R, n), "cov": PdMatrix(n)},
871-
size=100,
872-
valuedomain=Vector(R, n),
873-
ref_rand=ref_rand,
874-
)
875-
pymc3_random(
876-
pm.MvNormal,
877-
{"mu": Vector(R, n), "tau": PdMatrix(n)},
878-
size=100,
879-
valuedomain=Vector(R, n),
880-
ref_rand=ref_rand_tau,
881-
)
882-
pymc3_random(
883-
pm.MvNormal,
884-
{"mu": Vector(R, n), "chol": PdMatrixChol(n)},
885-
size=100,
886-
valuedomain=Vector(R, n),
887-
ref_rand=ref_rand_chol,
888-
)
889-
pymc3_random(
890-
pm.MvNormal,
891-
{"mu": Vector(R, n), "chol": PdMatrixCholUpper(n)},
892-
size=100,
893-
valuedomain=Vector(R, n),
894-
ref_rand=ref_rand_uchol,
895-
extra_args={"lower": False},
896-
)
897-
898872
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
899873
def test_matrix_normal(self):
900874
def ref_rand(size, mu, rowcov, colcov):
@@ -1037,20 +1011,6 @@ def ref_rand(size, nu, Sigma, mu):
10371011
ref_rand=ref_rand,
10381012
)
10391013

1040-
@pytest.mark.skip(reason="This test is covered by Aesara")
1041-
def test_dirichlet(self):
1042-
def ref_rand(size, a):
1043-
return st.dirichlet.rvs(a, size=size)
1044-
1045-
for n in [2, 3]:
1046-
pymc3_random(
1047-
pm.Dirichlet,
1048-
{"a": Vector(Rplus, n)},
1049-
valuedomain=Simplex(n),
1050-
size=100,
1051-
ref_rand=ref_rand,
1052-
)
1053-
10541014
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
10551015
def test_dirichlet_multinomial(self):
10561016
def ref_rand(size, a, n):
@@ -1118,20 +1078,6 @@ def test_dirichlet_multinomial_dist_ShapeError(self, n, a, shape, expectation):
11181078
with expectation:
11191079
m.random()
11201080

1121-
@pytest.mark.skip(reason="This test is covered by Aesara")
1122-
def test_multinomial(self):
1123-
def ref_rand(size, p, n):
1124-
return nr.multinomial(pvals=p, n=n, size=size)
1125-
1126-
for n in [2, 3]:
1127-
pymc3_random_discrete(
1128-
pm.Multinomial,
1129-
{"p": Simplex(n), "n": Nat},
1130-
valuedomain=Vector(Nat, n),
1131-
size=100,
1132-
ref_rand=ref_rand,
1133-
)
1134-
11351081
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
11361082
def test_gumbel(self):
11371083
def ref_rand(size, mu, beta):

0 commit comments

Comments
 (0)