|
24 | 24 | import scipy.stats as st
|
25 | 25 |
|
26 | 26 | from numpy.testing import assert_almost_equal
|
27 |
| -from scipy import linalg |
28 | 27 | from scipy.special import expit
|
29 | 28 |
|
30 | 29 | import pymc3 as pm
|
31 | 30 |
|
32 | 31 | from pymc3.aesaraf import change_rv_size, floatX, intX
|
| 32 | +from pymc3.distributions.multivariate import quaddist_matrix |
33 | 33 | from pymc3.distributions.shape_utils import to_tuple
|
34 | 34 | from pymc3.exceptions import ShapeError
|
35 | 35 | from pymc3.tests.helpers import SeededTest
|
|
40 | 40 | NatSmall,
|
41 | 41 | PdMatrix,
|
42 | 42 | PdMatrixChol,
|
43 |
| - PdMatrixCholUpper, |
44 | 43 | R,
|
45 | 44 | RandomPdMatrix,
|
46 | 45 | RealMatrix,
|
@@ -639,6 +638,34 @@ def test_poisson(self):
|
639 | 638 | params = [("mu", 4)]
|
640 | 639 | self._pymc_params_match_rv_ones(params, params, pm.Poisson)
|
641 | 640 |
|
| 641 | + def test_mv_distribution(self): |
| 642 | + params = [("mu", np.array([1.0, 2.0])), ("cov", np.array([[2.0, 0.0], [0.0, 3.5]]))] |
| 643 | + self._pymc_params_match_rv_ones(params, params, pm.MvNormal) |
| 644 | + |
| 645 | + def test_mv_distribution_chol(self): |
| 646 | + params = [("mu", np.array([1.0, 2.0])), ("chol", np.array([[2.0, 0.0], [0.0, 3.5]]))] |
| 647 | + expected_cov = quaddist_matrix(chol=params[1][1]) |
| 648 | + expected_params = [("mu", np.array([1.0, 2.0])), ("cov", expected_cov.eval())] |
| 649 | + self._pymc_params_match_rv_ones(params, expected_params, pm.MvNormal) |
| 650 | + |
| 651 | + def test_mv_distribution_tau(self): |
| 652 | + params = [("mu", np.array([1.0, 2.0])), ("tau", np.array([[2.0, 0.0], [0.0, 3.5]]))] |
| 653 | + expected_cov = quaddist_matrix(tau=params[1][1]) |
| 654 | + expected_params = [("mu", np.array([1.0, 2.0])), ("cov", expected_cov.eval())] |
| 655 | + self._pymc_params_match_rv_ones(params, expected_params, pm.MvNormal) |
| 656 | + |
| 657 | + def test_dirichlet(self): |
| 658 | + params = [("a", np.array([1.0, 2.0]))] |
| 659 | + self._pymc_params_match_rv_ones(params, params, pm.Dirichlet) |
| 660 | + |
| 661 | + def test_multinomial(self): |
| 662 | + params = [("n", 85), ("p", np.array([0.28, 0.62, 0.10]))] |
| 663 | + self._pymc_params_match_rv_ones(params, params, pm.Multinomial) |
| 664 | + |
| 665 | + def test_categorical(self): |
| 666 | + params = [("p", np.array([0.28, 0.62, 0.10]))] |
| 667 | + self._pymc_params_match_rv_ones(params, params, pm.Categorical) |
| 668 | + |
642 | 669 |
|
643 | 670 | class TestScalarParameterSamples(SeededTest):
|
644 | 671 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
@@ -835,66 +862,13 @@ def ref_rand(size, q, beta):
|
835 | 862 | pm.DiscreteWeibull, {"q": Unit, "beta": Rplusdunif}, ref_rand=ref_rand
|
836 | 863 | )
|
837 | 864 |
|
838 |
| - @pytest.mark.skip(reason="This test is covered by Aesara") |
839 |
| - @pytest.mark.parametrize("s", [2, 3, 4]) |
840 |
| - def test_categorical_random(self, s): |
841 |
| - def ref_rand(size, p): |
842 |
| - return nr.choice(np.arange(p.shape[0]), p=p, size=size) |
843 |
| - |
844 |
| - pymc3_random_discrete(pm.Categorical, {"p": Simplex(s)}, ref_rand=ref_rand) |
845 |
| - |
846 | 865 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
847 | 866 | def test_constant_dist(self):
|
848 | 867 | def ref_rand(size, c):
|
849 | 868 | return c * np.ones(size, dtype=int)
|
850 | 869 |
|
851 | 870 | pymc3_random_discrete(pm.Constant, {"c": I}, ref_rand=ref_rand)
|
852 | 871 |
|
853 |
| - @pytest.mark.skip(reason="This test is covered by Aesara") |
854 |
| - def test_mv_normal(self): |
855 |
| - def ref_rand(size, mu, cov): |
856 |
| - return st.multivariate_normal.rvs(mean=mu, cov=cov, size=size) |
857 |
| - |
858 |
| - def ref_rand_tau(size, mu, tau): |
859 |
| - return ref_rand(size, mu, linalg.inv(tau)) |
860 |
| - |
861 |
| - def ref_rand_chol(size, mu, chol): |
862 |
| - return ref_rand(size, mu, np.dot(chol, chol.T)) |
863 |
| - |
864 |
| - def ref_rand_uchol(size, mu, chol): |
865 |
| - return ref_rand(size, mu, np.dot(chol.T, chol)) |
866 |
| - |
867 |
| - for n in [2, 3]: |
868 |
| - pymc3_random( |
869 |
| - pm.MvNormal, |
870 |
| - {"mu": Vector(R, n), "cov": PdMatrix(n)}, |
871 |
| - size=100, |
872 |
| - valuedomain=Vector(R, n), |
873 |
| - ref_rand=ref_rand, |
874 |
| - ) |
875 |
| - pymc3_random( |
876 |
| - pm.MvNormal, |
877 |
| - {"mu": Vector(R, n), "tau": PdMatrix(n)}, |
878 |
| - size=100, |
879 |
| - valuedomain=Vector(R, n), |
880 |
| - ref_rand=ref_rand_tau, |
881 |
| - ) |
882 |
| - pymc3_random( |
883 |
| - pm.MvNormal, |
884 |
| - {"mu": Vector(R, n), "chol": PdMatrixChol(n)}, |
885 |
| - size=100, |
886 |
| - valuedomain=Vector(R, n), |
887 |
| - ref_rand=ref_rand_chol, |
888 |
| - ) |
889 |
| - pymc3_random( |
890 |
| - pm.MvNormal, |
891 |
| - {"mu": Vector(R, n), "chol": PdMatrixCholUpper(n)}, |
892 |
| - size=100, |
893 |
| - valuedomain=Vector(R, n), |
894 |
| - ref_rand=ref_rand_uchol, |
895 |
| - extra_args={"lower": False}, |
896 |
| - ) |
897 |
| - |
898 | 872 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
899 | 873 | def test_matrix_normal(self):
|
900 | 874 | def ref_rand(size, mu, rowcov, colcov):
|
@@ -1037,20 +1011,6 @@ def ref_rand(size, nu, Sigma, mu):
|
1037 | 1011 | ref_rand=ref_rand,
|
1038 | 1012 | )
|
1039 | 1013 |
|
1040 |
| - @pytest.mark.skip(reason="This test is covered by Aesara") |
1041 |
| - def test_dirichlet(self): |
1042 |
| - def ref_rand(size, a): |
1043 |
| - return st.dirichlet.rvs(a, size=size) |
1044 |
| - |
1045 |
| - for n in [2, 3]: |
1046 |
| - pymc3_random( |
1047 |
| - pm.Dirichlet, |
1048 |
| - {"a": Vector(Rplus, n)}, |
1049 |
| - valuedomain=Simplex(n), |
1050 |
| - size=100, |
1051 |
| - ref_rand=ref_rand, |
1052 |
| - ) |
1053 |
| - |
1054 | 1014 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
1055 | 1015 | def test_dirichlet_multinomial(self):
|
1056 | 1016 | def ref_rand(size, a, n):
|
@@ -1118,20 +1078,6 @@ def test_dirichlet_multinomial_dist_ShapeError(self, n, a, shape, expectation):
|
1118 | 1078 | with expectation:
|
1119 | 1079 | m.random()
|
1120 | 1080 |
|
1121 |
| - @pytest.mark.skip(reason="This test is covered by Aesara") |
1122 |
| - def test_multinomial(self): |
1123 |
| - def ref_rand(size, p, n): |
1124 |
| - return nr.multinomial(pvals=p, n=n, size=size) |
1125 |
| - |
1126 |
| - for n in [2, 3]: |
1127 |
| - pymc3_random_discrete( |
1128 |
| - pm.Multinomial, |
1129 |
| - {"p": Simplex(n), "n": Nat}, |
1130 |
| - valuedomain=Vector(Nat, n), |
1131 |
| - size=100, |
1132 |
| - ref_rand=ref_rand, |
1133 |
| - ) |
1134 |
| - |
1135 | 1081 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
1136 | 1082 | def test_gumbel(self):
|
1137 | 1083 | def ref_rand(size, mu, beta):
|
|
0 commit comments