Skip to content

Commit 0bceae1

Browse files
committed
Update test_sum and test_prod to support complex dtypes
1 parent 6dc6ecd commit 0bceae1

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,19 +157,18 @@ def is_int_dtype(dtype):
157157
return dtype in all_int_dtypes
158158

159159

160-
def is_float_dtype(dtype):
160+
def is_float_dtype(dtype, real=False):
161161
# None equals NumPy's xp.float64 object, so we specifically check it here.
162162
# xp.float64 is in fact an alias of np.dtype('float64'), and its equality
163163
# with None is meant to be deprecated at some point.
164164
# See https://github.com/numpy/numpy/issues/18434
165165
if dtype is None:
166166
return False
167167
valid_dtypes = real_float_dtypes
168-
if api_version > "2021.12":
168+
if api_version > "2021.12" and not real:
169169
valid_dtypes += complex_dtypes
170170
return dtype in valid_dtypes
171171

172-
173172
def get_scalar_type(dtype: DataType) -> ScalarType:
174173
if dtype in all_int_dtypes:
175174
return int

array_api_tests/test_statistical_functions.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from . import hypothesis_helpers as hh
1212
from . import pytest_helpers as ph
1313
from . import shape_helpers as sh
14-
from . import xps
14+
from . import xps, api_version
1515
from ._array_module import _UndefinedStub
1616
from .typing import DataType
1717

@@ -145,11 +145,19 @@ def test_prod(x, data):
145145
_dtype = x.dtype
146146
else:
147147
_dtype = default_dtype
148-
else:
148+
elif dh.is_float_dtype(x.dtype, real=True):
149149
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]:
150150
_dtype = x.dtype
151151
else:
152152
_dtype = dh.default_float
153+
elif api_version > "2021.12":
154+
# Complex dtype
155+
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]:
156+
_dtype = x.dtype
157+
else:
158+
_dtype = dh.default_complex
159+
else:
160+
raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
153161
else:
154162
_dtype = dtype
155163
if _dtype is None:
@@ -253,11 +261,19 @@ def test_sum(x, data):
253261
_dtype = x.dtype
254262
else:
255263
_dtype = default_dtype
256-
else:
264+
elif dh.is_float_dtype(x.dtype, real=True):
257265
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]:
258266
_dtype = x.dtype
259267
else:
260268
_dtype = dh.default_float
269+
elif api_version > "2021.12":
270+
# Complex dtype
271+
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]:
272+
_dtype = x.dtype
273+
else:
274+
_dtype = dh.default_complex
275+
else:
276+
raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
261277
else:
262278
_dtype = dtype
263279
if _dtype is None:

0 commit comments

Comments
 (0)