Skip to content

Commit 1d2f592

Browse files
ricardoV94twiecki
authored andcommitted
Raise ValueError when Domain has no values
1 parent 598dd9d commit 1d2f592

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

pymc/tests/test_distributions.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def polyagamma_cdf(*args, **kwargs):
3535
raise RuntimeError("polyagamma package is not installed!")
3636

3737

38+
from contextlib import ExitStack as does_not_raise
39+
3840
import pytest
3941
import scipy.stats
4042
import scipy.stats.distributions as sp
@@ -155,6 +157,14 @@ def __init__(self, vals, dtype=None, edges=None, shape=None):
155157
if edges is None:
156158
edges = array(vals[0]), array(vals[-1])
157159
vals = vals[1:-1]
160+
161+
if not vals:
162+
raise ValueError(
163+
f"Domain has no values left after removing edges: {edges}.\n"
164+
"You can duplicate the edge values or explicitly specify the edges with the edge keyword.\n"
165+
f"For example: `Domain([{edges[0]}, {edges[0]}, {edges[1]}, {edges[1]}])`"
166+
)
167+
158168
if shape is None:
159169
shape = avals[0].shape
160170

@@ -192,6 +202,22 @@ def __neg__(self):
192202
return Domain([-v for v in self.vals], self.dtype, (-self.lower, -self.upper), self.shape)
193203

194204

205+
@pytest.mark.parametrize(
206+
"values, edges, expectation",
207+
[
208+
([], None, pytest.raises(IndexError)),
209+
([], (0, 0), pytest.raises(ValueError)),
210+
([0], None, pytest.raises(ValueError)),
211+
([0], (0, 0), does_not_raise()),
212+
([-1, 1], None, pytest.raises(ValueError)),
213+
([-1, 0, 1], None, does_not_raise()),
214+
],
215+
)
216+
def test_domain(values, edges, expectation):
217+
with expectation:
218+
Domain(values, edges=edges)
219+
220+
195221
def product(domains, n_samples=-1):
196222
"""Get an iterator over a product of domains.
197223
@@ -2423,7 +2449,7 @@ def test_categorical_valid_p(self):
24232449
def test_categorical(self, n):
24242450
self.check_logp(
24252451
Categorical,
2426-
Domain(range(n), "int64"),
2452+
Domain(range(n), dtype="int64", edges=(None, None)),
24272453
{"p": Simplex(n)},
24282454
lambda value, p: categorical_logpdf(value, p),
24292455
)
@@ -2432,7 +2458,7 @@ def test_categorical(self, n):
24322458
def test_orderedlogistic(self, n):
24332459
self.check_logp(
24342460
OrderedLogistic,
2435-
Domain(range(n), "int64"),
2461+
Domain(range(n), dtype="int64", edges=(None, None)),
24362462
{"eta": R, "cutpoints": Vector(R, n - 1)},
24372463
lambda value, eta, cutpoints: orderedlogistic_logpdf(value, eta, cutpoints),
24382464
)
@@ -2441,7 +2467,7 @@ def test_orderedlogistic(self, n):
24412467
def test_orderedprobit(self, n):
24422468
self.check_logp(
24432469
OrderedProbit,
2444-
Domain(range(n), "int64"),
2470+
Domain(range(n), dtype="int64", edges=(None, None)),
24452471
{"eta": Runif, "cutpoints": UnitSortedVector(n - 1)},
24462472
lambda value, eta, cutpoints: orderedprobit_logpdf(value, eta, cutpoints),
24472473
)

pymc/tests/test_distributions_random.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,16 @@ def pymc_random(
6464
dist,
6565
paramdomains,
6666
ref_rand,
67-
valuedomain=Domain([0]),
67+
valuedomain=None,
6868
size=10000,
6969
alpha=0.05,
7070
fails=10,
7171
extra_args=None,
7272
model_args=None,
7373
):
74+
if valuedomain is None:
75+
valuedomain = Domain([0], edges=(None, None))
76+
7477
if model_args is None:
7578
model_args = {}
7679

@@ -104,12 +107,15 @@ def pymc_random(
104107
def pymc_random_discrete(
105108
dist,
106109
paramdomains,
107-
valuedomain=Domain([0]),
110+
valuedomain=None,
108111
ref_rand=None,
109112
size=100000,
110113
alpha=0.05,
111114
fails=20,
112115
):
116+
if valuedomain is None:
117+
valuedomain = Domain([0], edges=(None, None))
118+
113119
model, param_vars = build_model(dist, valuedomain, paramdomains)
114120
model_dist = change_rv_size(model.named_vars["value"], size, expand=True)
115121
pymc_rand = aesara.function([], model_dist)

0 commit comments

Comments
 (0)