@@ -152,15 +152,20 @@ def get_lkj_cases():
152
152
153
153
154
154
class Domain :
155
- def __init__ (self , vals , dtype = None , edges = None , shape = None ):
156
- avals = array (vals , dtype = dtype )
157
- if dtype is None and not str (avals .dtype ).startswith ("int" ):
158
- avals = avals .astype (aesara .config .floatX )
159
- vals = [array (v , dtype = avals .dtype ) for v in vals ]
155
+ def __init__ (self , vals , dtype = aesara .config .floatX , edges = None , shape = None ):
156
+ # Infinity values must be kept as floats
157
+ vals = [array (v , dtype = dtype ) if np .all (np .isfinite (v )) else floatX (v ) for v in vals ]
160
158
161
159
if edges is None :
162
160
edges = array (vals [0 ]), array (vals [- 1 ])
163
161
vals = vals [1 :- 1 ]
162
+ else :
163
+ edges = list (edges )
164
+ if edges [0 ] is None :
165
+ edges [0 ] = np .full_like (vals [0 ], - np .inf )
166
+ if edges [1 ] is None :
167
+ edges [1 ] = np .full_like (vals [0 ], np .inf )
168
+ edges = tuple (edges )
164
169
165
170
if not vals :
166
171
raise ValueError (
@@ -170,13 +175,12 @@ def __init__(self, vals, dtype=None, edges=None, shape=None):
170
175
)
171
176
172
177
if shape is None :
173
- shape = avals [0 ].shape
178
+ shape = vals [0 ].shape
174
179
175
180
self .vals = vals
176
181
self .shape = shape
177
-
178
182
self .lower , self .upper = edges
179
- self .dtype = avals . dtype
183
+ self .dtype = dtype
180
184
181
185
def __add__ (self , other ):
182
186
return Domain (
@@ -251,17 +255,17 @@ def product(domains, n_samples=-1):
251
255
252
256
Circ = Domain ([- np .pi , - 2.1 , - 1 , - 0.01 , 0.0 , 0.01 , 1 , 2.1 , np .pi ])
253
257
254
- Runif = Domain ([- 1 , - 0.4 , 0 , 0.4 , 1 ])
255
- Rdunif = Domain ([- 10 , 0 , 10.0 ] )
258
+ Runif = Domain ([- np . inf , - 0.4 , 0 , 0.4 , np . inf ])
259
+ Rdunif = Domain ([- np . inf , - 1 , 0 , 1 , np . inf ], "int64" )
256
260
Rplusunif = Domain ([0 , 0.5 , inf ])
257
- Rplusdunif = Domain ([2 , 10 , 100 ], "int64" )
261
+ Rplusdunif = Domain ([0 , 10 , np . inf ], "int64" )
258
262
259
- I = Domain ([- 1000 , - 3 , - 2 , - 1 , 0 , 1 , 2 , 3 , 1000 ], "int64" )
263
+ I = Domain ([- np . inf , - 3 , - 2 , - 1 , 0 , 1 , 2 , 3 , np . inf ], "int64" )
260
264
261
- NatSmall = Domain ([0 , 3 , 4 , 5 , 1000 ], "int64" )
262
- Nat = Domain ([0 , 1 , 2 , 3 , 2000 ], "int64" )
263
- NatBig = Domain ([0 , 1 , 2 , 3 , 5000 , 50000 ], "int64" )
264
- PosNat = Domain ([1 , 2 , 3 , 2000 ], "int64" )
265
+ NatSmall = Domain ([0 , 3 , 4 , 5 , np . inf ], "int64" )
266
+ Nat = Domain ([0 , 1 , 2 , 3 , np . inf ], "int64" )
267
+ NatBig = Domain ([0 , 1 , 2 , 3 , 5000 , np . inf ], "int64" )
268
+ PosNat = Domain ([1 , 2 , 3 , np . inf ], "int64" )
265
269
266
270
Bool = Domain ([0 , 0 , 1 , 1 ], "int64" )
267
271
@@ -523,20 +527,16 @@ def orderedprobit_logpdf(value, eta, cutpoints):
523
527
return np .where (np .all (ps >= 0 ), np .log (p ), - np .inf )
524
528
525
529
526
- class Simplex :
527
- def __init__ (self , n ):
528
- self .vals = list (simplex_values (n ))
529
- self .shape = (n ,)
530
- self .dtype = Unit .dtype
530
+ def Simplex (n ):
531
+ return Domain (simplex_values (n ), shape = (n ,), dtype = Unit .dtype , edges = (None , None ))
531
532
532
533
533
- class MultiSimplex :
534
- def __init__ (self , n_dependent , n_independent ):
535
- self .vals = []
536
- for simplex_value in itertools .product (simplex_values (n_dependent ), repeat = n_independent ):
537
- self .vals .append (np .vstack (simplex_value ))
538
- self .shape = (n_independent , n_dependent )
539
- self .dtype = Unit .dtype
534
+ def MultiSimplex (n_dependent , n_independent ):
535
+ vals = []
536
+ for simplex_value in itertools .product (simplex_values (n_dependent ), repeat = n_independent ):
537
+ vals .append (np .vstack (simplex_value ))
538
+
539
+ return Domain (vals , dtype = Unit .dtype , shape = (n_independent , n_dependent ))
540
540
541
541
542
542
def PdMatrix (n ):
@@ -811,18 +811,15 @@ def check_logcdf(
811
811
valid_params = {param : paramdomain .vals [0 ] for param , paramdomain in paramdomains .items ()}
812
812
valid_dist = pymc_dist .dist (** valid_params )
813
813
814
- # Natural domains do not have inf as the upper edge, but should also be ignored
815
- nat_domains = (NatSmall , Nat , NatBig , PosNat )
816
-
817
- # Test pymc distribution gives -inf for parameters outside the
818
- # supported domain edges (excluding edgse)
814
+ # Test pymc distribution raises ParameterValueError for parameters outside the
815
+ # supported domain edges (excluding edges)
819
816
if not skip_paramdomain_outside_edge_test :
820
817
# Step1: collect potential invalid parameters
821
818
invalid_params = {param : [None , None ] for param in paramdomains }
822
819
for param , paramdomain in paramdomains .items ():
823
820
if np .isfinite (paramdomain .lower ):
824
821
invalid_params [param ][0 ] = paramdomain .lower - 1
825
- if np .isfinite (paramdomain .upper ) and paramdomain not in nat_domains :
822
+ if np .isfinite (paramdomain .upper ):
826
823
invalid_params [param ][1 ] = paramdomain .upper + 1
827
824
# Step2: test invalid parameters, one a time
828
825
for invalid_param , invalid_edges in invalid_params .items ():
@@ -851,7 +848,7 @@ def check_logcdf(
851
848
)
852
849
853
850
# Test that values above domain edge evaluate to 0
854
- if domain not in nat_domains and np .isfinite (domain .upper ):
851
+ if np .isfinite (domain .upper ):
855
852
above_domain = domain .upper + 1
856
853
with aesara .config .change_flags (mode = Mode ("py" )):
857
854
assert_equal (
0 commit comments