Skip to content

Commit 5132c92

Browse files
Update tests following distributions refactoring
The distributions refactoring moves the random variable sampling to aesara. This relies on numpy and scipy random variables implementation. So, now the only thing we care about testing is that the parametrization on the PyMC side is sendible given the one on the Aesara side (effectively the numpy/scipy one) More details can be found on issue pymc-devs#4554 pymc-devs#4554
1 parent 2e7d042 commit 5132c92

File tree

1 file changed

+71
-42
lines changed

1 file changed

+71
-42
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 71 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
import pytest
2525
import scipy.stats as st
2626

27+
from numpy.testing import assert_almost_equal
2728
from scipy import linalg
2829
from scipy.special import expit
2930

3031
import pymc3 as pm
3132

3233
from pymc3.aesaraf import change_rv_size, floatX, intX
33-
from pymc3.distributions.dist_math import clipped_beta_rvs
3434
from pymc3.distributions.shape_utils import to_tuple
3535
from pymc3.exceptions import ShapeError
3636
from pymc3.tests.helpers import SeededTest, select_by_precision
@@ -524,6 +524,76 @@ def test_dirichlet_random_shape(self, shape, size):
524524
assert pm.Dirichlet.dist(a=np.ones(shape)).random(size=size).shape == out_shape
525525

526526

527+
class TestCorrectParametrizationMappingPymcToScipy(SeededTest):
528+
@staticmethod
529+
def get_inputs_from_apply_node_outputs(outputs):
530+
parents = outputs.get_parents()
531+
if not parents:
532+
raise Exception("Parent Apply node missing for output")
533+
# I am assuming there will always only be 1 Apply parent node in this context
534+
return parents[0].inputs
535+
536+
def test_pymc_params_match_rv_ones(
537+
self, pymc_params, expected_aesara_params, pymc_dist, decimal=6
538+
):
539+
pymc_dist_output = pymc_dist.dist(**dict(pymc_params))
540+
aesera_dist_inputs = self.get_inputs_from_apply_node_outputs(pymc_dist_output)[3:]
541+
assert len(expected_aesara_params) == len(aesera_dist_inputs)
542+
for (expected_name, expected_value), actual_variable in zip(
543+
expected_aesara_params, aesera_dist_inputs
544+
):
545+
assert_almost_equal(expected_value, actual_variable.eval(), decimal=decimal)
546+
547+
def test_normal(self):
548+
params = [("mu", 5.0), ("sigma", 10.0)]
549+
self.test_pymc_params_match_rv_ones(params, params, pm.Normal)
550+
551+
def test_uniform(self):
552+
params = [("lower", 0.5), ("upper", 1.5)]
553+
self.test_pymc_params_match_rv_ones(params, params, pm.Uniform)
554+
555+
def test_half_normal(self):
556+
params, expected_aesara_params = [("sigma", 10.0)], [("mean", 0), ("sigma", 10.0)]
557+
self.test_pymc_params_match_rv_ones(params, expected_aesara_params, pm.HalfNormal)
558+
559+
def test_beta_alpha_beta(self):
560+
params = [("alpha", 2.0), ("beta", 5.0)]
561+
self.test_pymc_params_match_rv_ones(params, params, pm.Beta)
562+
563+
def test_beta_mu_sigma(self):
564+
params = [("mu", 2.0), ("sigma", 5.0)]
565+
expected_alpha, expected_beta = pm.Beta.get_alpha_beta(mu=params[0][1], sigma=params[1][1])
566+
expected_params = [("alpha", expected_alpha), ("beta", expected_beta)]
567+
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Beta)
568+
569+
@pytest.mark.skip(reason="Expected to fail due to bug")
570+
def test_exponential(self):
571+
params = [("lam", 10.0)]
572+
expected_params = [("lam", 1 / params[0][1])]
573+
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Exponential)
574+
575+
def test_cauchy(self):
576+
params = [("alpha", 2.0), ("beta", 5.0)]
577+
self.test_pymc_params_match_rv_ones(params, params, pm.Cauchy)
578+
579+
def test_half_cauchy(self):
580+
params = [("alpha", 2.0), ("beta", 5.0)]
581+
self.test_pymc_params_match_rv_ones(params, params, pm.HalfCauchy)
582+
583+
@pytest.mark.skip(reason="Expected to fail due to bug")
584+
def test_gamma_alpha_beta(self):
585+
params = [("alpha", 2.0), ("beta", 5.0)]
586+
expected_params = [("alpha", params[0][1]), ("beta", 1 / params[1][1])]
587+
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Gamma)
588+
589+
@pytest.mark.skip(reason="Expected to fail due to bug")
590+
def test_gamma_mu_sigma(self):
591+
params = [("mu", 2.0), ("sigma", 5.0)]
592+
expected_alpha, expected_beta = pm.Gamma.get_alpha_beta(mu=params[0][1], sigma=params[1][1])
593+
expected_params = [("alpha", expected_alpha), ("beta", 1 / expected_beta)]
594+
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Gamma)
595+
596+
527597
class TestScalarParameterSamples(SeededTest):
528598
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
529599
def test_bounded(self):
@@ -535,20 +605,6 @@ def ref_rand(size, tau):
535605

536606
pymc3_random(BoundedNormal, {"tau": Rplus}, ref_rand=ref_rand)
537607

538-
@pytest.mark.skip(reason="This test is covered by Aesara")
539-
def test_uniform(self):
540-
def ref_rand(size, lower, upper):
541-
return st.uniform.rvs(size=size, loc=lower, scale=upper - lower)
542-
543-
pymc3_random(pm.Uniform, {"lower": -Rplus, "upper": Rplus}, ref_rand=ref_rand)
544-
545-
@pytest.mark.skip(reason="This test is covered by Aesara")
546-
def test_normal(self):
547-
def ref_rand(size, mu, sigma):
548-
return st.norm.rvs(size=size, loc=mu, scale=sigma)
549-
550-
pymc3_random(pm.Normal, {"mu": R, "sigma": Rplus}, ref_rand=ref_rand)
551-
552608
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
553609
def test_truncated_normal(self):
554610
def ref_rand(size, mu, sigma, lower, upper):
@@ -587,13 +643,6 @@ def ref_rand(size, alpha, mu, sigma):
587643

588644
pymc3_random(pm.SkewNormal, {"mu": R, "sigma": Rplus, "alpha": R}, ref_rand=ref_rand)
589645

590-
@pytest.mark.skip(reason="This test is covered by Aesara")
591-
def test_half_normal(self):
592-
def ref_rand(size, tau):
593-
return st.halfnorm.rvs(size=size, loc=0, scale=tau ** -0.5)
594-
595-
pymc3_random(pm.HalfNormal, {"tau": Rplus}, ref_rand=ref_rand)
596-
597646
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
598647
def test_wald(self):
599648
# Cannot do anything too exciting as scipy wald is a
@@ -607,13 +656,6 @@ def ref_rand(size, mu, lam, alpha):
607656
ref_rand=ref_rand,
608657
)
609658

610-
@pytest.mark.skip(reason="This test is covered by Aesara")
611-
def test_beta(self):
612-
def ref_rand(size, alpha, beta):
613-
return clipped_beta_rvs(a=alpha, b=beta, size=size)
614-
615-
pymc3_random(pm.Beta, {"alpha": Rplus, "beta": Rplus}, ref_rand=ref_rand)
616-
617659
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
618660
def test_laplace(self):
619661
def ref_rand(size, mu, b):
@@ -648,20 +690,7 @@ def ref_rand(size, nu, mu, lam):
648690
pymc3_random(pm.StudentT, {"nu": Rplus, "mu": R, "lam": Rplus}, ref_rand=ref_rand)
649691

650692
@pytest.mark.skip(reason="This test is covered by Aesara")
651-
def test_cauchy(self):
652-
def ref_rand(size, alpha, beta):
653-
return st.cauchy.rvs(alpha, beta, size=size)
654-
655-
pymc3_random(pm.Cauchy, {"alpha": R, "beta": Rplusbig}, ref_rand=ref_rand)
656693

657-
@pytest.mark.skip(reason="This test is covered by Aesara")
658-
def test_half_cauchy(self):
659-
def ref_rand(size, beta):
660-
return st.halfcauchy.rvs(scale=beta, size=size)
661-
662-
pymc3_random(pm.HalfCauchy, {"beta": Rplusbig}, ref_rand=ref_rand)
663-
664-
@pytest.mark.skip(reason="This test is covered by Aesara")
665694
def test_inverse_gamma(self):
666695
def ref_rand(size, alpha, beta):
667696
return st.invgamma.rvs(a=alpha, scale=beta, size=size)

0 commit comments

Comments
 (0)