diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index fce2be9d2a..0329575c01 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -1246,10 +1246,9 @@ def __init__( The dtype of the internal accumulator. If ``None`` (default), we use the dtype in the list below, or the input dtype if its precision is higher: - - for int dtypes, we use at least int64; - for uint dtypes, we use at least uint64; - - for float dtypes, we use at least float64; + - for float dtypes, we use at least floatX; - for complex dtypes, we use at least complex128. upcast_discrete_output See @@ -1333,8 +1332,8 @@ def _acc_dtype(self, idtype): uint8="uint64", uint16="uint64", uint32="uint64", - float16="float32", - float32="float64", + float16=config.floatX, + float32=config.floatX, complex64="complex128", ).get(idtype, idtype) elif acc_dtype in continuous_dtypes and idtype in discrete_dtypes: diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index cda745d023..c948ec9679 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -631,7 +631,9 @@ def test_grad(self, ndim): a = np.random.random((10,) * ndim).astype(config.floatX) for axis in self._possible_axis(ndim): - utt.verify_grad(lambda x: Repeat(axis=axis)(x, 3), [a]) + utt.verify_grad( + lambda x: Repeat(axis=axis)(x, 3), [a], cast_to_output_type=True + ) def test_broadcastable(self): x = TensorType(config.floatX, shape=(None, 1, None))() diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index e346348406..0ce8dd252f 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -3030,8 +3030,8 @@ def test_reduce_default_acc_dtype(self): uint8="uint64", uint16="uint64", uint32="uint64", - float16="float32", - float32="float64", + float16=config.floatX, + float32=config.floatX, complex64="complex128", ).get(dtype, dtype) f = function([x], s, mode=self.mode) @@ -3255,8 +3255,8 @@ def test_prod_without_zeros_default_acc_dtype(self): uint8="uint64", uint16="uint64", uint32="uint64", - float16="float32", - float32="float64", + float16=config.floatX, + float32=config.floatX, complex64="complex128", ).get(dtype, dtype) diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 946cb48b6b..d01bfd39db 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -771,8 +771,18 @@ def test_grad_2d_inc_set_subtensor(self): t = op(n[:z, :z], m) gn, gm = pytensor.grad(pt_sum(t), [n, m]) - utt.verify_grad(lambda m: op(n[:z, :z], m), [mv], mode=self.mode) - utt.verify_grad(lambda nn: op(nn[:z, :z], mv), [data], mode=self.mode) + utt.verify_grad( + lambda m: op(n[:z, :z], m), + [mv], + mode=self.mode, + cast_to_output_type=True, + ) + utt.verify_grad( + lambda nn: op(nn[:z, :z], mv), + [data], + mode=self.mode, + cast_to_output_type=True, + ) def test_grad_0d(self): data = np.asarray(random(2, 3), dtype=self.dtype) diff --git a/tests/tensor/utils.py b/tests/tensor/utils.py index be5f2a029e..d2a556482f 100644 --- a/tests/tensor/utils.py +++ b/tests/tensor/utils.py @@ -592,6 +592,7 @@ def test_grad(self): mode=self.mode, rel_tol=_grad_rtol, eps=_grad_eps, + cast_to_output_type=True, ) except Exception as exc: err_msg = (