Skip to content

Commit 436f94a

Browse files
Parametrize some more tests by floatX
1 parent 0fb7999 commit 436f94a

File tree

5 files changed

+165
-143
lines changed

5 files changed

+165
-143
lines changed

pymc/distributions/simulator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -295,11 +295,9 @@ def __call__(self, epsilon, obs_data, sim_data):
295295
return self.d_n * np.sum(-np.log(nu_d / self.rho_d) / epsilon) + self.log_r
296296

297297

298-
scalarX = at.dscalar if aesara.config.floatX == "float64" else at.fscalar
299-
vectorX = at.dvector if aesara.config.floatX == "float64" else at.fvector
300-
301-
302298
def create_sum_stat_op_from_fn(fn):
299+
vectorX = at.dvector if aesara.config.floatX == "float64" else at.fvector
300+
303301
# Check if callable returns TensorVariable with dummy inputs
304302
try:
305303
res = fn(vectorX())
@@ -322,6 +320,9 @@ def perform(self, node, inputs, outputs):
322320

323321

324322
def create_distance_op_from_fn(fn):
323+
scalarX = at.dscalar if aesara.config.floatX == "float64" else at.fscalar
324+
vectorX = at.dvector if aesara.config.floatX == "float64" else at.fvector
325+
325326
# Check if callable returns TensorVariable with dummy inputs
326327
try:
327328
res = fn(scalarX(), vectorX(), vectorX())

pymc/tests/test_aesaraf.py

Lines changed: 83 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,6 @@
4747
from pymc.exceptions import ShapeError
4848
from pymc.vartypes import int_types
4949

50-
FLOATX = str(aesara.config.floatX)
51-
INTX = str(_conversion_map[FLOATX])
52-
5350

5451
def test_change_rv_size():
5552
loc = at.as_tensor_variable([1, 2])
@@ -176,57 +173,59 @@ def setup_class(self):
176173
self.output_buffer = dict()
177174
self.func_buffer = dict()
178175

179-
def _input_tensors(self, shape):
176+
def _input_tensors(self, shape, floatX):
177+
intX = str(_conversion_map[floatX])
180178
ndim = len(shape)
181-
arr = TensorType(FLOATX, [False] * ndim)("arr")
182-
indices = TensorType(INTX, [False] * ndim)("indices")
183-
arr.tag.test_value = np.zeros(shape, dtype=FLOATX)
184-
indices.tag.test_value = np.zeros(shape, dtype=INTX)
179+
arr = TensorType(floatX, [False] * ndim)("arr")
180+
indices = TensorType(intX, [False] * ndim)("indices")
181+
arr.tag.test_value = np.zeros(shape, dtype=floatX)
182+
indices.tag.test_value = np.zeros(shape, dtype=intX)
185183
return arr, indices
186184

187-
def get_input_tensors(self, shape):
185+
def get_input_tensors(self, shape, floatX):
188186
ndim = len(shape)
189187
try:
190-
return self.inputs_buffer[ndim]
188+
return self.inputs_buffer[(ndim, floatX)]
191189
except KeyError:
192-
arr, indices = self._input_tensors(shape)
193-
self.inputs_buffer[ndim] = arr, indices
190+
arr, indices = self._input_tensors(shape, floatX)
191+
self.inputs_buffer[(ndim, floatX)] = arr, indices
194192
return arr, indices
195193

196194
def _output_tensor(self, arr, indices, axis):
197195
return take_along_axis(arr, indices, axis)
198196

199-
def get_output_tensors(self, shape, axis):
197+
def get_output_tensors(self, shape, axis, floatX):
200198
ndim = len(shape)
201199
try:
202-
return self.output_buffer[(ndim, axis)]
200+
return self.output_buffer[(ndim, axis, floatX)]
203201
except KeyError:
204-
arr, indices = self.get_input_tensors(shape)
202+
arr, indices = self.get_input_tensors(shape, floatX)
205203
out = self._output_tensor(arr, indices, axis)
206-
self.output_buffer[(ndim, axis)] = out
204+
self.output_buffer[(ndim, axis, floatX)] = out
207205
return out
208206

209207
def _function(self, arr, indices, out):
210208
return aesara.function([arr, indices], [out])
211209

212-
def get_function(self, shape, axis):
210+
def get_function(self, shape, axis, floatX):
213211
ndim = len(shape)
214212
try:
215-
return self.func_buffer[(ndim, axis)]
213+
return self.func_buffer[(ndim, axis, floatX)]
216214
except KeyError:
217-
arr, indices = self.get_input_tensors(shape)
218-
out = self.get_output_tensors(shape, axis)
215+
arr, indices = self.get_input_tensors(shape, floatX)
216+
out = self.get_output_tensors(shape, axis, floatX)
219217
func = self._function(arr, indices, out)
220-
self.func_buffer[(ndim, axis)] = func
218+
self.func_buffer[(ndim, axis, floatX)] = func
221219
return func
222220

223221
@staticmethod
224-
def get_input_values(shape, axis, samples):
225-
arr = np.random.randn(*shape).astype(FLOATX)
222+
def get_input_values(shape, axis, samples, floatX):
223+
intX = str(_conversion_map[floatX])
224+
arr = np.random.randn(*shape).astype(floatX)
226225
size = list(shape)
227226
size[axis] = samples
228227
size = tuple(size)
229-
indices = np.random.randint(low=0, high=shape[axis], size=size, dtype=INTX)
228+
indices = np.random.randint(low=0, high=shape[axis], size=size, dtype=intX)
230229
return arr, indices
231230

232231
@pytest.mark.parametrize(
@@ -250,10 +249,12 @@ def get_input_values(shape, axis, samples):
250249
),
251250
ids=str,
252251
)
253-
def test_take_along_axis(self, shape, axis, samples):
254-
arr, indices = self.get_input_values(shape, axis, samples)
255-
func = self.get_function(shape, axis)
256-
assert np.allclose(np_take_along_axis(arr, indices, axis=axis), func(arr, indices)[0])
252+
@pytest.mark.parametrize("floatX", ["float32", "float64"])
253+
def test_take_along_axis(self, shape, axis, samples, floatX):
254+
with aesara.config.change_flags(floatX=floatX):
255+
arr, indices = self.get_input_values(shape, axis, samples, floatX)
256+
func = self.get_function(shape, axis, floatX)
257+
assert np.allclose(np_take_along_axis(arr, indices, axis=axis), func(arr, indices)[0])
257258

258259
@pytest.mark.parametrize(
259260
["shape", "axis", "samples"],
@@ -276,53 +277,62 @@ def test_take_along_axis(self, shape, axis, samples):
276277
),
277278
ids=str,
278279
)
279-
def test_take_along_axis_grad(self, shape, axis, samples):
280-
if axis < 0:
281-
_axis = len(shape) + axis
282-
else:
283-
_axis = axis
284-
# Setup the aesara function
285-
t_arr, t_indices = self.get_input_tensors(shape)
286-
t_out2 = aesara.grad(
287-
at.sum(self._output_tensor(t_arr**2, t_indices, axis)),
288-
t_arr,
289-
)
290-
func = aesara.function([t_arr, t_indices], [t_out2])
291-
292-
# Test that the gradient gives the same output as what is expected
293-
arr, indices = self.get_input_values(shape, axis, samples)
294-
expected_grad = np.zeros_like(arr)
295-
slicer = [slice(None)] * len(shape)
296-
for i in range(indices.shape[axis]):
297-
slicer[axis] = i
298-
inds = indices[tuple(slicer)].reshape(shape[:_axis] + (1,) + shape[_axis + 1 :])
299-
inds = _make_along_axis_idx(shape, inds, _axis)
300-
expected_grad[inds] += 1
301-
expected_grad *= 2 * arr
302-
out = func(arr, indices)[0]
303-
assert np.allclose(out, expected_grad)
280+
@pytest.mark.parametrize("floatX", ["float32", "float64"])
281+
def test_take_along_axis_grad(self, shape, axis, samples, floatX):
282+
with aesara.config.change_flags(floatX=floatX):
283+
if axis < 0:
284+
_axis = len(shape) + axis
285+
else:
286+
_axis = axis
287+
# Setup the aesara function
288+
t_arr, t_indices = self.get_input_tensors(shape, floatX)
289+
t_out2 = aesara.grad(
290+
at.sum(self._output_tensor(t_arr**2, t_indices, axis)),
291+
t_arr,
292+
)
293+
func = aesara.function([t_arr, t_indices], [t_out2])
294+
295+
# Test that the gradient gives the same output as what is expected
296+
arr, indices = self.get_input_values(shape, axis, samples, floatX)
297+
expected_grad = np.zeros_like(arr)
298+
slicer = [slice(None)] * len(shape)
299+
for i in range(indices.shape[axis]):
300+
slicer[axis] = i
301+
inds = indices[tuple(slicer)].reshape(shape[:_axis] + (1,) + shape[_axis + 1 :])
302+
inds = _make_along_axis_idx(shape, inds, _axis)
303+
expected_grad[inds] += 1
304+
expected_grad *= 2 * arr
305+
out = func(arr, indices)[0]
306+
assert np.allclose(out, expected_grad)
304307

305308
@pytest.mark.parametrize("axis", [-4, 4], ids=str)
306-
def test_axis_failure(self, axis):
307-
arr, indices = self.get_input_tensors((3, 1))
308-
with pytest.raises(ValueError):
309-
take_along_axis(arr, indices, axis=axis)
310-
311-
def test_ndim_failure(self):
312-
arr = TensorType(FLOATX, [False] * 3)("arr")
313-
indices = TensorType(INTX, [False] * 2)("indices")
314-
arr.tag.test_value = np.zeros((1,) * arr.ndim, dtype=FLOATX)
315-
indices.tag.test_value = np.zeros((1,) * indices.ndim, dtype=INTX)
316-
with pytest.raises(ValueError):
317-
take_along_axis(arr, indices)
318-
319-
def test_dtype_failure(self):
320-
arr = TensorType(FLOATX, [False] * 3)("arr")
321-
indices = TensorType(FLOATX, [False] * 3)("indices")
322-
arr.tag.test_value = np.zeros((1,) * arr.ndim, dtype=FLOATX)
323-
indices.tag.test_value = np.zeros((1,) * indices.ndim, dtype=FLOATX)
324-
with pytest.raises(IndexError):
325-
take_along_axis(arr, indices)
309+
@pytest.mark.parametrize("floatX", ["float32", "float64"])
310+
def test_axis_failure(self, axis, floatX):
311+
with aesara.config.change_flags(floatX=floatX):
312+
arr, indices = self.get_input_tensors((3, 1), floatX)
313+
with pytest.raises(ValueError):
314+
take_along_axis(arr, indices, axis=axis)
315+
316+
@pytest.mark.parametrize("floatX", ["float32", "float64"])
317+
def test_ndim_failure(self, floatX):
318+
with aesara.config.change_flags(floatX=floatX):
319+
intX = str(_conversion_map[floatX])
320+
arr = TensorType(floatX, [False] * 3)("arr")
321+
indices = TensorType(intX, [False] * 2)("indices")
322+
arr.tag.test_value = np.zeros((1,) * arr.ndim, dtype=floatX)
323+
indices.tag.test_value = np.zeros((1,) * indices.ndim, dtype=intX)
324+
with pytest.raises(ValueError):
325+
take_along_axis(arr, indices)
326+
327+
@pytest.mark.parametrize("floatX", ["float32", "float64"])
328+
def test_dtype_failure(self, floatX):
329+
with aesara.config.change_flags(floatX=floatX):
330+
arr = TensorType(floatX, [False] * 3)("arr")
331+
indices = TensorType(floatX, [False] * 3)("indices")
332+
arr.tag.test_value = np.zeros((1,) * arr.ndim, dtype=floatX)
333+
indices.tag.test_value = np.zeros((1,) * indices.ndim, dtype=floatX)
334+
with pytest.raises(IndexError):
335+
take_along_axis(arr, indices)
326336

327337

328338
def test_extract_obs_data():

pymc/tests/test_distributions_random.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import functools
1515
import itertools
1616
import re
17+
import sys
1718

1819
from typing import Callable, List, Optional
1920

@@ -1571,14 +1572,18 @@ def constant_rng_fn(self, size, c):
15711572
"check_pymc_params_match_rv_op",
15721573
"check_pymc_draws_match_reference",
15731574
"check_rv_size",
1574-
"check_dtype",
15751575
]
15761576

1577-
def check_dtype(self):
1578-
assert pm.Constant.dist(2**4).dtype == "int8"
1579-
assert pm.Constant.dist(2**16).dtype == "int32"
1580-
assert pm.Constant.dist(2**32).dtype == "int64"
1581-
assert pm.Constant.dist(2.0).dtype == aesara.config.floatX
1577+
@pytest.mark.parametrize("floatX", ["float32", "float64"])
1578+
@pytest.mark.xfail(
1579+
sys.platform == "win32", reason="https://github.com/aesara-devs/aesara/issues/871"
1580+
)
1581+
def test_dtype(self, floatX):
1582+
with aesara.config.change_flags(floatX=floatX):
1583+
assert pm.Constant.dist(2**4).dtype == "int8"
1584+
assert pm.Constant.dist(2**16).dtype == "int32"
1585+
assert pm.Constant.dist(2**32).dtype == "int64"
1586+
assert pm.Constant.dist(2.0).dtype == floatX
15821587

15831588

15841589
class TestOrderedLogistic(BaseTestDistributionRandom):

pymc/tests/test_mixture.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -659,24 +659,26 @@ def test_iterable_single_component_warning(self):
659659
with pytest.warns(UserWarning, match="Single component will be treated as a mixture"):
660660
Mixture.dist(w=[0.5, 0.5], comp_dists=[Normal.dist(size=2)])
661661

662-
def test_mixture_dtype(self):
663-
mix_dtype = Mixture.dist(
664-
w=[0.5, 0.5],
665-
comp_dists=[
666-
Multinomial.dist(n=5, p=[0.5, 0.5]),
667-
Multinomial.dist(n=5, p=[0.5, 0.5]),
668-
],
669-
).dtype
670-
assert mix_dtype == "int64"
671-
672-
mix_dtype = Mixture.dist(
673-
w=[0.5, 0.5],
674-
comp_dists=[
675-
Dirichlet.dist(a=[0.5, 0.5]),
676-
Dirichlet.dist(a=[0.5, 0.5]),
677-
],
678-
).dtype
679-
assert mix_dtype == aesara.config.floatX
662+
@pytest.mark.parametrize("floatX", ["float32", "float64"])
663+
def test_mixture_dtype(self, floatX):
664+
with aesara.config.change_flags(floatX=floatX):
665+
mix_dtype = Mixture.dist(
666+
w=[0.5, 0.5],
667+
comp_dists=[
668+
Multinomial.dist(n=5, p=[0.5, 0.5]),
669+
Multinomial.dist(n=5, p=[0.5, 0.5]),
670+
],
671+
).dtype
672+
assert mix_dtype == "int64"
673+
674+
mix_dtype = Mixture.dist(
675+
w=[0.5, 0.5],
676+
comp_dists=[
677+
Dirichlet.dist(a=[0.5, 0.5]),
678+
Dirichlet.dist(a=[0.5, 0.5]),
679+
],
680+
).dtype
681+
assert mix_dtype == floatX
680682

681683
@pytest.mark.parametrize(
682684
"comp_dists, expected_shape",

0 commit comments

Comments
 (0)