Skip to content

Commit 7fed128

Browse files
Change tests for refactored distributions
More details can be found on issue pymc-devs#4554 pymc-devs#4554
1 parent 45180b1 commit 7fed128

File tree

1 file changed

+29
-83
lines changed

1 file changed

+29
-83
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 29 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525
import scipy.stats as st
2626

2727
from numpy.testing import assert_almost_equal
28-
from scipy import linalg
2928
from scipy.special import expit
3029

3130
import pymc3 as pm
3231

3332
from pymc3.aesaraf import change_rv_size, floatX, intX
33+
from pymc3.distributions.multivariate import quaddist_matrix
3434
from pymc3.distributions.shape_utils import to_tuple
3535
from pymc3.exceptions import ShapeError
3636
from pymc3.tests.helpers import SeededTest, select_by_precision
@@ -41,7 +41,6 @@
4141
NatSmall,
4242
PdMatrix,
4343
PdMatrixChol,
44-
PdMatrixCholUpper,
4544
R,
4645
RandomPdMatrix,
4746
RealMatrix,
@@ -627,6 +626,34 @@ def test_poisson(self):
627626
params = [("mu", 4)]
628627
self._pymc_params_match_rv_ones(params, params, pm.Poisson)
629628

629+
def test_mv_distribution(self):
630+
params = [("mu", np.array([1.0, 2.0])), ("cov", np.array([[2.0, 0.0], [0.0, 3.5]]))]
631+
self._pymc_params_match_rv_ones(params, params, pm.MvNormal)
632+
633+
def test_mv_distribution_chol(self):
634+
params = [("mu", np.array([1.0, 2.0])), ("chol", np.array([[2.0, 0.0], [0.0, 3.5]]))]
635+
expected_cov = quaddist_matrix(chol=params[1][1])
636+
expected_params = [("mu", np.array([1.0, 2.0])), ("cov", expected_cov.eval())]
637+
self._pymc_params_match_rv_ones(params, expected_params, pm.MvNormal)
638+
639+
def test_mv_distribution_tau(self):
640+
params = [("mu", np.array([1.0, 2.0])), ("tau", np.array([[2.0, 0.0], [0.0, 3.5]]))]
641+
expected_cov = quaddist_matrix(tau=params[1][1])
642+
expected_params = [("mu", np.array([1.0, 2.0])), ("cov", expected_cov.eval())]
643+
self._pymc_params_match_rv_ones(params, expected_params, pm.MvNormal)
644+
645+
def test_dirichlet(self):
646+
params = [("a", np.array([1.0, 2.0]))]
647+
self._pymc_params_match_rv_ones(params, params, pm.Dirichlet)
648+
649+
def test_multinomial(self):
650+
params = [("n", 85), ("p", np.array([0.28, 0.62, 0.10]))]
651+
self._pymc_params_match_rv_ones(params, params, pm.Multinomial)
652+
653+
def test_categorical(self):
654+
params = [("p", np.array([0.28, 0.62, 0.10]))]
655+
self._pymc_params_match_rv_ones(params, params, pm.Categorical)
656+
630657

631658
class TestScalarParameterSamples(SeededTest):
632659
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
@@ -815,66 +842,13 @@ def ref_rand(size, q, beta):
815842
pm.DiscreteWeibull, {"q": Unit, "beta": Rplusdunif}, ref_rand=ref_rand
816843
)
817844

818-
@pytest.mark.skip(reason="This test is covered by Aesara")
819-
@pytest.mark.parametrize("s", [2, 3, 4])
820-
def test_categorical_random(self, s):
821-
def ref_rand(size, p):
822-
return nr.choice(np.arange(p.shape[0]), p=p, size=size)
823-
824-
pymc3_random_discrete(pm.Categorical, {"p": Simplex(s)}, ref_rand=ref_rand)
825-
826845
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
827846
def test_constant_dist(self):
828847
def ref_rand(size, c):
829848
return c * np.ones(size, dtype=int)
830849

831850
pymc3_random_discrete(pm.Constant, {"c": I}, ref_rand=ref_rand)
832851

833-
@pytest.mark.skip(reason="This test is covered by Aesara")
834-
def test_mv_normal(self):
835-
def ref_rand(size, mu, cov):
836-
return st.multivariate_normal.rvs(mean=mu, cov=cov, size=size)
837-
838-
def ref_rand_tau(size, mu, tau):
839-
return ref_rand(size, mu, linalg.inv(tau))
840-
841-
def ref_rand_chol(size, mu, chol):
842-
return ref_rand(size, mu, np.dot(chol, chol.T))
843-
844-
def ref_rand_uchol(size, mu, chol):
845-
return ref_rand(size, mu, np.dot(chol.T, chol))
846-
847-
for n in [2, 3]:
848-
pymc3_random(
849-
pm.MvNormal,
850-
{"mu": Vector(R, n), "cov": PdMatrix(n)},
851-
size=100,
852-
valuedomain=Vector(R, n),
853-
ref_rand=ref_rand,
854-
)
855-
pymc3_random(
856-
pm.MvNormal,
857-
{"mu": Vector(R, n), "tau": PdMatrix(n)},
858-
size=100,
859-
valuedomain=Vector(R, n),
860-
ref_rand=ref_rand_tau,
861-
)
862-
pymc3_random(
863-
pm.MvNormal,
864-
{"mu": Vector(R, n), "chol": PdMatrixChol(n)},
865-
size=100,
866-
valuedomain=Vector(R, n),
867-
ref_rand=ref_rand_chol,
868-
)
869-
pymc3_random(
870-
pm.MvNormal,
871-
{"mu": Vector(R, n), "chol": PdMatrixCholUpper(n)},
872-
size=100,
873-
valuedomain=Vector(R, n),
874-
ref_rand=ref_rand_uchol,
875-
extra_args={"lower": False},
876-
)
877-
878852
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
879853
def test_matrix_normal(self):
880854
def ref_rand(size, mu, rowcov, colcov):
@@ -1017,20 +991,6 @@ def ref_rand(size, nu, Sigma, mu):
1017991
ref_rand=ref_rand,
1018992
)
1019993

1020-
@pytest.mark.skip(reason="This test is covered by Aesara")
1021-
def test_dirichlet(self):
1022-
def ref_rand(size, a):
1023-
return st.dirichlet.rvs(a, size=size)
1024-
1025-
for n in [2, 3]:
1026-
pymc3_random(
1027-
pm.Dirichlet,
1028-
{"a": Vector(Rplus, n)},
1029-
valuedomain=Simplex(n),
1030-
size=100,
1031-
ref_rand=ref_rand,
1032-
)
1033-
1034994
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
1035995
def test_dirichlet_multinomial(self):
1036996
def ref_rand(size, a, n):
@@ -1098,20 +1058,6 @@ def test_dirichlet_multinomial_dist_ShapeError(self, n, a, shape, expectation):
10981058
with expectation:
10991059
m.random()
11001060

1101-
@pytest.mark.skip(reason="This test is covered by Aesara")
1102-
def test_multinomial(self):
1103-
def ref_rand(size, p, n):
1104-
return nr.multinomial(pvals=p, n=n, size=size)
1105-
1106-
for n in [2, 3]:
1107-
pymc3_random_discrete(
1108-
pm.Multinomial,
1109-
{"p": Simplex(n), "n": Nat},
1110-
valuedomain=Vector(Nat, n),
1111-
size=100,
1112-
ref_rand=ref_rand,
1113-
)
1114-
11151061
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
11161062
def test_gumbel(self):
11171063
def ref_rand(size, mu, beta):

0 commit comments

Comments
 (0)