Skip to content

Commit e8a9bb2

Browse files
committed
Test RandomWalk change size and fix bug
1 parent eace9ed commit e8a9bb2

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

pymc/distributions/timeseries.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def rv_op(cls, init_dist, innovation_dist, steps, size=None):
118118

119119
# If not explicit, size is determined by the shapes of the input distributions
120120
if size is None:
121-
size = at.broadcast_shape(init_dist, innovation_dist[..., 0])
121+
size = at.broadcast_shape(init_dist, at.atleast_1d(innovation_dist)[..., 0])
122122
innovation_size = tuple(size) + (steps,)
123123

124124
# Resize input distributions

pymc/tests/distributions/test_timeseries.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,19 @@ def test_dists_not_registered_check(self):
8282
):
8383
RandomWalk("rw", init_dist=init_dist, innovation_dist=innovation, steps=5)
8484

85+
def test_change_size(self):
86+
init_dist = Normal.dist()
87+
innovation_dist = Normal.dist()
88+
89+
# size = 5
90+
rw = RandomWalk.dist(init_dist=init_dist, innovation_dist=innovation_dist, shape=(5, 100))
91+
92+
new_rw = change_dist_size(rw, new_size=(7,))
93+
assert tuple(new_rw.shape.eval()) == (7, 100)
94+
95+
new_rw = change_dist_size(rw, new_size=(4, 3), expand=True)
96+
assert tuple(new_rw.shape.eval()) == (4, 3, 5, 100)
97+
8598

8699
class TestGaussianRandomWalk:
87100
def test_logp(self):

0 commit comments

Comments
 (0)