47
47
from pymc .exceptions import ShapeError
48
48
from pymc .vartypes import int_types
49
49
50
- FLOATX = str (aesara .config .floatX )
51
- INTX = str (_conversion_map [FLOATX ])
52
-
53
50
54
51
def test_change_rv_size ():
55
52
loc = at .as_tensor_variable ([1 , 2 ])
@@ -176,57 +173,59 @@ def setup_class(self):
176
173
self .output_buffer = dict ()
177
174
self .func_buffer = dict ()
178
175
179
- def _input_tensors (self , shape ):
176
+ def _input_tensors (self , shape , floatX ):
177
+ intX = str (_conversion_map [floatX ])
180
178
ndim = len (shape )
181
- arr = TensorType (FLOATX , [False ] * ndim )("arr" )
182
- indices = TensorType (INTX , [False ] * ndim )("indices" )
183
- arr .tag .test_value = np .zeros (shape , dtype = FLOATX )
184
- indices .tag .test_value = np .zeros (shape , dtype = INTX )
179
+ arr = TensorType (floatX , [False ] * ndim )("arr" )
180
+ indices = TensorType (intX , [False ] * ndim )("indices" )
181
+ arr .tag .test_value = np .zeros (shape , dtype = floatX )
182
+ indices .tag .test_value = np .zeros (shape , dtype = intX )
185
183
return arr , indices
186
184
187
- def get_input_tensors (self , shape ):
185
+ def get_input_tensors (self , shape , floatX ):
188
186
ndim = len (shape )
189
187
try :
190
- return self .inputs_buffer [ndim ]
188
+ return self .inputs_buffer [( ndim , floatX ) ]
191
189
except KeyError :
192
- arr , indices = self ._input_tensors (shape )
193
- self .inputs_buffer [ndim ] = arr , indices
190
+ arr , indices = self ._input_tensors (shape , floatX )
191
+ self .inputs_buffer [( ndim , floatX ) ] = arr , indices
194
192
return arr , indices
195
193
196
194
def _output_tensor (self , arr , indices , axis ):
197
195
return take_along_axis (arr , indices , axis )
198
196
199
- def get_output_tensors (self , shape , axis ):
197
+ def get_output_tensors (self , shape , axis , floatX ):
200
198
ndim = len (shape )
201
199
try :
202
- return self .output_buffer [(ndim , axis )]
200
+ return self .output_buffer [(ndim , axis , floatX )]
203
201
except KeyError :
204
- arr , indices = self .get_input_tensors (shape )
202
+ arr , indices = self .get_input_tensors (shape , floatX )
205
203
out = self ._output_tensor (arr , indices , axis )
206
- self .output_buffer [(ndim , axis )] = out
204
+ self .output_buffer [(ndim , axis , floatX )] = out
207
205
return out
208
206
209
207
def _function (self , arr , indices , out ):
210
208
return aesara .function ([arr , indices ], [out ])
211
209
212
- def get_function (self , shape , axis ):
210
+ def get_function (self , shape , axis , floatX ):
213
211
ndim = len (shape )
214
212
try :
215
- return self .func_buffer [(ndim , axis )]
213
+ return self .func_buffer [(ndim , axis , floatX )]
216
214
except KeyError :
217
- arr , indices = self .get_input_tensors (shape )
218
- out = self .get_output_tensors (shape , axis )
215
+ arr , indices = self .get_input_tensors (shape , floatX )
216
+ out = self .get_output_tensors (shape , axis , floatX )
219
217
func = self ._function (arr , indices , out )
220
- self .func_buffer [(ndim , axis )] = func
218
+ self .func_buffer [(ndim , axis , floatX )] = func
221
219
return func
222
220
223
221
@staticmethod
224
- def get_input_values (shape , axis , samples ):
225
- arr = np .random .randn (* shape ).astype (FLOATX )
222
+ def get_input_values (shape , axis , samples , floatX ):
223
+ intX = str (_conversion_map [floatX ])
224
+ arr = np .random .randn (* shape ).astype (floatX )
226
225
size = list (shape )
227
226
size [axis ] = samples
228
227
size = tuple (size )
229
- indices = np .random .randint (low = 0 , high = shape [axis ], size = size , dtype = INTX )
228
+ indices = np .random .randint (low = 0 , high = shape [axis ], size = size , dtype = intX )
230
229
return arr , indices
231
230
232
231
@pytest .mark .parametrize (
@@ -250,10 +249,12 @@ def get_input_values(shape, axis, samples):
250
249
),
251
250
ids = str ,
252
251
)
253
- def test_take_along_axis (self , shape , axis , samples ):
254
- arr , indices = self .get_input_values (shape , axis , samples )
255
- func = self .get_function (shape , axis )
256
- assert np .allclose (np_take_along_axis (arr , indices , axis = axis ), func (arr , indices )[0 ])
252
+ @pytest .mark .parametrize ("floatX" , ["float32" , "float64" ])
253
+ def test_take_along_axis (self , shape , axis , samples , floatX ):
254
+ with aesara .config .change_flags (floatX = floatX ):
255
+ arr , indices = self .get_input_values (shape , axis , samples , floatX )
256
+ func = self .get_function (shape , axis , floatX )
257
+ assert np .allclose (np_take_along_axis (arr , indices , axis = axis ), func (arr , indices )[0 ])
257
258
258
259
@pytest .mark .parametrize (
259
260
["shape" , "axis" , "samples" ],
@@ -276,53 +277,62 @@ def test_take_along_axis(self, shape, axis, samples):
276
277
),
277
278
ids = str ,
278
279
)
279
- def test_take_along_axis_grad (self , shape , axis , samples ):
280
- if axis < 0 :
281
- _axis = len (shape ) + axis
282
- else :
283
- _axis = axis
284
- # Setup the aesara function
285
- t_arr , t_indices = self .get_input_tensors (shape )
286
- t_out2 = aesara .grad (
287
- at .sum (self ._output_tensor (t_arr ** 2 , t_indices , axis )),
288
- t_arr ,
289
- )
290
- func = aesara .function ([t_arr , t_indices ], [t_out2 ])
291
-
292
- # Test that the gradient gives the same output as what is expected
293
- arr , indices = self .get_input_values (shape , axis , samples )
294
- expected_grad = np .zeros_like (arr )
295
- slicer = [slice (None )] * len (shape )
296
- for i in range (indices .shape [axis ]):
297
- slicer [axis ] = i
298
- inds = indices [tuple (slicer )].reshape (shape [:_axis ] + (1 ,) + shape [_axis + 1 :])
299
- inds = _make_along_axis_idx (shape , inds , _axis )
300
- expected_grad [inds ] += 1
301
- expected_grad *= 2 * arr
302
- out = func (arr , indices )[0 ]
303
- assert np .allclose (out , expected_grad )
280
+ @pytest .mark .parametrize ("floatX" , ["float32" , "float64" ])
281
+ def test_take_along_axis_grad (self , shape , axis , samples , floatX ):
282
+ with aesara .config .change_flags (floatX = floatX ):
283
+ if axis < 0 :
284
+ _axis = len (shape ) + axis
285
+ else :
286
+ _axis = axis
287
+ # Setup the aesara function
288
+ t_arr , t_indices = self .get_input_tensors (shape , floatX )
289
+ t_out2 = aesara .grad (
290
+ at .sum (self ._output_tensor (t_arr ** 2 , t_indices , axis )),
291
+ t_arr ,
292
+ )
293
+ func = aesara .function ([t_arr , t_indices ], [t_out2 ])
294
+
295
+ # Test that the gradient gives the same output as what is expected
296
+ arr , indices = self .get_input_values (shape , axis , samples , floatX )
297
+ expected_grad = np .zeros_like (arr )
298
+ slicer = [slice (None )] * len (shape )
299
+ for i in range (indices .shape [axis ]):
300
+ slicer [axis ] = i
301
+ inds = indices [tuple (slicer )].reshape (shape [:_axis ] + (1 ,) + shape [_axis + 1 :])
302
+ inds = _make_along_axis_idx (shape , inds , _axis )
303
+ expected_grad [inds ] += 1
304
+ expected_grad *= 2 * arr
305
+ out = func (arr , indices )[0 ]
306
+ assert np .allclose (out , expected_grad )
304
307
305
308
@pytest .mark .parametrize ("axis" , [- 4 , 4 ], ids = str )
306
- def test_axis_failure (self , axis ):
307
- arr , indices = self .get_input_tensors ((3 , 1 ))
308
- with pytest .raises (ValueError ):
309
- take_along_axis (arr , indices , axis = axis )
310
-
311
- def test_ndim_failure (self ):
312
- arr = TensorType (FLOATX , [False ] * 3 )("arr" )
313
- indices = TensorType (INTX , [False ] * 2 )("indices" )
314
- arr .tag .test_value = np .zeros ((1 ,) * arr .ndim , dtype = FLOATX )
315
- indices .tag .test_value = np .zeros ((1 ,) * indices .ndim , dtype = INTX )
316
- with pytest .raises (ValueError ):
317
- take_along_axis (arr , indices )
318
-
319
- def test_dtype_failure (self ):
320
- arr = TensorType (FLOATX , [False ] * 3 )("arr" )
321
- indices = TensorType (FLOATX , [False ] * 3 )("indices" )
322
- arr .tag .test_value = np .zeros ((1 ,) * arr .ndim , dtype = FLOATX )
323
- indices .tag .test_value = np .zeros ((1 ,) * indices .ndim , dtype = FLOATX )
324
- with pytest .raises (IndexError ):
325
- take_along_axis (arr , indices )
309
+ @pytest .mark .parametrize ("floatX" , ["float32" , "float64" ])
310
+ def test_axis_failure (self , axis , floatX ):
311
+ with aesara .config .change_flags (floatX = floatX ):
312
+ arr , indices = self .get_input_tensors ((3 , 1 ), floatX )
313
+ with pytest .raises (ValueError ):
314
+ take_along_axis (arr , indices , axis = axis )
315
+
316
+ @pytest .mark .parametrize ("floatX" , ["float32" , "float64" ])
317
+ def test_ndim_failure (self , floatX ):
318
+ with aesara .config .change_flags (floatX = floatX ):
319
+ intX = str (_conversion_map [floatX ])
320
+ arr = TensorType (floatX , [False ] * 3 )("arr" )
321
+ indices = TensorType (intX , [False ] * 2 )("indices" )
322
+ arr .tag .test_value = np .zeros ((1 ,) * arr .ndim , dtype = floatX )
323
+ indices .tag .test_value = np .zeros ((1 ,) * indices .ndim , dtype = intX )
324
+ with pytest .raises (ValueError ):
325
+ take_along_axis (arr , indices )
326
+
327
+ @pytest .mark .parametrize ("floatX" , ["float32" , "float64" ])
328
+ def test_dtype_failure (self , floatX ):
329
+ with aesara .config .change_flags (floatX = floatX ):
330
+ arr = TensorType (floatX , [False ] * 3 )("arr" )
331
+ indices = TensorType (floatX , [False ] * 3 )("indices" )
332
+ arr .tag .test_value = np .zeros ((1 ,) * arr .ndim , dtype = floatX )
333
+ indices .tag .test_value = np .zeros ((1 ,) * indices .ndim , dtype = floatX )
334
+ with pytest .raises (IndexError ):
335
+ take_along_axis (arr , indices )
326
336
327
337
328
338
def test_extract_obs_data ():
0 commit comments