Skip to content

Commit 6f9db94

Browse files
committed
Factor out dtype logic from test_sum() and test_prod() and apply it to test_trace()
1 parent a4d419f commit 6f9db94

File tree

3 files changed

+57
-69
lines changed

3 files changed

+57
-69
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,47 @@ def as_real_dtype(dtype):
242242
else:
243243
raise ValueError("as_real_dtype requires a floating-point dtype")
244244

245+
def accumulation_result_dtype(x_dtype, dtype_kwarg):
246+
"""
247+
Result dtype logic for sum(), prod(), and trace()
248+
249+
Note: may return None if a default uint cannot exist (e.g., for pytorch
250+
which doesn't support uint32 or uint64). See https://github.com/data-apis/array-api-tests/issues/106
251+
252+
"""
253+
if dtype_kwarg is None:
254+
if is_int_dtype(x_dtype):
255+
if x_dtype in uint_dtypes:
256+
default_dtype = default_uint
257+
else:
258+
default_dtype = default_int
259+
if default_dtype is None:
260+
_dtype = None
261+
else:
262+
m, M = dtype_ranges[x_dtype]
263+
d_m, d_M = dtype_ranges[default_dtype]
264+
if m < d_m or M > d_M:
265+
_dtype = x_dtype
266+
else:
267+
_dtype = default_dtype
268+
elif is_float_dtype(x_dtype, include_complex=False):
269+
if dtype_nbits[x_dtype] > dtype_nbits[default_float]:
270+
_dtype = x_dtype
271+
else:
272+
_dtype = default_float
273+
elif api_version > "2021.12":
274+
# Complex dtype
275+
if dtype_nbits[x_dtype] > dtype_nbits[default_complex]:
276+
_dtype = x_dtype
277+
else:
278+
_dtype = default_complex
279+
else:
280+
raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
281+
else:
282+
_dtype = dtype_kwarg
283+
284+
return _dtype
285+
245286
if not hasattr(xp, "asarray"):
246287
default_int = xp.int32
247288
default_float = xp.float32

array_api_tests/test_linalg.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -795,11 +795,16 @@ def test_tensordot(x1, x2, kw):
795795
def test_trace(x, kw):
796796
res = linalg.trace(x, **kw)
797797

798-
# TODO: trace() should promote in some cases. See
799-
# https://github.com/data-apis/array-api/issues/202. See also the dtype
800-
# argument to sum() below.
801-
802-
# assert res.dtype == x.dtype, "trace() returned the wrong dtype"
798+
dtype = kw.get("dtype", None)
799+
expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
800+
if expected_dtype is None:
801+
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
802+
# uint32 or uint64), we skip testing the output dtype.
803+
# See https://github.com/data-apis/array-api-tests/issues/160
804+
if x.dtype in dh.uint_dtypes:
805+
assert dh.is_int_dtype(res.dtype) # sanity check
806+
else:
807+
ph.assert_dtype("trace", in_dtype=x.dtype, out_dtype=res.dtype, expected=expected_dtype)
803808

804809
n, m = x.shape[-2:]
805810
ph.assert_result_shape('trace', x.shape, res.shape, expected=x.shape[:-2])

array_api_tests/test_statistical_functions.py

Lines changed: 6 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -130,44 +130,15 @@ def test_prod(x, data):
130130
out = xp.prod(x, **kw)
131131

132132
dtype = kw.get("dtype", None)
133-
if dtype is None:
134-
if dh.is_int_dtype(x.dtype):
135-
if x.dtype in dh.uint_dtypes:
136-
default_dtype = dh.default_uint
137-
else:
138-
default_dtype = dh.default_int
139-
if default_dtype is None:
140-
_dtype = None
141-
else:
142-
m, M = dh.dtype_ranges[x.dtype]
143-
d_m, d_M = dh.dtype_ranges[default_dtype]
144-
if m < d_m or M > d_M:
145-
_dtype = x.dtype
146-
else:
147-
_dtype = default_dtype
148-
elif dh.is_float_dtype(x.dtype, include_complex=False):
149-
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]:
150-
_dtype = x.dtype
151-
else:
152-
_dtype = dh.default_float
153-
elif api_version > "2021.12":
154-
# Complex dtype
155-
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]:
156-
_dtype = x.dtype
157-
else:
158-
_dtype = dh.default_complex
159-
else:
160-
raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
161-
else:
162-
_dtype = dtype
163-
if _dtype is None:
133+
expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
134+
if expected_dtype is None:
164135
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
165136
# uint32 or uint64), we skip testing the output dtype.
166137
# See https://github.com/data-apis/array-api-tests/issues/106
167138
if x.dtype in dh.uint_dtypes:
168139
assert dh.is_int_dtype(out.dtype) # sanity check
169140
else:
170-
ph.assert_dtype("prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=_dtype)
141+
ph.assert_dtype("prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype)
171142
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
172143
ph.assert_keepdimable_shape(
173144
"prod", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
@@ -246,44 +217,15 @@ def test_sum(x, data):
246217
out = xp.sum(x, **kw)
247218

248219
dtype = kw.get("dtype", None)
249-
if dtype is None:
250-
if dh.is_int_dtype(x.dtype):
251-
if x.dtype in dh.uint_dtypes:
252-
default_dtype = dh.default_uint
253-
else:
254-
default_dtype = dh.default_int
255-
if default_dtype is None:
256-
_dtype = None
257-
else:
258-
m, M = dh.dtype_ranges[x.dtype]
259-
d_m, d_M = dh.dtype_ranges[default_dtype]
260-
if m < d_m or M > d_M:
261-
_dtype = x.dtype
262-
else:
263-
_dtype = default_dtype
264-
elif dh.is_float_dtype(x.dtype, include_complex=False):
265-
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]:
266-
_dtype = x.dtype
267-
else:
268-
_dtype = dh.default_float
269-
elif api_version > "2021.12":
270-
# Complex dtype
271-
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]:
272-
_dtype = x.dtype
273-
else:
274-
_dtype = dh.default_complex
275-
else:
276-
raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
277-
else:
278-
_dtype = dtype
279-
if _dtype is None:
220+
expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
221+
if expected_dtype is None:
280222
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
281223
# uint32 or uint64), we skip testing the output dtype.
282224
# See https://github.com/data-apis/array-api-tests/issues/160
283225
if x.dtype in dh.uint_dtypes:
284226
assert dh.is_int_dtype(out.dtype) # sanity check
285227
else:
286-
ph.assert_dtype("sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=_dtype)
228+
ph.assert_dtype("sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype)
287229
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
288230
ph.assert_keepdimable_shape(
289231
"sum", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw

0 commit comments

Comments
 (0)