Skip to content

Commit 3c8c283

Browse files
Update tests following distributions refactoring
The distributions refactoring moves the random variable sampling to aesara. This relies on numpy and scipy random variables implementation. So, now the only thing we care about testing is that the parametrization on the PyMC side is sendible given the one on the Aesara side (effectively the numpy/scipy one) More details can be found on issue pymc-devs#4554 pymc-devs#4554
1 parent de8ee52 commit 3c8c283

File tree

1 file changed

+71
-64
lines changed

1 file changed

+71
-64
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 71 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323
import pytest
2424
import scipy.stats as st
2525

26+
from numpy.testing import assert_almost_equal
2627
from scipy import linalg
2728
from scipy.special import expit
2829

2930
import pymc3 as pm
3031

3132
from pymc3.aesaraf import change_rv_size, floatX, intX
32-
from pymc3.distributions.dist_math import clipped_beta_rvs
3333
from pymc3.distributions.shape_utils import to_tuple
3434
from pymc3.exceptions import ShapeError
3535
from pymc3.tests.helpers import SeededTest
@@ -536,6 +536,76 @@ def test_dirichlet_random_shape(self, shape, size):
536536
assert pm.Dirichlet.dist(a=np.ones(shape)).random(size=size).shape == out_shape
537537

538538

539+
class TestCorrectParametrizationMappingPymcToScipy(SeededTest):
540+
@staticmethod
541+
def get_inputs_from_apply_node_outputs(outputs):
542+
parents = outputs.get_parents()
543+
if not parents:
544+
raise Exception("Parent Apply node missing for output")
545+
# I am assuming there will always only be 1 Apply parent node in this context
546+
return parents[0].inputs
547+
548+
def test_pymc_params_match_rv_ones(
549+
self, pymc_params, expected_aesara_params, pymc_dist, decimal=6
550+
):
551+
pymc_dist_output = pymc_dist.dist(**dict(pymc_params))
552+
aesera_dist_inputs = self.get_inputs_from_apply_node_outputs(pymc_dist_output)[3:]
553+
assert len(expected_aesara_params) == len(aesera_dist_inputs)
554+
for (expected_name, expected_value), actual_variable in zip(
555+
expected_aesara_params, aesera_dist_inputs
556+
):
557+
assert_almost_equal(expected_value, actual_variable.eval(), decimal=decimal)
558+
559+
def test_normal(self):
560+
params = [("mu", 5.0), ("sigma", 10.0)]
561+
self.test_pymc_params_match_rv_ones(params, params, pm.Normal)
562+
563+
def test_uniform(self):
564+
params = [("lower", 0.5), ("upper", 1.5)]
565+
self.test_pymc_params_match_rv_ones(params, params, pm.Uniform)
566+
567+
def test_half_normal(self):
568+
params, expected_aesara_params = [("sigma", 10.0)], [("mean", 0), ("sigma", 10.0)]
569+
self.test_pymc_params_match_rv_ones(params, expected_aesara_params, pm.HalfNormal)
570+
571+
def test_beta_alpha_beta(self):
572+
params = [("alpha", 2.0), ("beta", 5.0)]
573+
self.test_pymc_params_match_rv_ones(params, params, pm.Beta)
574+
575+
def test_beta_mu_sigma(self):
576+
params = [("mu", 2.0), ("sigma", 5.0)]
577+
expected_alpha, expected_beta = pm.Beta.get_alpha_beta(mu=params[0][1], sigma=params[1][1])
578+
expected_params = [("alpha", expected_alpha), ("beta", expected_beta)]
579+
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Beta)
580+
581+
@pytest.mark.skip(reason="Expected to fail due to bug")
582+
def test_exponential(self):
583+
params = [("lam", 10.0)]
584+
expected_params = [("lam", 1 / params[0][1])]
585+
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Exponential)
586+
587+
def test_cauchy(self):
588+
params = [("alpha", 2.0), ("beta", 5.0)]
589+
self.test_pymc_params_match_rv_ones(params, params, pm.Cauchy)
590+
591+
def test_half_cauchy(self):
592+
params = [("alpha", 2.0), ("beta", 5.0)]
593+
self.test_pymc_params_match_rv_ones(params, params, pm.HalfCauchy)
594+
595+
@pytest.mark.skip(reason="Expected to fail due to bug")
596+
def test_gamma_alpha_beta(self):
597+
params = [("alpha", 2.0), ("beta", 5.0)]
598+
expected_params = [("alpha", params[0][1]), ("beta", 1 / params[1][1])]
599+
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Gamma)
600+
601+
@pytest.mark.skip(reason="Expected to fail due to bug")
602+
def test_gamma_mu_sigma(self):
603+
params = [("mu", 2.0), ("sigma", 5.0)]
604+
expected_alpha, expected_beta = pm.Gamma.get_alpha_beta(mu=params[0][1], sigma=params[1][1])
605+
expected_params = [("alpha", expected_alpha), ("beta", 1 / expected_beta)]
606+
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Gamma)
607+
608+
539609
class TestScalarParameterSamples(SeededTest):
540610
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
541611
def test_bounded(self):
@@ -547,20 +617,6 @@ def ref_rand(size, tau):
547617

548618
pymc3_random(BoundedNormal, {"tau": Rplus}, ref_rand=ref_rand)
549619

550-
@pytest.mark.skip(reason="This test is covered by Aesara")
551-
def test_uniform(self):
552-
def ref_rand(size, lower, upper):
553-
return st.uniform.rvs(size=size, loc=lower, scale=upper - lower)
554-
555-
pymc3_random(pm.Uniform, {"lower": -Rplus, "upper": Rplus}, ref_rand=ref_rand)
556-
557-
@pytest.mark.skip(reason="This test is covered by Aesara")
558-
def test_normal(self):
559-
def ref_rand(size, mu, sigma):
560-
return st.norm.rvs(size=size, loc=mu, scale=sigma)
561-
562-
pymc3_random(pm.Normal, {"mu": R, "sigma": Rplus}, ref_rand=ref_rand)
563-
564620
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
565621
def test_truncated_normal(self):
566622
def ref_rand(size, mu, sigma, lower, upper):
@@ -599,13 +655,6 @@ def ref_rand(size, alpha, mu, sigma):
599655

600656
pymc3_random(pm.SkewNormal, {"mu": R, "sigma": Rplus, "alpha": R}, ref_rand=ref_rand)
601657

602-
@pytest.mark.skip(reason="This test is covered by Aesara")
603-
def test_half_normal(self):
604-
def ref_rand(size, tau):
605-
return st.halfnorm.rvs(size=size, loc=0, scale=tau ** -0.5)
606-
607-
pymc3_random(pm.HalfNormal, {"tau": Rplus}, ref_rand=ref_rand)
608-
609658
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
610659
def test_wald(self):
611660
# Cannot do anything too exciting as scipy wald is a
@@ -619,20 +668,6 @@ def ref_rand(size, mu, lam, alpha):
619668
ref_rand=ref_rand,
620669
)
621670

622-
@pytest.mark.skip(reason="This test is covered by Aesara")
623-
def test_beta(self):
624-
def ref_rand(size, alpha, beta):
625-
return clipped_beta_rvs(a=alpha, b=beta, size=size)
626-
627-
pymc3_random(pm.Beta, {"alpha": Rplus, "beta": Rplus}, ref_rand=ref_rand)
628-
629-
@pytest.mark.skip(reason="This test is covered by Aesara")
630-
def test_exponential(self):
631-
def ref_rand(size, lam):
632-
return nr.exponential(scale=1.0 / lam, size=size)
633-
634-
pymc3_random(pm.Exponential, {"lam": Rplus}, ref_rand=ref_rand)
635-
636671
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
637672
def test_laplace(self):
638673
def ref_rand(size, mu, b):
@@ -666,34 +701,6 @@ def ref_rand(size, nu, mu, lam):
666701

667702
pymc3_random(pm.StudentT, {"nu": Rplus, "mu": R, "lam": Rplus}, ref_rand=ref_rand)
668703

669-
@pytest.mark.skip(reason="This test is covered by Aesara")
670-
def test_cauchy(self):
671-
def ref_rand(size, alpha, beta):
672-
return st.cauchy.rvs(alpha, beta, size=size)
673-
674-
pymc3_random(pm.Cauchy, {"alpha": R, "beta": Rplusbig}, ref_rand=ref_rand)
675-
676-
@pytest.mark.skip(reason="This test is covered by Aesara")
677-
def test_half_cauchy(self):
678-
def ref_rand(size, beta):
679-
return st.halfcauchy.rvs(scale=beta, size=size)
680-
681-
pymc3_random(pm.HalfCauchy, {"beta": Rplusbig}, ref_rand=ref_rand)
682-
683-
@pytest.mark.skip(reason="This test is covered by Aesara")
684-
def test_gamma_alpha_beta(self):
685-
def ref_rand(size, alpha, beta):
686-
return st.gamma.rvs(alpha, scale=1.0 / beta, size=size)
687-
688-
pymc3_random(pm.Gamma, {"alpha": Rplusbig, "beta": Rplusbig}, ref_rand=ref_rand)
689-
690-
@pytest.mark.skip(reason="This test is covered by Aesara")
691-
def test_gamma_mu_sigma(self):
692-
def ref_rand(size, mu, sigma):
693-
return st.gamma.rvs(mu ** 2 / sigma ** 2, scale=sigma ** 2 / mu, size=size)
694-
695-
pymc3_random(pm.Gamma, {"mu": Rplusbig, "sigma": Rplusbig}, ref_rand=ref_rand)
696-
697704
@pytest.mark.skip(reason="This test is covered by Aesara")
698705
def test_inverse_gamma(self):
699706
def ref_rand(size, alpha, beta):

0 commit comments

Comments
 (0)