Skip to content

Commit 7bf2004

Browse files
Change tests for more refactored distributions.
More details can be found on issue #4554 #4554
1 parent 3c8c283 commit 7bf2004

File tree

2 files changed

+50
-51
lines changed

2 files changed

+50
-51
lines changed

pymc3/distributions/discrete.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -731,24 +731,23 @@ def NegBinom(a, m, x):
731731

732732
@classmethod
733733
def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs):
734-
n, p = cls.get_mu_alpha(mu, alpha, p, n)
734+
n, p = cls.get_n_p(mu, alpha, p, n)
735735
n = at.as_tensor_variable(floatX(n))
736736
p = at.as_tensor_variable(floatX(p))
737737
return super().dist([n, p], *args, **kwargs)
738738

739739
@classmethod
740-
def get_mu_alpha(cls, mu=None, alpha=None, p=None, n=None):
740+
def get_n_p(cls, mu=None, alpha=None, p=None, n=None):
741741
if n is None:
742742
if alpha is not None:
743-
n = at.as_tensor_variable(floatX(alpha))
743+
n = alpha
744744
else:
745745
raise ValueError("Incompatible parametrization. Must specify either alpha or n.")
746746
elif alpha is not None:
747747
raise ValueError("Incompatible parametrization. Can't specify both alpha and n.")
748748

749749
if p is None:
750750
if mu is not None:
751-
mu = at.as_tensor_variable(floatX(mu))
752751
p = n / (mu + n)
753752
else:
754753
raise ValueError("Incompatible parametrization. Must specify either mu or p.")

pymc3/tests/test_distributions_random.py

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -545,9 +545,7 @@ def get_inputs_from_apply_node_outputs(outputs):
545545
# I am assuming there will always only be 1 Apply parent node in this context
546546
return parents[0].inputs
547547

548-
def test_pymc_params_match_rv_ones(
549-
self, pymc_params, expected_aesara_params, pymc_dist, decimal=6
550-
):
548+
def _pymc_params_match_rv_ones(self, pymc_params, expected_aesara_params, pymc_dist, decimal=6):
551549
pymc_dist_output = pymc_dist.dist(**dict(pymc_params))
552550
aesera_dist_inputs = self.get_inputs_from_apply_node_outputs(pymc_dist_output)[3:]
553551
assert len(expected_aesara_params) == len(aesera_dist_inputs)
@@ -558,52 +556,88 @@ def test_pymc_params_match_rv_ones(
558556

559557
def test_normal(self):
560558
params = [("mu", 5.0), ("sigma", 10.0)]
561-
self.test_pymc_params_match_rv_ones(params, params, pm.Normal)
559+
self._pymc_params_match_rv_ones(params, params, pm.Normal)
562560

563561
def test_uniform(self):
564562
params = [("lower", 0.5), ("upper", 1.5)]
565-
self.test_pymc_params_match_rv_ones(params, params, pm.Uniform)
563+
self._pymc_params_match_rv_ones(params, params, pm.Uniform)
566564

567565
def test_half_normal(self):
568566
params, expected_aesara_params = [("sigma", 10.0)], [("mean", 0), ("sigma", 10.0)]
569-
self.test_pymc_params_match_rv_ones(params, expected_aesara_params, pm.HalfNormal)
567+
self._pymc_params_match_rv_ones(params, expected_aesara_params, pm.HalfNormal)
570568

571569
def test_beta_alpha_beta(self):
572570
params = [("alpha", 2.0), ("beta", 5.0)]
573-
self.test_pymc_params_match_rv_ones(params, params, pm.Beta)
571+
self._pymc_params_match_rv_ones(params, params, pm.Beta)
574572

575573
def test_beta_mu_sigma(self):
576574
params = [("mu", 2.0), ("sigma", 5.0)]
577575
expected_alpha, expected_beta = pm.Beta.get_alpha_beta(mu=params[0][1], sigma=params[1][1])
578576
expected_params = [("alpha", expected_alpha), ("beta", expected_beta)]
579-
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Beta)
577+
self._pymc_params_match_rv_ones(params, expected_params, pm.Beta)
580578

581579
@pytest.mark.skip(reason="Expected to fail due to bug")
582580
def test_exponential(self):
583581
params = [("lam", 10.0)]
584582
expected_params = [("lam", 1 / params[0][1])]
585-
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Exponential)
583+
self._pymc_params_match_rv_ones(params, expected_params, pm.Exponential)
586584

587585
def test_cauchy(self):
588586
params = [("alpha", 2.0), ("beta", 5.0)]
589-
self.test_pymc_params_match_rv_ones(params, params, pm.Cauchy)
587+
self._pymc_params_match_rv_ones(params, params, pm.Cauchy)
590588

591589
def test_half_cauchy(self):
592590
params = [("alpha", 2.0), ("beta", 5.0)]
593-
self.test_pymc_params_match_rv_ones(params, params, pm.HalfCauchy)
591+
self._pymc_params_match_rv_ones(params, params, pm.HalfCauchy)
594592

595593
@pytest.mark.skip(reason="Expected to fail due to bug")
596594
def test_gamma_alpha_beta(self):
597595
params = [("alpha", 2.0), ("beta", 5.0)]
598596
expected_params = [("alpha", params[0][1]), ("beta", 1 / params[1][1])]
599-
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Gamma)
597+
self._pymc_params_match_rv_ones(params, expected_params, pm.Gamma)
600598

601599
@pytest.mark.skip(reason="Expected to fail due to bug")
602600
def test_gamma_mu_sigma(self):
603601
params = [("mu", 2.0), ("sigma", 5.0)]
604602
expected_alpha, expected_beta = pm.Gamma.get_alpha_beta(mu=params[0][1], sigma=params[1][1])
605603
expected_params = [("alpha", expected_alpha), ("beta", 1 / expected_beta)]
606-
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Gamma)
604+
self._pymc_params_match_rv_ones(params, expected_params, pm.Gamma)
605+
606+
def test_inverse_gamma_alpha_beta(self):
607+
params = [("alpha", 2.0), ("beta", 5.0)]
608+
self._pymc_params_match_rv_ones(params, params, pm.InverseGamma)
609+
610+
def test_inverse_gamma_mu_sigma(self):
611+
params = [("mu", 2.0), ("sigma", 5.0)]
612+
expected_alpha, expected_beta = pm.InverseGamma._get_alpha_beta(
613+
mu=params[0][1], sigma=params[1][1], alpha=None, beta=None
614+
)
615+
expected_params = [("alpha", expected_alpha), ("beta", expected_beta)]
616+
self._pymc_params_match_rv_ones(params, expected_params, pm.InverseGamma)
617+
618+
def test_binomial(self):
619+
params = [("n", 100), ("p", 0.33)]
620+
self._pymc_params_match_rv_ones(params, params, pm.Binomial)
621+
622+
def test_negative_binomial(self):
623+
params = [("n", 100), ("p", 0.33)]
624+
self._pymc_params_match_rv_ones(params, params, pm.NegativeBinomial)
625+
626+
def test_negative_binomial_mu_sigma(self):
627+
params = [("mu", 5.0), ("alpha", 8.0)]
628+
expected_n, expected_p = pm.NegativeBinomial.get_n_p(
629+
mu=params[0][1], alpha=params[1][1], n=None, p=None
630+
)
631+
expected_params = [("n", expected_n), ("p", expected_p)]
632+
self._pymc_params_match_rv_ones(params, expected_params, pm.NegativeBinomial)
633+
634+
def test_bernoulli(self):
635+
params = [("p", 0.33)]
636+
self._pymc_params_match_rv_ones(params, params, pm.Bernoulli)
637+
638+
def test_poisson(self):
639+
params = [("mu", 4)]
640+
self._pymc_params_match_rv_ones(params, params, pm.Poisson)
607641

608642

609643
class TestScalarParameterSamples(SeededTest):
@@ -701,13 +735,6 @@ def ref_rand(size, nu, mu, lam):
701735

702736
pymc3_random(pm.StudentT, {"nu": Rplus, "mu": R, "lam": Rplus}, ref_rand=ref_rand)
703737

704-
@pytest.mark.skip(reason="This test is covered by Aesara")
705-
def test_inverse_gamma(self):
706-
def ref_rand(size, alpha, beta):
707-
return st.invgamma.rvs(a=alpha, scale=beta, size=size)
708-
709-
pymc3_random(pm.InverseGamma, {"alpha": Rplus, "beta": Rplus}, ref_rand=ref_rand)
710-
711738
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
712739
def test_pareto(self):
713740
def ref_rand(size, alpha, m):
@@ -754,10 +781,6 @@ def test_half_flat(self):
754781
with pytest.raises(ValueError):
755782
f.random(1)
756783

757-
@pytest.mark.skip(reason="This test is covered by Aesara")
758-
def test_binomial(self):
759-
pymc3_random_discrete(pm.Binomial, {"n": Nat, "p": Unit}, ref_rand=st.binom.rvs)
760-
761784
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
762785
@pytest.mark.xfail(
763786
sys.platform.startswith("win"),
@@ -771,29 +794,6 @@ def test_beta_binomial(self):
771794
def _beta_bin(self, n, alpha, beta, size=None):
772795
return st.binom.rvs(n, st.beta.rvs(a=alpha, b=beta, size=size))
773796

774-
@pytest.mark.skip(reason="This test is covered by Aesara")
775-
def test_bernoulli(self):
776-
pymc3_random_discrete(
777-
pm.Bernoulli, {"p": Unit}, ref_rand=lambda size, p=None: st.bernoulli.rvs(p, size=size)
778-
)
779-
780-
@pytest.mark.skip(reason="This test is covered by Aesara")
781-
def test_poisson(self):
782-
pymc3_random_discrete(pm.Poisson, {"mu": Rplusbig}, size=500, ref_rand=st.poisson.rvs)
783-
784-
@pytest.mark.skip(reason="This test is covered by Aesara")
785-
def test_negative_binomial(self):
786-
def ref_rand(size, alpha, mu):
787-
return st.nbinom.rvs(alpha, alpha / (mu + alpha), size=size)
788-
789-
pymc3_random_discrete(
790-
pm.NegativeBinomial,
791-
{"mu": Rplusbig, "alpha": Rplusbig},
792-
size=100,
793-
fails=50,
794-
ref_rand=ref_rand,
795-
)
796-
797797
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
798798
def test_geometric(self):
799799
pymc3_random_discrete(pm.Geometric, {"p": Unit}, size=500, fails=50, ref_rand=nr.geometric)

0 commit comments

Comments
 (0)