Skip to content

Commit 2f5543d

Browse files
ArmavicaricardoV94
authored andcommitted
Break up test_distributions.py
1 parent 23df308 commit 2f5543d

19 files changed

+3620
-3427
lines changed

.github/workflows/tests.yml

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,38 +38,43 @@ jobs:
3838
test-subset:
3939
- |
4040
pymc/tests/test_util.py
41-
pymc/tests/test_logprob.py
41+
pymc/tests/distributions/test_logprob.py
4242
pymc/tests/test_aesaraf.py
4343
pymc/tests/test_math.py
4444
pymc/tests/test_posdef_sym.py
4545
pymc/tests/test_ndarray_backend.py
4646
pymc/tests/test_hmc.py
4747
pymc/tests/test_func_utils.py
48-
pymc/tests/test_shape_handling.py
48+
pymc/tests/distributions/test_shape_utils.py
4949
pymc/tests/test_starting.py
50-
pymc/tests/test_mixture.py
50+
pymc/tests/distributions/test_mixture.py
5151
5252
- |
53-
pymc/tests/test_distributions.py
53+
pymc/tests/distributions/test_distribution.py
54+
pymc/tests/distributions/test_bound.py
55+
pymc/tests/distributions/test_censored.py
56+
pymc/tests/distributions/test_discrete.py
57+
pymc/tests/distributions/test_continuous.py
58+
pymc/tests/distributions/test_multivariate.py
5459
5560
- |
5661
pymc/tests/test_tuning.py
5762
pymc/tests/test_shared.py
5863
pymc/tests/test_types.py
59-
pymc/tests/test_distributions_moments.py
64+
pymc/tests/distributions/test_moments.py
6065
6166
- |
6267
pymc/tests/test_modelcontext.py
63-
pymc/tests/test_dist_math.py
68+
pymc/tests/distributions/test_dist_math.py
6469
pymc/tests/test_minibatches.py
6570
pymc/tests/test_pickling.py
66-
pymc/tests/test_transforms.py
71+
pymc/tests/distributions/test_transform.py
6772
pymc/tests/test_parallel_sampling.py
6873
pymc/tests/test_printing.py
6974
7075
- |
71-
pymc/tests/test_distributions_random.py
72-
pymc/tests/test_distributions_timeseries.py
76+
pymc/tests/distributions/test_random.py
77+
pymc/tests/distributions/test_timeseries.py
7378
pymc/tests/test_gp.py
7479
pymc/tests/test_model.py
7580
pymc/tests/test_model_graph.py
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright 2020 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import warnings
16+
17+
import numpy as np
18+
import pytest
19+
import scipy.stats as st
20+
21+
from aesara.tensor.random.op import RandomVariable
22+
23+
import pymc as pm
24+
25+
from pymc.distributions import joint_logp
26+
27+
28+
class TestBound:
29+
"""Tests for pm.Bound distribution"""
30+
31+
def test_continuous(self):
32+
with pm.Model() as model:
33+
dist = pm.Normal.dist(mu=0, sigma=1)
34+
with warnings.catch_warnings():
35+
warnings.filterwarnings(
36+
"ignore", "invalid value encountered in add", RuntimeWarning
37+
)
38+
UnboundedNormal = pm.Bound("unbound", dist, transform=None)
39+
InfBoundedNormal = pm.Bound(
40+
"infbound", dist, lower=-np.inf, upper=np.inf, transform=None
41+
)
42+
LowerNormal = pm.Bound("lower", dist, lower=0, transform=None)
43+
UpperNormal = pm.Bound("upper", dist, upper=0, transform=None)
44+
BoundedNormal = pm.Bound("bounded", dist, lower=1, upper=10, transform=None)
45+
LowerNormalTransform = pm.Bound("lowertrans", dist, lower=1)
46+
UpperNormalTransform = pm.Bound("uppertrans", dist, upper=10)
47+
BoundedNormalTransform = pm.Bound("boundedtrans", dist, lower=1, upper=10)
48+
49+
assert joint_logp(LowerNormal, -1).eval() == -np.inf
50+
assert joint_logp(UpperNormal, 1).eval() == -np.inf
51+
assert joint_logp(BoundedNormal, 0).eval() == -np.inf
52+
assert joint_logp(BoundedNormal, 11).eval() == -np.inf
53+
54+
assert joint_logp(UnboundedNormal, 0).eval() != -np.inf
55+
assert joint_logp(UnboundedNormal, 11).eval() != -np.inf
56+
assert joint_logp(InfBoundedNormal, 0).eval() != -np.inf
57+
assert joint_logp(InfBoundedNormal, 11).eval() != -np.inf
58+
59+
value = model.rvs_to_values[LowerNormalTransform]
60+
assert joint_logp(LowerNormalTransform, value).eval({value: -1}) != -np.inf
61+
value = model.rvs_to_values[UpperNormalTransform]
62+
assert joint_logp(UpperNormalTransform, value).eval({value: 1}) != -np.inf
63+
value = model.rvs_to_values[BoundedNormalTransform]
64+
assert joint_logp(BoundedNormalTransform, value).eval({value: 0}) != -np.inf
65+
assert joint_logp(BoundedNormalTransform, value).eval({value: 11}) != -np.inf
66+
67+
ref_dist = pm.Normal.dist(mu=0, sigma=1)
68+
assert np.allclose(joint_logp(UnboundedNormal, 5).eval(), joint_logp(ref_dist, 5).eval())
69+
assert np.allclose(joint_logp(LowerNormal, 5).eval(), joint_logp(ref_dist, 5).eval())
70+
assert np.allclose(joint_logp(UpperNormal, -5).eval(), joint_logp(ref_dist, 5).eval())
71+
assert np.allclose(joint_logp(BoundedNormal, 5).eval(), joint_logp(ref_dist, 5).eval())
72+
73+
def test_discrete(self):
74+
with pm.Model() as model:
75+
dist = pm.Poisson.dist(mu=4)
76+
with warnings.catch_warnings():
77+
warnings.filterwarnings(
78+
"ignore", "invalid value encountered in add", RuntimeWarning
79+
)
80+
UnboundedPoisson = pm.Bound("unbound", dist)
81+
LowerPoisson = pm.Bound("lower", dist, lower=1)
82+
UpperPoisson = pm.Bound("upper", dist, upper=10)
83+
BoundedPoisson = pm.Bound("bounded", dist, lower=1, upper=10)
84+
85+
assert joint_logp(LowerPoisson, 0).eval() == -np.inf
86+
assert joint_logp(UpperPoisson, 11).eval() == -np.inf
87+
assert joint_logp(BoundedPoisson, 0).eval() == -np.inf
88+
assert joint_logp(BoundedPoisson, 11).eval() == -np.inf
89+
90+
assert joint_logp(UnboundedPoisson, 0).eval() != -np.inf
91+
assert joint_logp(UnboundedPoisson, 11).eval() != -np.inf
92+
93+
ref_dist = pm.Poisson.dist(mu=4)
94+
assert np.allclose(joint_logp(UnboundedPoisson, 5).eval(), joint_logp(ref_dist, 5).eval())
95+
assert np.allclose(joint_logp(LowerPoisson, 5).eval(), joint_logp(ref_dist, 5).eval())
96+
assert np.allclose(joint_logp(UpperPoisson, 5).eval(), joint_logp(ref_dist, 5).eval())
97+
assert np.allclose(joint_logp(BoundedPoisson, 5).eval(), joint_logp(ref_dist, 5).eval())
98+
99+
def create_invalid_distribution(self):
100+
class MyNormal(RandomVariable):
101+
name = "my_normal"
102+
ndim_supp = 0
103+
ndims_params = [0, 0]
104+
dtype = "floatX"
105+
106+
my_normal = MyNormal()
107+
108+
class InvalidDistribution(pm.Distribution):
109+
rv_op = my_normal
110+
111+
@classmethod
112+
def dist(cls, mu=0, sigma=1, **kwargs):
113+
return super().dist([mu, sigma], **kwargs)
114+
115+
return InvalidDistribution
116+
117+
def test_arguments_checks(self):
118+
msg = "Observed Bound distributions are not supported"
119+
with pm.Model() as m:
120+
x = pm.Normal("x", 0, 1)
121+
with pytest.raises(ValueError, match=msg):
122+
pm.Bound("bound", x, observed=5)
123+
124+
msg = "Cannot transform discrete variable."
125+
with pm.Model() as m:
126+
x = pm.Poisson.dist(0.5)
127+
with warnings.catch_warnings():
128+
warnings.filterwarnings(
129+
"ignore", "invalid value encountered in add", RuntimeWarning
130+
)
131+
with pytest.raises(ValueError, match=msg):
132+
pm.Bound("bound", x, transform=pm.distributions.transforms.log)
133+
134+
msg = "Given dims do not exist in model coordinates."
135+
with pm.Model() as m:
136+
x = pm.Poisson.dist(0.5)
137+
with pytest.raises(ValueError, match=msg):
138+
pm.Bound("bound", x, dims="random_dims")
139+
140+
msg = "The dist x was already registered in the current model"
141+
with pm.Model() as m:
142+
x = pm.Normal("x", 0, 1)
143+
with pytest.raises(ValueError, match=msg):
144+
pm.Bound("bound", x)
145+
146+
msg = "Passing a distribution class to `Bound` is no longer supported"
147+
with pm.Model() as m:
148+
with pytest.raises(ValueError, match=msg):
149+
pm.Bound("bound", pm.Normal)
150+
151+
msg = "Bounding of MultiVariate RVs is not yet supported"
152+
with pm.Model() as m:
153+
x = pm.MvNormal.dist(np.zeros(3), np.eye(3))
154+
with pytest.raises(NotImplementedError, match=msg):
155+
pm.Bound("bound", x)
156+
157+
msg = "must be a Discrete or Continuous distribution subclass"
158+
with pm.Model() as m:
159+
x = self.create_invalid_distribution().dist()
160+
with pytest.raises(ValueError, match=msg):
161+
pm.Bound("bound", x)
162+
163+
def test_invalid_sampling(self):
164+
msg = "Cannot sample from a bounded variable"
165+
with pm.Model() as m:
166+
dist = pm.Normal.dist(mu=0, sigma=1)
167+
BoundedNormal = pm.Bound("bounded", dist, lower=1, upper=10)
168+
with pytest.raises(NotImplementedError, match=msg):
169+
pm.sample_prior_predictive()
170+
171+
def test_bound_shapes(self):
172+
with pm.Model(coords={"sample": np.ones((2, 5))}) as m:
173+
dist = pm.Normal.dist(mu=0, sigma=1)
174+
bound_sized = pm.Bound("boundedsized", dist, lower=1, upper=10, size=(4, 5))
175+
bound_shaped = pm.Bound("boundedshaped", dist, lower=1, upper=10, shape=(3, 5))
176+
bound_dims = pm.Bound("boundeddims", dist, lower=1, upper=10, dims="sample")
177+
178+
initial_point = m.initial_point()
179+
dist_size = initial_point["boundedsized_interval__"].shape
180+
dist_shape = initial_point["boundedshaped_interval__"].shape
181+
dist_dims = initial_point["boundeddims_interval__"].shape
182+
183+
assert dist_size == (4, 5)
184+
assert dist_shape == (3, 5)
185+
assert dist_dims == (2, 5)
186+
187+
def test_bound_dist(self):
188+
# Continuous
189+
bound = pm.Bound.dist(pm.Normal.dist(0, 1), lower=0)
190+
assert pm.logp(bound, -1).eval() == -np.inf
191+
assert np.isclose(pm.logp(bound, 1).eval(), st.norm(0, 1).logpdf(1))
192+
193+
# Discrete
194+
bound = pm.Bound.dist(pm.Poisson.dist(1), lower=2)
195+
assert pm.logp(bound, 1).eval() == -np.inf
196+
assert np.isclose(pm.logp(bound, 2).eval(), st.poisson(1).logpmf(2))
197+
198+
def test_array_bound(self):
199+
with pm.Model() as model:
200+
dist = pm.Normal.dist()
201+
with warnings.catch_warnings():
202+
warnings.filterwarnings(
203+
"ignore", "invalid value encountered in add", RuntimeWarning
204+
)
205+
LowerPoisson = pm.Bound("lower", dist, lower=[1, None], transform=None)
206+
UpperPoisson = pm.Bound("upper", dist, upper=[np.inf, 10], transform=None)
207+
BoundedPoisson = pm.Bound("bounded", dist, lower=[1, 2], upper=[9, 10], transform=None)
208+
209+
first, second = joint_logp(LowerPoisson, [0, 0], sum=False)[0].eval()
210+
assert first == -np.inf
211+
assert second != -np.inf
212+
213+
first, second = joint_logp(UpperPoisson, [11, 11], sum=False)[0].eval()
214+
assert first != -np.inf
215+
assert second == -np.inf
216+
217+
first, second = joint_logp(BoundedPoisson, [1, 1], sum=False)[0].eval()
218+
assert first != -np.inf
219+
assert second == -np.inf
220+
221+
first, second = joint_logp(BoundedPoisson, [10, 10], sum=False)[0].eval()
222+
assert first == -np.inf
223+
assert second != -np.inf
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright 2020 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import pytest
17+
18+
import pymc as pm
19+
20+
21+
class TestCensored:
22+
@pytest.mark.parametrize("censored", (False, True))
23+
def test_censored_workflow(self, censored):
24+
# Based on pymc-examples/censored_data
25+
rng = np.random.default_rng(1234)
26+
size = 500
27+
true_mu = 13.0
28+
true_sigma = 5.0
29+
30+
# Set censoring limits
31+
low = 3.0
32+
high = 16.0
33+
34+
# Draw censored samples
35+
data = rng.normal(true_mu, true_sigma, size)
36+
data[data <= low] = low
37+
data[data >= high] = high
38+
39+
rng = 17092021
40+
with pm.Model() as m:
41+
mu = pm.Normal(
42+
"mu",
43+
mu=((high - low) / 2) + low,
44+
sigma=(high - low) / 2.0,
45+
initval="moment",
46+
)
47+
sigma = pm.HalfNormal("sigma", sigma=(high - low) / 2.0, initval="moment")
48+
observed = pm.Censored(
49+
"observed",
50+
pm.Normal.dist(mu=mu, sigma=sigma),
51+
lower=low if censored else None,
52+
upper=high if censored else None,
53+
observed=data,
54+
)
55+
56+
prior_pred = pm.sample_prior_predictive(random_seed=rng)
57+
posterior = pm.sample(tune=500, draws=500, random_seed=rng)
58+
posterior_pred = pm.sample_posterior_predictive(posterior, random_seed=rng)
59+
60+
expected = True if censored else False
61+
assert (9 < prior_pred.prior_predictive.mean() < 10) == expected
62+
assert (13 < posterior.posterior["mu"].mean() < 14) == expected
63+
assert (4.5 < posterior.posterior["sigma"].mean() < 5.5) == expected
64+
assert (12 < posterior_pred.posterior_predictive.mean() < 13) == expected
65+
66+
def test_censored_invalid_dist(self):
67+
with pm.Model():
68+
invalid_dist = pm.Normal
69+
with pytest.raises(
70+
ValueError,
71+
match=r"Censoring dist must be a distribution created via the",
72+
):
73+
x = pm.Censored("x", invalid_dist, lower=None, upper=None)
74+
75+
with pm.Model():
76+
mv_dist = pm.Dirichlet.dist(a=[1, 1, 1])
77+
with pytest.raises(
78+
NotImplementedError,
79+
match="Censoring of multivariate distributions has not been implemented yet",
80+
):
81+
x = pm.Censored("x", mv_dist, lower=None, upper=None)
82+
83+
with pm.Model():
84+
registered_dist = pm.Normal("dist")
85+
with pytest.raises(
86+
ValueError,
87+
match="The dist dist was already registered in the current model",
88+
):
89+
x = pm.Censored("x", registered_dist, lower=None, upper=None)
90+
91+
def test_change_size(self):
92+
base_dist = pm.Censored.dist(pm.Normal.dist(), -1, 1, size=(3, 2))
93+
94+
new_dist = pm.Censored.change_size(base_dist, (4,))
95+
assert new_dist.eval().shape == (4,)
96+
97+
new_dist = pm.Censored.change_size(base_dist, (4,), expand=True)
98+
assert new_dist.eval().shape == (4, 3, 2)
99+
100+
def test_dist_broadcasted_by_lower_upper(self):
101+
x = pm.Censored.dist(pm.Normal.dist(), lower=np.zeros((2,)), upper=None)
102+
assert tuple(x.owner.inputs[0].shape.eval()) == (2,)
103+
104+
x = pm.Censored.dist(pm.Normal.dist(), lower=np.zeros((2,)), upper=np.zeros((4, 2)))
105+
assert tuple(x.owner.inputs[0].shape.eval()) == (4, 2)
106+
107+
x = pm.Censored.dist(
108+
pm.Normal.dist(size=(3, 4, 2)), lower=np.zeros((2,)), upper=np.zeros((4, 2))
109+
)
110+
assert tuple(x.owner.inputs[0].shape.eval()) == (3, 4, 2)

0 commit comments

Comments
 (0)