Skip to content

Commit 6159d72

Browse files
committed
Speedup check_logcdf test
1 parent ad9b919 commit 6159d72

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

pymc3/tests/test_distributions.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,10 @@ def check_logcdf(
750750
if not skip_paramdomain_inside_edge_test:
751751
domains = paramdomains.copy()
752752
domains["value"] = domain
753+
754+
model, param_vars = build_model(pymc3_dist, domain, paramdomains)
755+
pymc3_logcdf = model.fastfn(logpt(model["value"], cdf=True))
756+
753757
if decimal is None:
754758
decimal = select_by_precision(float64=6, float32=3)
755759

@@ -758,17 +762,23 @@ def check_logcdf(
758762
if skip_params_fn(params):
759763
continue
760764
scipy_cdf = scipy_logcdf(**params)
765+
766+
scipy_eval = scipy_logcdf(**params)
761767
value = params.pop("value")
762-
with Model() as m:
763-
dist = pymc3_dist("y", **params)
768+
769+
# Update shared parameter variables in pymc3_logcdf function
770+
for param_name, param_value in params.items():
771+
param_vars[param_name].set_value(param_value)
772+
773+
pymc3_eval = pymc3_logcdf({"value": value})
774+
764775
params["value"] = value # for displaying in err_msg
765-
with aesara.config.change_flags(on_opt_error="raise", mode=Mode("py")):
766-
assert_almost_equal(
767-
logcdf(dist, value).eval(),
768-
scipy_cdf,
769-
decimal=decimal,
770-
err_msg=str(params),
771-
)
776+
assert_almost_equal(
777+
pymc3_eval,
778+
scipy_eval,
779+
decimal=decimal,
780+
err_msg=str(params),
781+
)
772782

773783
valid_value = domain.vals[0]
774784
valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()}

0 commit comments

Comments
 (0)