Skip to content

Commit c80657e

Browse files
committed
Group GaussianRandomWalk tests in single class
1 parent 54425cd commit c80657e

File tree

1 file changed

+75
-77
lines changed

1 file changed

+75
-77
lines changed

pymc/tests/test_distributions_timeseries.py

Lines changed: 75 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -32,102 +32,100 @@
3232
from pymc.tests.test_distributions_random import BaseTestDistributionRandom
3333

3434

35-
class TestGaussianRandomWalkRandom(BaseTestDistributionRandom):
36-
# Override default size for test class
37-
size = None
38-
39-
pymc_dist = pm.GaussianRandomWalk
40-
pymc_dist_params = {"mu": 1.0, "sigma": 2, "init": pm.Constant.dist(0), "steps": 4}
41-
expected_rv_op_params = {"mu": 1.0, "sigma": 2, "init": pm.Constant.dist(0), "steps": 4}
42-
43-
checks_to_run = [
44-
"check_pymc_params_match_rv_op",
45-
"check_rv_inferred_size",
46-
]
47-
48-
def check_rv_inferred_size(self):
49-
steps = self.pymc_dist_params["steps"]
50-
sizes_to_check = [None, (), 1, (1,)]
51-
sizes_expected = [(steps + 1,), (steps + 1,), (1, steps + 1), (1, steps + 1)]
52-
53-
for size, expected in zip(sizes_to_check, sizes_expected):
54-
pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
55-
expected_symbolic = tuple(pymc_rv.shape.eval())
56-
assert expected_symbolic == expected
57-
58-
def test_steps_scalar_check(self):
59-
with pytest.raises(ValueError, match="steps must be an integer scalar"):
60-
self.pymc_dist.dist(steps=[1])
61-
62-
63-
def test_gaussianrandomwalk_inference():
64-
mu, sigma, steps = 2, 1, 1000
65-
obs = np.concatenate([[0], np.random.normal(mu, sigma, size=steps)]).cumsum()
35+
class TestGaussianRandomWalk:
36+
class TestGaussianRandomWalkRandom(BaseTestDistributionRandom):
37+
# Override default size for test class
38+
size = None
39+
40+
pymc_dist = pm.GaussianRandomWalk
41+
pymc_dist_params = {"mu": 1.0, "sigma": 2, "init": pm.Constant.dist(0), "steps": 4}
42+
expected_rv_op_params = {"mu": 1.0, "sigma": 2, "init": pm.Constant.dist(0), "steps": 4}
43+
44+
checks_to_run = [
45+
"check_pymc_params_match_rv_op",
46+
"check_rv_inferred_size",
47+
]
6648

67-
with pm.Model():
68-
_mu = pm.Uniform("mu", -10, 10)
69-
_sigma = pm.Uniform("sigma", 0, 10)
49+
def check_rv_inferred_size(self):
50+
steps = self.pymc_dist_params["steps"]
51+
sizes_to_check = [None, (), 1, (1,)]
52+
sizes_expected = [(steps + 1,), (steps + 1,), (1, steps + 1), (1, steps + 1)]
7053

71-
obs_data = pm.MutableData("obs_data", obs)
72-
grw = GaussianRandomWalk("grw", _mu, _sigma, steps=steps, observed=obs_data)
54+
for size, expected in zip(sizes_to_check, sizes_expected):
55+
pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
56+
expected_symbolic = tuple(pymc_rv.shape.eval())
57+
assert expected_symbolic == expected
7358

74-
trace = pm.sample(chains=1)
59+
def test_steps_scalar_check(self):
60+
with pytest.raises(ValueError, match="steps must be an integer scalar"):
61+
self.pymc_dist.dist(steps=[1])
7562

76-
recovered_mu = trace.posterior["mu"].mean()
77-
recovered_sigma = trace.posterior["sigma"].mean()
78-
np.testing.assert_allclose([mu, sigma], [recovered_mu, recovered_sigma], atol=0.2)
63+
def test_gaussianrandomwalk_inference(self):
64+
mu, sigma, steps = 2, 1, 1000
65+
obs = np.concatenate([[0], np.random.normal(mu, sigma, size=steps)]).cumsum()
7966

67+
with pm.Model():
68+
_mu = pm.Uniform("mu", -10, 10)
69+
_sigma = pm.Uniform("sigma", 0, 10)
8070

81-
@pytest.mark.parametrize("init", [None, pm.Normal.dist()])
82-
def test_gaussian_random_walk_init_dist_shape(init):
83-
"""Test that init_dist is properly resized"""
84-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init)
85-
assert tuple(grw.owner.inputs[-2].shape.eval()) == ()
71+
obs_data = pm.MutableData("obs_data", obs)
72+
grw = GaussianRandomWalk("grw", _mu, _sigma, steps=steps, observed=obs_data)
8673

87-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init, size=(5,))
88-
assert tuple(grw.owner.inputs[-2].shape.eval()) == (5,)
74+
trace = pm.sample(chains=1)
8975

90-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init, shape=1)
91-
assert tuple(grw.owner.inputs[-2].shape.eval()) == ()
76+
recovered_mu = trace.posterior["mu"].mean()
77+
recovered_sigma = trace.posterior["sigma"].mean()
78+
np.testing.assert_allclose([mu, sigma], [recovered_mu, recovered_sigma], atol=0.2)
9279

93-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init, shape=(5, 1))
94-
assert tuple(grw.owner.inputs[-2].shape.eval()) == (5,)
80+
@pytest.mark.parametrize("init", [None, pm.Normal.dist()])
81+
def test_gaussian_random_walk_init_dist_shape(self, init):
82+
"""Test that init_dist is properly resized"""
83+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init)
84+
assert tuple(grw.owner.inputs[-2].shape.eval()) == ()
9585

96-
grw = pm.GaussianRandomWalk.dist(mu=[0, 0], sigma=1, steps=1, init=init)
97-
assert tuple(grw.owner.inputs[-2].shape.eval()) == (2,)
86+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init, size=(5,))
87+
assert tuple(grw.owner.inputs[-2].shape.eval()) == (5,)
9888

99-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=[1, 1], steps=1, init=init)
100-
assert tuple(grw.owner.inputs[-2].shape.eval()) == (2,)
89+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init, shape=1)
90+
assert tuple(grw.owner.inputs[-2].shape.eval()) == ()
10191

102-
grw = pm.GaussianRandomWalk.dist(mu=np.zeros((3, 1)), sigma=[1, 1], steps=1, init=init)
103-
assert tuple(grw.owner.inputs[-2].shape.eval()) == (3, 2)
92+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init, shape=(5, 1))
93+
assert tuple(grw.owner.inputs[-2].shape.eval()) == (5,)
10494

95+
grw = pm.GaussianRandomWalk.dist(mu=[0, 0], sigma=1, steps=1, init=init)
96+
assert tuple(grw.owner.inputs[-2].shape.eval()) == (2,)
10597

106-
def test_shape_ellipsis():
107-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=5, init=pm.Normal.dist(), shape=(3, ...))
108-
assert tuple(grw.shape.eval()) == (3, 6)
109-
assert tuple(grw.owner.inputs[-2].shape.eval()) == (3,)
98+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=[1, 1], steps=1, init=init)
99+
assert tuple(grw.owner.inputs[-2].shape.eval()) == (2,)
110100

101+
grw = pm.GaussianRandomWalk.dist(mu=np.zeros((3, 1)), sigma=[1, 1], steps=1, init=init)
102+
assert tuple(grw.owner.inputs[-2].shape.eval()) == (3, 2)
111103

112-
def test_gaussianrandomwalk_broadcasted_by_init_dist():
113-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=4, init=pm.Normal.dist(size=(2, 3)))
114-
assert tuple(grw.shape.eval()) == (2, 3, 5)
115-
assert grw.eval().shape == (2, 3, 5)
104+
def test_shape_ellipsis(self):
105+
grw = pm.GaussianRandomWalk.dist(
106+
mu=0, sigma=1, steps=5, init=pm.Normal.dist(), shape=(3, ...)
107+
)
108+
assert tuple(grw.shape.eval()) == (3, 6)
109+
assert tuple(grw.owner.inputs[-2].shape.eval()) == (3,)
116110

111+
def test_gaussianrandomwalk_broadcasted_by_init_dist(self):
112+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=4, init=pm.Normal.dist(size=(2, 3)))
113+
assert tuple(grw.shape.eval()) == (2, 3, 5)
114+
assert grw.eval().shape == (2, 3, 5)
117115

118-
@pytest.mark.parametrize(
119-
"init",
120-
[
121-
pm.HalfNormal.dist(sigma=2),
122-
pm.StudentT.dist(nu=4, mu=1, sigma=0.5),
123-
],
124-
)
125-
def test_gaussian_random_walk_init_dist_logp(init):
126-
grw = pm.GaussianRandomWalk.dist(init=init, steps=1)
127-
assert np.isclose(
128-
pm.logp(grw, [0, 0]).eval(),
129-
pm.logp(init, 0).eval() + scipy.stats.norm.logpdf(0),
116+
@pytest.mark.parametrize(
117+
"init",
118+
[
119+
pm.HalfNormal.dist(sigma=2),
120+
pm.StudentT.dist(nu=4, mu=1, sigma=0.5),
121+
],
130122
)
123+
def test_gaussian_random_walk_init_dist_logp(self, init):
124+
grw = pm.GaussianRandomWalk.dist(init=init, steps=1)
125+
assert np.isclose(
126+
pm.logp(grw, [0, 0]).eval(),
127+
pm.logp(init, 0).eval() + scipy.stats.norm.logpdf(0),
128+
)
131129

132130

133131
@pytest.mark.xfail(reason="Timeseries not refactored")

0 commit comments

Comments
 (0)