Skip to content

Commit edb829e

Browse files
committed
Speedup check_logcdf and check_selfconsistency_discrete_logcdf tests
Also reverts reduced test n_samples due to speed issues
1 parent 691be5e commit edb829e

File tree

1 file changed

+42
-34
lines changed

1 file changed

+42
-34
lines changed

pymc3/tests/test_distributions.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
import pymc3 as pm
3838

39-
from pymc3.aesaraf import change_rv_size, floatX, intX
39+
from pymc3.aesaraf import floatX, intX
4040
from pymc3.distributions import (
4141
AR1,
4242
CAR,
@@ -102,9 +102,10 @@
102102
continuous,
103103
logcdf,
104104
logp,
105+
logpt,
105106
logpt_sum,
106107
)
107-
from pymc3.math import kronecker, logsumexp
108+
from pymc3.math import kronecker
108109
from pymc3.model import Deterministic, Model, Point
109110
from pymc3.tests.helpers import select_by_precision
110111
from pymc3.vartypes import continuous_types
@@ -751,25 +752,33 @@ def check_logcdf(
751752
if not skip_paramdomain_inside_edge_test:
752753
domains = paramdomains.copy()
753754
domains["value"] = domain
755+
756+
model, param_vars = build_model(pymc3_dist, domain, paramdomains)
757+
pymc3_logcdf = model.fastfn(logpt(model["value"], cdf=True))
758+
754759
if decimal is None:
755760
decimal = select_by_precision(float64=6, float32=3)
756761

757762
for pt in product(domains, n_samples=n_samples):
758763
params = dict(pt)
759764
if skip_params_fn(params):
760765
continue
761-
scipy_cdf = scipy_logcdf(**params)
766+
767+
scipy_eval = scipy_logcdf(**params)
768+
762769
value = params.pop("value")
763-
with Model() as m:
764-
dist = pymc3_dist("y", **params)
770+
# Update shared parameter variables in pymc3_logcdf function
771+
for param_name, param_value in params.items():
772+
param_vars[param_name].set_value(param_value)
773+
pymc3_eval = pymc3_logcdf({"value": value})
774+
765775
params["value"] = value # for displaying in err_msg
766-
with aesara.config.change_flags(on_opt_error="raise", mode=Mode("py")):
767-
assert_almost_equal(
768-
logcdf(dist, value).eval(),
769-
scipy_cdf,
770-
decimal=decimal,
771-
err_msg=str(params),
772-
)
776+
assert_almost_equal(
777+
pymc3_eval,
778+
scipy_eval,
779+
decimal=decimal,
780+
err_msg=str(params),
781+
)
773782

774783
valid_value = domain.vals[0]
775784
valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()}
@@ -849,24 +858,33 @@ def check_selfconsistency_discrete_logcdf(
849858
"""
850859
Check that logcdf of discrete distributions matches sum of logps up to value
851860
"""
861+
# This test only works for scalar random variables
862+
assert distribution.rv_op.ndim_supp == 0
863+
852864
domains = paramdomains.copy()
853865
domains["value"] = domain
854866
if decimal is None:
855867
decimal = select_by_precision(float64=6, float32=3)
868+
869+
model, param_vars = build_model(distribution, domain, paramdomains)
870+
dist_logcdf = model.fastfn(logpt(model["value"], cdf=True))
871+
dist_logp = model.fastfn(logpt(model["value"]))
872+
856873
for pt in product(domains, n_samples=n_samples):
857874
params = dict(pt)
858875
if skip_params_fn(params):
859876
continue
860877
value = params.pop("value")
861878
values = np.arange(domain.lower, value + 1)
862-
dist = distribution.dist(**params)
863-
# This only works for scalar random variables
864-
assert dist.owner.op.ndim_supp == 0
865-
values_dist = change_rv_size(dist, values.shape)
879+
880+
# Update shared parameter variables in logp/logcdf function
881+
for param_name, param_value in params.items():
882+
param_vars[param_name].set_value(param_value)
883+
866884
with aesara.config.change_flags(mode=Mode("py")):
867885
assert_almost_equal(
868-
logcdf(dist, value).eval(),
869-
logsumexp(logp(values_dist, values), keepdims=False).eval(),
886+
dist_logcdf({"value": value}),
887+
scipy.special.logsumexp([dist_logp({"value": value}) for value in values]),
870888
decimal=decimal,
871889
err_msg=str(pt),
872890
)
@@ -1140,13 +1158,17 @@ def test_beta(self):
11401158
{"alpha": Rplus, "beta": Rplus},
11411159
lambda value, alpha, beta: sp.beta.logpdf(value, alpha, beta),
11421160
)
1143-
self.check_logp(Beta, Unit, {"mu": Unit, "sigma": Rplus}, beta_mu_sigma)
1161+
self.check_logp(
1162+
Beta,
1163+
Unit,
1164+
{"mu": Unit, "sigma": Rplus},
1165+
beta_mu_sigma,
1166+
)
11441167
self.check_logcdf(
11451168
Beta,
11461169
Unit,
11471170
{"alpha": Rplus, "beta": Rplus},
11481171
lambda value, alpha, beta: sp.beta.logcdf(value, alpha, beta),
1149-
n_samples=10,
11501172
decimal=select_by_precision(float64=5, float32=3),
11511173
)
11521174

@@ -1269,20 +1291,17 @@ def scipy_mu_alpha_logcdf(value, mu, alpha):
12691291
Nat,
12701292
{"mu": Rplus, "alpha": Rplus},
12711293
scipy_mu_alpha_logcdf,
1272-
n_samples=5,
12731294
)
12741295
self.check_logcdf(
12751296
NegativeBinomial,
12761297
Nat,
12771298
{"p": Unit, "n": Rplus},
12781299
lambda value, p, n: sp.nbinom.logcdf(value, n, p),
1279-
n_samples=5,
12801300
)
12811301
self.check_selfconsistency_discrete_logcdf(
12821302
NegativeBinomial,
12831303
Nat,
12841304
{"mu": Rplus, "alpha": Rplus},
1285-
n_samples=10,
12861305
)
12871306

12881307
@pytest.mark.parametrize(
@@ -1340,7 +1359,6 @@ def test_lognormal(self):
13401359
Rplus,
13411360
{"mu": R, "sigma": Rplusbig},
13421361
lambda value, mu, sigma: floatX(sp.lognorm.logpdf(value, sigma, 0, np.exp(mu))),
1343-
n_samples=5, # Just testing alternative parametrization
13441362
)
13451363
self.check_logcdf(
13461364
Lognormal,
@@ -1353,7 +1371,6 @@ def test_lognormal(self):
13531371
Rplus,
13541372
{"mu": R, "sigma": Rplusbig},
13551373
lambda value, mu, sigma: sp.lognorm.logcdf(value, sigma, 0, np.exp(mu)),
1356-
n_samples=5, # Just testing alternative parametrization
13571374
)
13581375

13591376
def test_t(self):
@@ -1368,14 +1385,12 @@ def test_t(self):
13681385
R,
13691386
{"nu": Rplus, "mu": R, "sigma": Rplus},
13701387
lambda value, nu, mu, sigma: sp.t.logpdf(value, nu, mu, sigma),
1371-
n_samples=5, # Just testing alternative parametrization
13721388
)
13731389
self.check_logcdf(
13741390
StudentT,
13751391
R,
13761392
{"nu": Rplus, "mu": R, "lam": Rplus},
13771393
lambda value, nu, mu, lam: sp.t.logcdf(value, nu, mu, lam ** -0.5),
1378-
n_samples=10, # relies on slow incomplete beta
13791394
)
13801395
# TODO: reenable when PR #4736 is merged
13811396
"""
@@ -1384,7 +1399,6 @@ def test_t(self):
13841399
R,
13851400
{"nu": Rplus, "mu": R, "sigma": Rplus},
13861401
lambda value, nu, mu, sigma: sp.t.logcdf(value, nu, mu, sigma),
1387-
n_samples=5, # Just testing alternative parametrization
13881402
)
13891403
"""
13901404

@@ -1561,13 +1575,11 @@ def test_binomial(self):
15611575
Nat,
15621576
{"n": NatSmall, "p": Unit},
15631577
lambda value, n, p: sp.binom.logcdf(value, n, p),
1564-
n_samples=10,
15651578
)
15661579
self.check_selfconsistency_discrete_logcdf(
15671580
Binomial,
15681581
Nat,
15691582
{"n": NatSmall, "p": Unit},
1570-
n_samples=10,
15711583
)
15721584

15731585
@pytest.mark.xfail(reason="checkd tests has not been refactored")
@@ -1769,14 +1781,12 @@ def logcdf_fn(value, psi, mu, alpha):
17691781
Nat,
17701782
{"psi": Unit, "mu": Rplusbig, "alpha": Rplusbig},
17711783
logcdf_fn,
1772-
n_samples=10,
17731784
)
17741785

17751786
self.check_selfconsistency_discrete_logcdf(
17761787
ZeroInflatedNegativeBinomial,
17771788
Nat,
17781789
{"psi": Unit, "mu": Rplusbig, "alpha": Rplusbig},
1779-
n_samples=10,
17801790
)
17811791

17821792
@pytest.mark.xfail(reason="Test not refactored yet")
@@ -1809,14 +1819,12 @@ def logcdf_fn(value, psi, n, p):
18091819
Nat,
18101820
{"psi": Unit, "n": NatSmall, "p": Unit},
18111821
logcdf_fn,
1812-
n_samples=10,
18131822
)
18141823

18151824
self.check_selfconsistency_discrete_logcdf(
18161825
ZeroInflatedBinomial,
18171826
Nat,
18181827
{"n": NatSmall, "p": Unit, "psi": Unit},
1819-
n_samples=10,
18201828
)
18211829

18221830
@pytest.mark.parametrize("n", [1, 2, 3])

0 commit comments

Comments
 (0)