@@ -35,6 +35,8 @@ def polyagamma_cdf(*args, **kwargs):
35
35
raise RuntimeError ("polyagamma package is not installed!" )
36
36
37
37
38
+ from contextlib import ExitStack as does_not_raise
39
+
38
40
import pytest
39
41
import scipy .stats
40
42
import scipy .stats .distributions as sp
@@ -155,6 +157,14 @@ def __init__(self, vals, dtype=None, edges=None, shape=None):
155
157
if edges is None :
156
158
edges = array (vals [0 ]), array (vals [- 1 ])
157
159
vals = vals [1 :- 1 ]
160
+
161
+ if not vals :
162
+ raise ValueError (
163
+ f"Domain has no values left after removing edges: { edges } .\n "
164
+ "You can duplicate the edge values or explicitly specify the edges with the edge keyword.\n "
165
+ f"For example: `Domain([{ edges [0 ]} , { edges [0 ]} , { edges [1 ]} , { edges [1 ]} ])`"
166
+ )
167
+
158
168
if shape is None :
159
169
shape = avals [0 ].shape
160
170
@@ -192,6 +202,22 @@ def __neg__(self):
192
202
return Domain ([- v for v in self .vals ], self .dtype , (- self .lower , - self .upper ), self .shape )
193
203
194
204
205
+ @pytest .mark .parametrize (
206
+ "values, edges, expectation" ,
207
+ [
208
+ ([], None , pytest .raises (IndexError )),
209
+ ([], (0 , 0 ), pytest .raises (ValueError )),
210
+ ([0 ], None , pytest .raises (ValueError )),
211
+ ([0 ], (0 , 0 ), does_not_raise ()),
212
+ ([- 1 , 1 ], None , pytest .raises (ValueError )),
213
+ ([- 1 , 0 , 1 ], None , does_not_raise ()),
214
+ ],
215
+ )
216
+ def test_domain (values , edges , expectation ):
217
+ with expectation :
218
+ Domain (values , edges = edges )
219
+
220
+
195
221
def product (domains , n_samples = - 1 ):
196
222
"""Get an iterator over a product of domains.
197
223
@@ -2423,7 +2449,7 @@ def test_categorical_valid_p(self):
2423
2449
def test_categorical (self , n ):
2424
2450
self .check_logp (
2425
2451
Categorical ,
2426
- Domain (range (n ), "int64" ),
2452
+ Domain (range (n ), dtype = "int64" , edges = ( None , None ) ),
2427
2453
{"p" : Simplex (n )},
2428
2454
lambda value , p : categorical_logpdf (value , p ),
2429
2455
)
@@ -2432,7 +2458,7 @@ def test_categorical(self, n):
2432
2458
def test_orderedlogistic (self , n ):
2433
2459
self .check_logp (
2434
2460
OrderedLogistic ,
2435
- Domain (range (n ), "int64" ),
2461
+ Domain (range (n ), dtype = "int64" , edges = ( None , None ) ),
2436
2462
{"eta" : R , "cutpoints" : Vector (R , n - 1 )},
2437
2463
lambda value , eta , cutpoints : orderedlogistic_logpdf (value , eta , cutpoints ),
2438
2464
)
@@ -2441,7 +2467,7 @@ def test_orderedlogistic(self, n):
2441
2467
def test_orderedprobit (self , n ):
2442
2468
self .check_logp (
2443
2469
OrderedProbit ,
2444
- Domain (range (n ), "int64" ),
2470
+ Domain (range (n ), dtype = "int64" , edges = ( None , None ) ),
2445
2471
{"eta" : Runif , "cutpoints" : UnitSortedVector (n - 1 )},
2446
2472
lambda value , eta , cutpoints : orderedprobit_logpdf (value , eta , cutpoints ),
2447
2473
)
0 commit comments