Skip to content

Commit 86fd042

Browse files
committed
Add workaround for HyperGeometric logcdf failure
1 parent c5206f4 commit 86fd042

File tree

1 file changed

+19
-21
lines changed

1 file changed

+19
-21
lines changed

pymc3/tests/test_distributions.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,6 @@ def check_logp(
594594
n_samples=100,
595595
extra_args=None,
596596
scipy_args=None,
597-
skip_params_fn=lambda x: False,
598597
):
599598
"""
600599
Generic test for PyMC3 logp methods
@@ -626,9 +625,6 @@ def check_logp(
626625
the pymc3 distribution logp is calculated
627626
scipy_args : Dictionary with extra arguments needed to call scipy logp method
628627
Usually the same as extra_args
629-
skip_params_fn: Callable
630-
A function that takes a ``dict`` of the test points and returns a
631-
boolean indicating whether or not to perform the test.
632628
"""
633629
if decimal is None:
634630
decimal = select_by_precision(float64=6, float32=3)
@@ -650,8 +646,6 @@ def logp_reference(args):
650646
domains["value"] = domain
651647
for pt in product(domains, n_samples=n_samples):
652648
pt = dict(pt)
653-
if skip_params_fn(pt):
654-
continue
655649
pt_d = self._model_input_dict(model, param_vars, pt)
656650
pt_logp = Point(pt_d, model=model)
657651
pt_ref = Point(pt, filter_model_vars=False, model=model)
@@ -696,7 +690,7 @@ def check_logcdf(
696690
n_samples=100,
697691
skip_paramdomain_inside_edge_test=False,
698692
skip_paramdomain_outside_edge_test=False,
699-
skip_params_fn=lambda x: False,
693+
skip_nan=False,
700694
):
701695
"""
702696
Generic test for PyMC3 logcdf methods
@@ -737,9 +731,8 @@ def check_logcdf(
737731
skip_paramdomain_outside_edge_test : Bool
738732
Whether to run test 2., which checks that pymc3 distribution logcdf
739733
returns -inf for invalid parameter values outside the supported domain edge
740-
skip_params_fn: Callable
741-
A function that takes a ``dict`` of the test points and returns a
742-
boolean indicating whether or not to perform the test.
734+
skip_nan: Bool
735+
Whether to skip comparison when pymc3 logcdf method evaluates to nan
743736
744737
Returns
745738
-------
@@ -759,9 +752,6 @@ def check_logcdf(
759752

760753
for pt in product(domains, n_samples=n_samples):
761754
params = dict(pt)
762-
if skip_params_fn(params):
763-
continue
764-
765755
scipy_eval = scipy_logcdf(**params)
766756

767757
value = params.pop("value")
@@ -770,6 +760,9 @@ def check_logcdf(
770760
param_vars[param_name].set_value(param_value)
771761
pymc3_eval = pymc3_logcdf({"value": value})
772762

763+
if skip_nan and np.isnan(pymc3_eval):
764+
continue
765+
773766
params["value"] = value # for displaying in err_msg
774767
assert_almost_equal(
775768
pymc3_eval,
@@ -851,7 +844,7 @@ def check_selfconsistency_discrete_logcdf(
851844
paramdomains,
852845
decimal=None,
853846
n_samples=100,
854-
skip_params_fn=lambda x: False,
847+
skip_nan=False,
855848
):
856849
"""
857850
Check that logcdf of discrete distributions matches sum of logps up to value
@@ -870,19 +863,25 @@ def check_selfconsistency_discrete_logcdf(
870863

871864
for pt in product(domains, n_samples=n_samples):
872865
params = dict(pt)
873-
if skip_params_fn(params):
874-
continue
875866
value = params.pop("value")
876867
values = np.arange(domain.lower, value + 1)
877868

878869
# Update shared parameter variables in logp/logcdf function
879870
for param_name, param_value in params.items():
880871
param_vars[param_name].set_value(param_value)
881872

873+
logcdf_eval = dist_logcdf({"value": value})
874+
if skip_nan and np.isnan(logcdf_eval):
875+
continue
876+
877+
logp_logsumexp_eval = scipy.special.logsumexp(
878+
[dist_logp({"value": value}) for value in values]
879+
)
880+
882881
with aesara.config.change_flags(mode=Mode("py")):
883882
assert_almost_equal(
884-
dist_logcdf({"value": value}),
885-
scipy.special.logsumexp([dist_logp({"value": value}) for value in values]),
883+
logcdf_eval,
884+
logp_logsumexp_eval,
886885
decimal=decimal,
887886
err_msg=str(pt),
888887
)
@@ -1233,20 +1232,19 @@ def modified_scipy_hypergeom_logcdf(value, N, k, n):
12331232
Nat,
12341233
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
12351234
modified_scipy_hypergeom_logpmf,
1236-
skip_params_fn=lambda x: x["N"] < x["n"] or x["N"] < x["k"],
12371235
)
12381236
self.check_logcdf(
12391237
HyperGeometric,
12401238
Nat,
12411239
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
12421240
modified_scipy_hypergeom_logcdf,
1243-
skip_params_fn=lambda x: x["N"] < x["n"] or x["N"] < x["k"],
1241+
skip_nan=True, # TODO: Remove once aesara/issues/461 is solved
12441242
)
12451243
self.check_selfconsistency_discrete_logcdf(
12461244
HyperGeometric,
12471245
Nat,
12481246
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
1249-
skip_params_fn=lambda x: x["N"] < x["n"] or x["N"] < x["k"],
1247+
skip_nan=True, # TODO: Remove once aesara/issues/461 is solved
12501248
)
12511249

12521250
def test_negative_binomial(self):

0 commit comments

Comments
 (0)