|
25 | 25 | import scipy.stats as st
|
26 | 26 |
|
27 | 27 | from numpy.testing import assert_almost_equal
|
28 |
| -from scipy import linalg |
29 | 28 | from scipy.special import expit
|
30 | 29 |
|
31 | 30 | import pymc3 as pm
|
32 | 31 |
|
33 | 32 | from pymc3.aesaraf import change_rv_size, floatX, intX
|
| 33 | +from pymc3.distributions.multivariate import quaddist_matrix |
34 | 34 | from pymc3.distributions.shape_utils import to_tuple
|
35 | 35 | from pymc3.exceptions import ShapeError
|
36 | 36 | from pymc3.tests.helpers import SeededTest, select_by_precision
|
|
41 | 41 | NatSmall,
|
42 | 42 | PdMatrix,
|
43 | 43 | PdMatrixChol,
|
44 |
| - PdMatrixCholUpper, |
45 | 44 | R,
|
46 | 45 | RandomPdMatrix,
|
47 | 46 | RealMatrix,
|
@@ -627,6 +626,34 @@ def test_poisson(self):
|
627 | 626 | params = [("mu", 4)]
|
628 | 627 | self._pymc_params_match_rv_ones(params, params, pm.Poisson)
|
629 | 628 |
|
| 629 | + def test_mv_distribution(self): |
| 630 | + params = [("mu", np.array([1.0, 2.0])), ("cov", np.array([[2.0, 0.0], [0.0, 3.5]]))] |
| 631 | + self._pymc_params_match_rv_ones(params, params, pm.MvNormal) |
| 632 | + |
| 633 | + def test_mv_distribution_chol(self): |
| 634 | + params = [("mu", np.array([1.0, 2.0])), ("chol", np.array([[2.0, 0.0], [0.0, 3.5]]))] |
| 635 | + expected_cov = quaddist_matrix(chol=params[1][1]) |
| 636 | + expected_params = [("mu", np.array([1.0, 2.0])), ("cov", expected_cov.eval())] |
| 637 | + self._pymc_params_match_rv_ones(params, expected_params, pm.MvNormal) |
| 638 | + |
| 639 | + def test_mv_distribution_tau(self): |
| 640 | + params = [("mu", np.array([1.0, 2.0])), ("tau", np.array([[2.0, 0.0], [0.0, 3.5]]))] |
| 641 | + expected_cov = quaddist_matrix(tau=params[1][1]) |
| 642 | + expected_params = [("mu", np.array([1.0, 2.0])), ("cov", expected_cov.eval())] |
| 643 | + self._pymc_params_match_rv_ones(params, expected_params, pm.MvNormal) |
| 644 | + |
| 645 | + def test_dirichlet(self): |
| 646 | + params = [("a", np.array([1.0, 2.0]))] |
| 647 | + self._pymc_params_match_rv_ones(params, params, pm.Dirichlet) |
| 648 | + |
| 649 | + def test_multinomial(self): |
| 650 | + params = [("n", 85), ("p", np.array([0.28, 0.62, 0.10]))] |
| 651 | + self._pymc_params_match_rv_ones(params, params, pm.Multinomial) |
| 652 | + |
| 653 | + def test_categorical(self): |
| 654 | + params = [("p", np.array([0.28, 0.62, 0.10]))] |
| 655 | + self._pymc_params_match_rv_ones(params, params, pm.Categorical) |
| 656 | + |
630 | 657 |
|
631 | 658 | class TestScalarParameterSamples(SeededTest):
|
632 | 659 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
@@ -815,66 +842,13 @@ def ref_rand(size, q, beta):
|
815 | 842 | pm.DiscreteWeibull, {"q": Unit, "beta": Rplusdunif}, ref_rand=ref_rand
|
816 | 843 | )
|
817 | 844 |
|
818 |
| - @pytest.mark.skip(reason="This test is covered by Aesara") |
819 |
| - @pytest.mark.parametrize("s", [2, 3, 4]) |
820 |
| - def test_categorical_random(self, s): |
821 |
| - def ref_rand(size, p): |
822 |
| - return nr.choice(np.arange(p.shape[0]), p=p, size=size) |
823 |
| - |
824 |
| - pymc3_random_discrete(pm.Categorical, {"p": Simplex(s)}, ref_rand=ref_rand) |
825 |
| - |
826 | 845 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
827 | 846 | def test_constant_dist(self):
|
828 | 847 | def ref_rand(size, c):
|
829 | 848 | return c * np.ones(size, dtype=int)
|
830 | 849 |
|
831 | 850 | pymc3_random_discrete(pm.Constant, {"c": I}, ref_rand=ref_rand)
|
832 | 851 |
|
833 |
| - @pytest.mark.skip(reason="This test is covered by Aesara") |
834 |
| - def test_mv_normal(self): |
835 |
| - def ref_rand(size, mu, cov): |
836 |
| - return st.multivariate_normal.rvs(mean=mu, cov=cov, size=size) |
837 |
| - |
838 |
| - def ref_rand_tau(size, mu, tau): |
839 |
| - return ref_rand(size, mu, linalg.inv(tau)) |
840 |
| - |
841 |
| - def ref_rand_chol(size, mu, chol): |
842 |
| - return ref_rand(size, mu, np.dot(chol, chol.T)) |
843 |
| - |
844 |
| - def ref_rand_uchol(size, mu, chol): |
845 |
| - return ref_rand(size, mu, np.dot(chol.T, chol)) |
846 |
| - |
847 |
| - for n in [2, 3]: |
848 |
| - pymc3_random( |
849 |
| - pm.MvNormal, |
850 |
| - {"mu": Vector(R, n), "cov": PdMatrix(n)}, |
851 |
| - size=100, |
852 |
| - valuedomain=Vector(R, n), |
853 |
| - ref_rand=ref_rand, |
854 |
| - ) |
855 |
| - pymc3_random( |
856 |
| - pm.MvNormal, |
857 |
| - {"mu": Vector(R, n), "tau": PdMatrix(n)}, |
858 |
| - size=100, |
859 |
| - valuedomain=Vector(R, n), |
860 |
| - ref_rand=ref_rand_tau, |
861 |
| - ) |
862 |
| - pymc3_random( |
863 |
| - pm.MvNormal, |
864 |
| - {"mu": Vector(R, n), "chol": PdMatrixChol(n)}, |
865 |
| - size=100, |
866 |
| - valuedomain=Vector(R, n), |
867 |
| - ref_rand=ref_rand_chol, |
868 |
| - ) |
869 |
| - pymc3_random( |
870 |
| - pm.MvNormal, |
871 |
| - {"mu": Vector(R, n), "chol": PdMatrixCholUpper(n)}, |
872 |
| - size=100, |
873 |
| - valuedomain=Vector(R, n), |
874 |
| - ref_rand=ref_rand_uchol, |
875 |
| - extra_args={"lower": False}, |
876 |
| - ) |
877 |
| - |
878 | 852 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
879 | 853 | def test_matrix_normal(self):
|
880 | 854 | def ref_rand(size, mu, rowcov, colcov):
|
@@ -1017,20 +991,6 @@ def ref_rand(size, nu, Sigma, mu):
|
1017 | 991 | ref_rand=ref_rand,
|
1018 | 992 | )
|
1019 | 993 |
|
1020 |
| - @pytest.mark.skip(reason="This test is covered by Aesara") |
1021 |
| - def test_dirichlet(self): |
1022 |
| - def ref_rand(size, a): |
1023 |
| - return st.dirichlet.rvs(a, size=size) |
1024 |
| - |
1025 |
| - for n in [2, 3]: |
1026 |
| - pymc3_random( |
1027 |
| - pm.Dirichlet, |
1028 |
| - {"a": Vector(Rplus, n)}, |
1029 |
| - valuedomain=Simplex(n), |
1030 |
| - size=100, |
1031 |
| - ref_rand=ref_rand, |
1032 |
| - ) |
1033 |
| - |
1034 | 994 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
1035 | 995 | def test_dirichlet_multinomial(self):
|
1036 | 996 | def ref_rand(size, a, n):
|
@@ -1098,20 +1058,6 @@ def test_dirichlet_multinomial_dist_ShapeError(self, n, a, shape, expectation):
|
1098 | 1058 | with expectation:
|
1099 | 1059 | m.random()
|
1100 | 1060 |
|
1101 |
| - @pytest.mark.skip(reason="This test is covered by Aesara") |
1102 |
| - def test_multinomial(self): |
1103 |
| - def ref_rand(size, p, n): |
1104 |
| - return nr.multinomial(pvals=p, n=n, size=size) |
1105 |
| - |
1106 |
| - for n in [2, 3]: |
1107 |
| - pymc3_random_discrete( |
1108 |
| - pm.Multinomial, |
1109 |
| - {"p": Simplex(n), "n": Nat}, |
1110 |
| - valuedomain=Vector(Nat, n), |
1111 |
| - size=100, |
1112 |
| - ref_rand=ref_rand, |
1113 |
| - ) |
1114 |
| - |
1115 | 1061 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
1116 | 1062 | def test_gumbel(self):
|
1117 | 1063 | def ref_rand(size, mu, beta):
|
|
0 commit comments