Skip to content

Commit fe67c98

Browse files
committed
Speedup check_selfconsistency_discrete_logcdf test
1 parent 6159d72 commit fe67c98

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

pymc3/tests/test_distributions.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@
103103
logpt,
104104
logpt_sum,
105105
)
106-
from pymc3.math import kronecker, logsumexp
106+
from pymc3.math import kronecker
107107
from pymc3.model import Deterministic, Model, Point
108108
from pymc3.tests.helpers import select_by_precision
109109
from pymc3.vartypes import continuous_types
@@ -858,24 +858,33 @@ def check_selfconsistency_discrete_logcdf(
858858
"""
859859
Check that logcdf of discrete distributions matches sum of logps up to value
860860
"""
861+
# This test only works for scalar random variables
862+
assert distribution.rv_op.ndim_supp == 0
863+
861864
domains = paramdomains.copy()
862865
domains["value"] = domain
863866
if decimal is None:
864867
decimal = select_by_precision(float64=6, float32=3)
868+
869+
model, param_vars = build_model(distribution, domain, paramdomains)
870+
dist_logcdf = model.fastfn(logpt(model["value"], cdf=True))
871+
dist_logp = model.fastfn(logpt(model["value"]))
872+
865873
for pt in product(domains, n_samples=n_samples):
866874
params = dict(pt)
867875
if skip_params_fn(params):
868876
continue
869877
value = params.pop("value")
870878
values = np.arange(domain.lower, value + 1)
871-
dist = distribution.dist(**params)
872-
# This only works for scalar random variables
873-
assert dist.owner.op.ndim_supp == 0
874-
values_dist = change_rv_size(dist, values.shape)
879+
880+
# Update shared parameter variables in logp/logcdf function
881+
for param_name, param_value in params.items():
882+
param_vars[param_name].set_value(param_value)
883+
875884
with aesara.config.change_flags(mode=Mode("py")):
876885
assert_almost_equal(
877-
logcdf(dist, value).eval(),
878-
logsumexp(logpt(values_dist, values), keepdims=False).eval(),
886+
dist_logcdf({"value": value}),
887+
scipy.special.logsumexp([dist_logp({"value": value}) for value in values]),
879888
decimal=decimal,
880889
err_msg=str(pt),
881890
)

0 commit comments

Comments
 (0)