Skip to content

Commit a3ab0f1

Browse files
committed
Add scalar parameter and value bound checks in TestMatchesScipy.check_logp
1 parent 97587bf commit a3ab0f1

File tree

1 file changed

+93
-32
lines changed

1 file changed

+93
-32
lines changed

pymc/tests/test_distributions.py

Lines changed: 93 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,7 @@ def check_logp(
634634
n_samples=100,
635635
extra_args=None,
636636
scipy_args=None,
637+
skip_paramdomain_outside_edge_test=False,
637638
):
638639
"""
639640
Generic test for PyMC logp methods
@@ -679,14 +680,39 @@ def logp_reference(args):
679680
args.update(scipy_args)
680681
return scipy_logp(**args)
681682

683+
def _model_input_dict(model, param_vars, pt):
684+
"""Create a dict with only the necessary, transformed logp inputs."""
685+
pt_d = {}
686+
for k, v in pt.items():
687+
rv_var = model.named_vars.get(k)
688+
nv = param_vars.get(k, rv_var)
689+
nv = getattr(nv.tag, "value_var", nv)
690+
691+
transform = getattr(nv.tag, "transform", None)
692+
if transform:
693+
# todo: the compiled graph behind this should be cached and
694+
# reused (if it isn't already).
695+
v = transform.forward(rv_var, v).eval()
696+
697+
if nv.name in param_vars:
698+
# update the shared parameter variables in `param_vars`
699+
param_vars[nv.name].set_value(v)
700+
else:
701+
# create an argument entry for the (potentially
702+
# transformed) "value" variable
703+
pt_d[nv.name] = v
704+
705+
return pt_d
706+
682707
model, param_vars = build_model(pymc_dist, domain, paramdomains, extra_args)
683708
logp_pymc = model.fastlogp_nojac
684709

710+
# Test supported value and parameters domain matches scipy
685711
domains = paramdomains.copy()
686712
domains["value"] = domain
687713
for pt in product(domains, n_samples=n_samples):
688714
pt = dict(pt)
689-
pt_d = self._model_input_dict(model, param_vars, pt)
715+
pt_d = _model_input_dict(model, param_vars, pt)
690716
pt_logp = Point(pt_d, model=model)
691717
pt_ref = Point(pt, filter_model_vars=False, model=model)
692718
assert_almost_equal(
@@ -696,29 +722,58 @@ def logp_reference(args):
696722
err_msg=str(pt),
697723
)
698724

699-
def _model_input_dict(self, model, param_vars, pt):
700-
"""Create a dict with only the necessary, transformed logp inputs."""
701-
pt_d = {}
702-
for k, v in pt.items():
703-
rv_var = model.named_vars.get(k)
704-
nv = param_vars.get(k, rv_var)
705-
nv = getattr(nv.tag, "value_var", nv)
706-
707-
transform = getattr(nv.tag, "transform", None)
708-
if transform:
709-
# todo: the compiled graph behind this should be cached and
710-
# reused (if it isn't already).
711-
v = transform.forward(rv_var, v).eval()
712-
713-
if nv.name in param_vars:
714-
# update the shared parameter variables in `param_vars`
715-
param_vars[nv.name].set_value(v)
716-
else:
717-
# create an argument entry for the (potentially
718-
# transformed) "value" variable
719-
pt_d[nv.name] = v
725+
valid_value = domain.vals[0]
726+
valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()}
727+
valid_dist = pymc_dist.dist(**valid_params, **extra_args)
720728

721-
return pt_d
729+
# Test pymc distribution raises ParameterValueError for scalar parameters outside
730+
# the supported domain edges (excluding edges)
731+
if not skip_paramdomain_outside_edge_test:
732+
# Step1: collect potential invalid parameters
733+
invalid_params = {param: [None, None] for param in paramdomains}
734+
for param, paramdomain in paramdomains.items():
735+
if np.ndim(paramdomain.lower) != 0:
736+
continue
737+
if np.isfinite(paramdomain.lower):
738+
invalid_params[param][0] = paramdomain.lower - 1
739+
if np.isfinite(paramdomain.upper):
740+
invalid_params[param][1] = paramdomain.upper + 1
741+
742+
# Step2: test invalid parameters, one a time
743+
for invalid_param, invalid_edges in invalid_params.items():
744+
for invalid_edge in invalid_edges:
745+
if invalid_edge is None:
746+
continue
747+
test_params = valid_params.copy() # Shallow copy should be okay
748+
test_params[invalid_param] = at.as_tensor_variable(invalid_edge)
749+
# We need to remove `Assert`s introduced by checks like
750+
# `assert_negative_support` and disable test values;
751+
# otherwise, we won't be able to create the `RandomVariable`
752+
with aesara.config.change_flags(compute_test_value="off"):
753+
invalid_dist = pymc_dist.dist(**test_params, **extra_args)
754+
with aesara.config.change_flags(mode=Mode("py")):
755+
with pytest.raises(ParameterValueError):
756+
logp(invalid_dist, valid_value).eval()
757+
pytest.fail(f"test_params={test_params}, valid_value={valid_value}")
758+
759+
# Test that values outside of scalar domain support evaluate to -np.inf
760+
if np.ndim(domain.lower) != 0:
761+
return
762+
invalid_values = [None, None]
763+
if np.isfinite(domain.lower):
764+
invalid_values[0] = domain.lower - 1
765+
if np.isfinite(domain.upper):
766+
invalid_values[1] = domain.upper + 1
767+
768+
for invalid_value in invalid_values:
769+
if invalid_value is None:
770+
continue
771+
with aesara.config.change_flags(mode=Mode("py")):
772+
assert_equal(
773+
logp(valid_dist, invalid_value).eval(),
774+
-np.inf,
775+
err_msg=str(invalid_value),
776+
)
722777

723778
def check_logcdf(
724779
self,
@@ -826,7 +881,7 @@ def check_logcdf(
826881
for invalid_edge in invalid_edges:
827882
if invalid_edge is not None:
828883
test_params = valid_params.copy() # Shallow copy should be okay
829-
test_params[invalid_param] = invalid_edge
884+
test_params[invalid_param] = at.as_tensor_variable(invalid_edge)
830885
# We need to remove `Assert`s introduced by checks like
831886
# `assert_negative_support` and disable test values;
832887
# otherwise, we won't be able to create the
@@ -934,6 +989,7 @@ def test_uniform(self):
934989
Runif,
935990
{"lower": -Rplusunif, "upper": Rplusunif},
936991
lambda value, lower, upper: sp.uniform.logpdf(value, lower, upper - lower),
992+
skip_paramdomain_outside_edge_test=True,
937993
)
938994
self.check_logcdf(
939995
Uniform,
@@ -954,6 +1010,7 @@ def test_triangular(self):
9541010
Runif,
9551011
{"lower": -Rplusunif, "c": Runif, "upper": Rplusunif},
9561012
lambda value, c, lower, upper: sp.triang.logpdf(value, c - lower, lower, upper - lower),
1013+
skip_paramdomain_outside_edge_test=True,
9571014
)
9581015
self.check_logcdf(
9591016
Triangular,
@@ -1007,6 +1064,7 @@ def test_discrete_unif(self):
10071064
Rdunif,
10081065
{"lower": -Rplusdunif, "upper": Rplusdunif},
10091066
lambda value, lower, upper: sp.randint.logpmf(value, lower, upper + 1),
1067+
skip_paramdomain_outside_edge_test=True,
10101068
)
10111069
self.check_logcdf(
10121070
DiscreteUniform,
@@ -1017,7 +1075,7 @@ def test_discrete_unif(self):
10171075
)
10181076
self.check_selfconsistency_discrete_logcdf(
10191077
DiscreteUniform,
1020-
Rdunif,
1078+
Domain([-10, 0, 10], "int64"),
10211079
{"lower": -Rplusdunif, "upper": Rplusdunif},
10221080
)
10231081
# Custom logp / logcdf check for invalid parameters
@@ -1029,7 +1087,7 @@ def test_discrete_unif(self):
10291087
logcdf(invalid_dist, 2).eval()
10301088

10311089
def test_flat(self):
1032-
self.check_logp(Flat, Runif, {}, lambda value: 0)
1090+
self.check_logp(Flat, R, {}, lambda value: 0)
10331091
with Model():
10341092
x = Flat("a")
10351093
self.check_logcdf(Flat, R, {}, lambda value: np.log(0.5))
@@ -1074,6 +1132,7 @@ def scipy_logp(value, mu, sigma, lower, upper):
10741132
{"mu": R, "sigma": Rplusbig, "lower": -Rplusbig, "upper": Rplusbig},
10751133
scipy_logp,
10761134
decimal=select_by_precision(float64=6, float32=1),
1135+
skip_paramdomain_outside_edge_test=True,
10771136
)
10781137

10791138
self.check_logp(
@@ -1082,6 +1141,7 @@ def scipy_logp(value, mu, sigma, lower, upper):
10821141
{"mu": R, "sigma": Rplusbig, "upper": Rplusbig},
10831142
functools.partial(scipy_logp, lower=-np.inf),
10841143
decimal=select_by_precision(float64=6, float32=1),
1144+
skip_paramdomain_outside_edge_test=True,
10851145
)
10861146

10871147
self.check_logp(
@@ -1090,6 +1150,7 @@ def scipy_logp(value, mu, sigma, lower, upper):
10901150
{"mu": R, "sigma": Rplusbig, "lower": -Rplusbig},
10911151
functools.partial(scipy_logp, upper=np.inf),
10921152
decimal=select_by_precision(float64=6, float32=1),
1153+
skip_paramdomain_outside_edge_test=True,
10931154
)
10941155

10951156
def test_half_normal(self):
@@ -1679,13 +1740,13 @@ def test_discrete_weibull(self):
16791740
self.check_logp(
16801741
DiscreteWeibull,
16811742
Nat,
1682-
{"q": Unit, "beta": Rplusdunif},
1743+
{"q": Unit, "beta": NatSmall},
16831744
discrete_weibull_logpmf,
16841745
)
16851746
self.check_selfconsistency_discrete_logcdf(
16861747
DiscreteWeibull,
16871748
Nat,
1688-
{"q": Unit, "beta": Rplusdunif},
1749+
{"q": Unit, "beta": NatSmall},
16891750
)
16901751

16911752
def test_poisson(self):
@@ -2088,7 +2149,7 @@ def test_wishart(self, n):
20882149
self.check_logp(
20892150
Wishart,
20902151
PdMatrix(n),
2091-
{"nu": Domain([3, 4, 2000]), "V": PdMatrix(n)},
2152+
{"nu": Domain([0, 3, 4, np.inf], "int64"), "V": PdMatrix(n)},
20922153
lambda value, nu, V: scipy.stats.wishart.logpdf(value, np.int(nu), V),
20932154
)
20942155

@@ -2255,7 +2316,7 @@ def test_categorical_valid_p(self, p):
22552316
def test_categorical(self, n):
22562317
self.check_logp(
22572318
Categorical,
2258-
Domain(range(n), dtype="int64", edges=(None, None)),
2319+
Domain(range(n), dtype="int64", edges=(0, n)),
22592320
{"p": Simplex(n)},
22602321
lambda value, p: categorical_logpdf(value, p),
22612322
)
@@ -2368,8 +2429,8 @@ def test_ex_gaussian_cdf_outside_edges(self):
23682429
def test_vonmises(self):
23692430
self.check_logp(
23702431
VonMises,
2371-
R,
2372-
{"mu": Circ, "kappa": Rplus},
2432+
Circ,
2433+
{"mu": R, "kappa": Rplus},
23732434
lambda value, mu, kappa: floatX(sp.vonmises.logpdf(value, kappa, loc=mu)),
23742435
)
23752436

0 commit comments

Comments
 (0)