Skip to content

Commit 07a248a

Browse files
ferrinetwiecki
authored andcommitted
shape problem when sampling
1 parent 140a80c commit 07a248a

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,20 +96,22 @@ def __init__(self, *args, **kwargs):
9696
super(BaseTestCases.BaseTestCase, self).__init__(*args, **kwargs)
9797
self.model = pm.Model()
9898

99-
def get_random_variable(self, shape, with_vector_params=False):
99+
def get_random_variable(self, shape, with_vector_params=False, name=None):
100100
if with_vector_params:
101101
params = {key: value * np.ones(self.shape, dtype=np.dtype(type(value))) for
102102
key, value in self.params.items()}
103103
else:
104104
params = self.params
105-
name = self.distribution.__name__
105+
if name is None:
106+
name = self.distribution.__name__
106107
with self.model:
107108
if shape is None:
108109
return self.distribution(name, transform=None, **params)
109110
else:
110111
return self.distribution(name, shape=shape, transform=None, **params)
111112

112-
def sample_random_variable(self, random_variable, size):
113+
@staticmethod
114+
def sample_random_variable(random_variable, size):
113115
try:
114116
return random_variable.random(size=size)
115117
except AttributeError:
@@ -145,7 +147,7 @@ def test_parameters_1d_shape(self):
145147
else:
146148
expected = np.atleast_1d(size).tolist()
147149
expected.append(self.shape)
148-
actual = np.atleast_1d(self.sample_random_variable(rv, size)).shape
150+
actual = self.sample_random_variable(rv, size).shape
149151
self.assertSequenceEqual(expected, actual)
150152

151153
def test_broadcast_shape(self):
@@ -160,6 +162,26 @@ def test_broadcast_shape(self):
160162
actual = np.atleast_1d(self.sample_random_variable(rv, size)).shape
161163
self.assertSequenceEqual(expected, actual)
162164

165+
def test_different_shapes_and_sample_sizes(self):
166+
shapes = [(), (1,), (1, 1), (1, 2), (10, 10, 1), (10, 10, 2)]
167+
prefix = self.distribution.__name__
168+
expected = []
169+
actual = []
170+
for shape in shapes:
171+
rv = self.get_random_variable(shape, name='%s_%s' % (prefix, shape))
172+
for size in (None, 1, 5, (4, 5)):
173+
if size is None:
174+
s = []
175+
else:
176+
try:
177+
s = list(size)
178+
except TypeError:
179+
s = [size]
180+
s.extend(shape)
181+
expected.append(tuple(s))
182+
actual.append(self.sample_random_variable(rv, size).shape)
183+
self.assertSequenceEqual(expected, actual)
184+
163185

164186
class TestNormal(BaseTestCases.BaseTestCase):
165187
distribution = pm.Normal

0 commit comments

Comments
 (0)