@@ -96,20 +96,22 @@ def __init__(self, *args, **kwargs):
96
96
super (BaseTestCases .BaseTestCase , self ).__init__ (* args , ** kwargs )
97
97
self .model = pm .Model ()
98
98
99
- def get_random_variable (self , shape , with_vector_params = False ):
99
+ def get_random_variable (self , shape , with_vector_params = False , name = None ):
100
100
if with_vector_params :
101
101
params = {key : value * np .ones (self .shape , dtype = np .dtype (type (value ))) for
102
102
key , value in self .params .items ()}
103
103
else :
104
104
params = self .params
105
- name = self .distribution .__name__
105
+ if name is None :
106
+ name = self .distribution .__name__
106
107
with self .model :
107
108
if shape is None :
108
109
return self .distribution (name , transform = None , ** params )
109
110
else :
110
111
return self .distribution (name , shape = shape , transform = None , ** params )
111
112
112
- def sample_random_variable (self , random_variable , size ):
113
+ @staticmethod
114
+ def sample_random_variable (random_variable , size ):
113
115
try :
114
116
return random_variable .random (size = size )
115
117
except AttributeError :
@@ -145,7 +147,7 @@ def test_parameters_1d_shape(self):
145
147
else :
146
148
expected = np .atleast_1d (size ).tolist ()
147
149
expected .append (self .shape )
148
- actual = np . atleast_1d ( self .sample_random_variable (rv , size ) ).shape
150
+ actual = self .sample_random_variable (rv , size ).shape
149
151
self .assertSequenceEqual (expected , actual )
150
152
151
153
def test_broadcast_shape (self ):
@@ -160,6 +162,26 @@ def test_broadcast_shape(self):
160
162
actual = np .atleast_1d (self .sample_random_variable (rv , size )).shape
161
163
self .assertSequenceEqual (expected , actual )
162
164
165
+ def test_different_shapes_and_sample_sizes (self ):
166
+ shapes = [(), (1 ,), (1 , 1 ), (1 , 2 ), (10 , 10 , 1 ), (10 , 10 , 2 )]
167
+ prefix = self .distribution .__name__
168
+ expected = []
169
+ actual = []
170
+ for shape in shapes :
171
+ rv = self .get_random_variable (shape , name = '%s_%s' % (prefix , shape ))
172
+ for size in (None , 1 , 5 , (4 , 5 )):
173
+ if size is None :
174
+ s = []
175
+ else :
176
+ try :
177
+ s = list (size )
178
+ except TypeError :
179
+ s = [size ]
180
+ s .extend (shape )
181
+ expected .append (tuple (s ))
182
+ actual .append (self .sample_random_variable (rv , size ).shape )
183
+ self .assertSequenceEqual (expected , actual )
184
+
163
185
164
186
class TestNormal (BaseTestCases .BaseTestCase ):
165
187
distribution = pm .Normal
0 commit comments