Skip to content

Commit 97587bf

Browse files
committed
Refactor Domain helper class
* Allow edges to be infinity in discrete domains * Automatically assign infinity edges when these are set to (None, None) * Simplex and MultiSimplex now return a Domain instance
1 parent 5971bd0 commit 97587bf

File tree

1 file changed

+32
-35
lines changed

1 file changed

+32
-35
lines changed

pymc/tests/test_distributions.py

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,20 @@ def get_lkj_cases():
152152

153153

154154
class Domain:
155-
def __init__(self, vals, dtype=None, edges=None, shape=None):
156-
avals = array(vals, dtype=dtype)
157-
if dtype is None and not str(avals.dtype).startswith("int"):
158-
avals = avals.astype(aesara.config.floatX)
159-
vals = [array(v, dtype=avals.dtype) for v in vals]
155+
def __init__(self, vals, dtype=aesara.config.floatX, edges=None, shape=None):
156+
# Infinity values must be kept as floats
157+
vals = [array(v, dtype=dtype) if np.all(np.isfinite(v)) else floatX(v) for v in vals]
160158

161159
if edges is None:
162160
edges = array(vals[0]), array(vals[-1])
163161
vals = vals[1:-1]
162+
else:
163+
edges = list(edges)
164+
if edges[0] is None:
165+
edges[0] = np.full_like(vals[0], -np.inf)
166+
if edges[1] is None:
167+
edges[1] = np.full_like(vals[0], np.inf)
168+
edges = tuple(edges)
164169

165170
if not vals:
166171
raise ValueError(
@@ -170,13 +175,12 @@ def __init__(self, vals, dtype=None, edges=None, shape=None):
170175
)
171176

172177
if shape is None:
173-
shape = avals[0].shape
178+
shape = vals[0].shape
174179

175180
self.vals = vals
176181
self.shape = shape
177-
178182
self.lower, self.upper = edges
179-
self.dtype = avals.dtype
183+
self.dtype = dtype
180184

181185
def __add__(self, other):
182186
return Domain(
@@ -251,17 +255,17 @@ def product(domains, n_samples=-1):
251255

252256
Circ = Domain([-np.pi, -2.1, -1, -0.01, 0.0, 0.01, 1, 2.1, np.pi])
253257

254-
Runif = Domain([-1, -0.4, 0, 0.4, 1])
255-
Rdunif = Domain([-10, 0, 10.0])
258+
Runif = Domain([-np.inf, -0.4, 0, 0.4, np.inf])
259+
Rdunif = Domain([-np.inf, -1, 0, 1, np.inf], "int64")
256260
Rplusunif = Domain([0, 0.5, inf])
257-
Rplusdunif = Domain([2, 10, 100], "int64")
261+
Rplusdunif = Domain([0, 10, np.inf], "int64")
258262

259-
I = Domain([-1000, -3, -2, -1, 0, 1, 2, 3, 1000], "int64")
263+
I = Domain([-np.inf, -3, -2, -1, 0, 1, 2, 3, np.inf], "int64")
260264

261-
NatSmall = Domain([0, 3, 4, 5, 1000], "int64")
262-
Nat = Domain([0, 1, 2, 3, 2000], "int64")
263-
NatBig = Domain([0, 1, 2, 3, 5000, 50000], "int64")
264-
PosNat = Domain([1, 2, 3, 2000], "int64")
265+
NatSmall = Domain([0, 3, 4, 5, np.inf], "int64")
266+
Nat = Domain([0, 1, 2, 3, np.inf], "int64")
267+
NatBig = Domain([0, 1, 2, 3, 5000, np.inf], "int64")
268+
PosNat = Domain([1, 2, 3, np.inf], "int64")
265269

266270
Bool = Domain([0, 0, 1, 1], "int64")
267271

@@ -523,20 +527,16 @@ def orderedprobit_logpdf(value, eta, cutpoints):
523527
return np.where(np.all(ps >= 0), np.log(p), -np.inf)
524528

525529

526-
class Simplex:
527-
def __init__(self, n):
528-
self.vals = list(simplex_values(n))
529-
self.shape = (n,)
530-
self.dtype = Unit.dtype
530+
def Simplex(n):
531+
return Domain(simplex_values(n), shape=(n,), dtype=Unit.dtype, edges=(None, None))
531532

532533

533-
class MultiSimplex:
534-
def __init__(self, n_dependent, n_independent):
535-
self.vals = []
536-
for simplex_value in itertools.product(simplex_values(n_dependent), repeat=n_independent):
537-
self.vals.append(np.vstack(simplex_value))
538-
self.shape = (n_independent, n_dependent)
539-
self.dtype = Unit.dtype
534+
def MultiSimplex(n_dependent, n_independent):
535+
vals = []
536+
for simplex_value in itertools.product(simplex_values(n_dependent), repeat=n_independent):
537+
vals.append(np.vstack(simplex_value))
538+
539+
return Domain(vals, dtype=Unit.dtype, shape=(n_independent, n_dependent))
540540

541541

542542
def PdMatrix(n):
@@ -811,18 +811,15 @@ def check_logcdf(
811811
valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()}
812812
valid_dist = pymc_dist.dist(**valid_params)
813813

814-
# Natural domains do not have inf as the upper edge, but should also be ignored
815-
nat_domains = (NatSmall, Nat, NatBig, PosNat)
816-
817-
# Test pymc distribution gives -inf for parameters outside the
818-
# supported domain edges (excluding edgse)
814+
# Test pymc distribution raises ParameterValueError for parameters outside the
815+
# supported domain edges (excluding edges)
819816
if not skip_paramdomain_outside_edge_test:
820817
# Step1: collect potential invalid parameters
821818
invalid_params = {param: [None, None] for param in paramdomains}
822819
for param, paramdomain in paramdomains.items():
823820
if np.isfinite(paramdomain.lower):
824821
invalid_params[param][0] = paramdomain.lower - 1
825-
if np.isfinite(paramdomain.upper) and paramdomain not in nat_domains:
822+
if np.isfinite(paramdomain.upper):
826823
invalid_params[param][1] = paramdomain.upper + 1
827824
# Step2: test invalid parameters, one a time
828825
for invalid_param, invalid_edges in invalid_params.items():
@@ -851,7 +848,7 @@ def check_logcdf(
851848
)
852849

853850
# Test that values above domain edge evaluate to 0
854-
if domain not in nat_domains and np.isfinite(domain.upper):
851+
if np.isfinite(domain.upper):
855852
above_domain = domain.upper + 1
856853
with aesara.config.change_flags(mode=Mode("py")):
857854
assert_equal(

0 commit comments

Comments
 (0)