Skip to content

Commit eace9ed

Browse files
committed
Fix RandomWalk dist checks
1 parent 85cfc99 commit eace9ed

File tree

2 files changed

+49
-7
lines changed

2 files changed

+49
-7
lines changed

pymc/distributions/timeseries.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,12 @@ def dist(cls, init_dist, innovation_dist, steps=None, **kwargs) -> at.TensorVari
102102

103103
if not (
104104
isinstance(innovation_dist, at.TensorVariable)
105-
and init_dist.owner is not None
106-
and isinstance(init_dist.owner.op, (RandomVariable, SymbolicRandomVariable))
107-
# TODO: Lift univariate constraint on inovvation_dist
108-
and init_dist.owner.op.ndim_supp == 0
105+
and innovation_dist.owner is not None
106+
and isinstance(innovation_dist.owner.op, (RandomVariable, SymbolicRandomVariable))
107+
and innovation_dist.owner.op.ndim_supp == 0
109108
):
110-
raise TypeError("init_dist must be a univariate distribution variable")
111-
check_dist_not_registered(init_dist)
109+
raise TypeError("innovation_dist must be a univariate distribution variable")
110+
check_dist_not_registered(innovation_dist)
112111

113112
return super().dist([init_dist, innovation_dist, steps], **kwargs)
114113

pymc/tests/distributions/test_timeseries.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@
2727
from pymc.distributions.logprob import logp
2828
from pymc.distributions.multivariate import Dirichlet
2929
from pymc.distributions.shape_utils import change_dist_size, to_tuple
30-
from pymc.distributions.timeseries import AR, GARCH11, EulerMaruyama, GaussianRandomWalk
30+
from pymc.distributions.timeseries import (
31+
AR,
32+
GARCH11,
33+
EulerMaruyama,
34+
GaussianRandomWalk,
35+
RandomWalk,
36+
)
3137
from pymc.model import Model
3238
from pymc.sampling import draw, sample, sample_posterior_predictive
3339
from pymc.tests.distributions.util import (
@@ -40,6 +46,43 @@
4046
from pymc.tests.helpers import SeededTest, select_by_precision
4147

4248

49+
class TestRandomWalk:
50+
def test_dists_types(self):
51+
init_dist = Normal.dist()
52+
innovation_dist = Normal.dist()
53+
54+
with pytest.raises(
55+
TypeError,
56+
match="init_dist must be a univariate distribution variable",
57+
):
58+
RandomWalk.dist(init_dist=5, innovation_dist=innovation_dist, steps=5)
59+
60+
with pytest.raises(
61+
TypeError,
62+
match="innovation_dist must be a univariate distribution variable",
63+
):
64+
RandomWalk.dist(init_dist=init_dist, innovation_dist=5, steps=5)
65+
66+
def test_dists_not_registered_check(self):
67+
with Model():
68+
init = Normal("init")
69+
innovation = Normal("innovation")
70+
71+
init_dist = Normal.dist()
72+
innovation_dist = Normal.dist()
73+
with pytest.raises(
74+
ValueError,
75+
match="The dist init was already registered in the current model",
76+
):
77+
RandomWalk("rw", init_dist=init, innovation_dist=innovation_dist, steps=5)
78+
79+
with pytest.raises(
80+
ValueError,
81+
match="The dist innovation was already registered in the current model",
82+
):
83+
RandomWalk("rw", init_dist=init_dist, innovation_dist=innovation, steps=5)
84+
85+
4386
class TestGaussianRandomWalk:
4487
def test_logp(self):
4588
def ref_logp(value, mu, sigma):

0 commit comments

Comments
 (0)