Skip to content

Commit 926f3a5

Browse files
authored
Merge pull request #163 from honno/skip-undefined-dtype-assert
Skip asserting to undefined default dtypes in `test_sum`
2 parents 17114d2 + a7ca3f5 commit 926f3a5

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

array_api_tests/test_statistical_functions.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,13 @@ def test_prod(x, data):
147147
_dtype = dh.default_float
148148
else:
149149
_dtype = dtype
150-
# We ignore asserting the out dtype if what we expect is undefined
151-
# See https://github.com/data-apis/array-api-tests/issues/106
152-
if not isinstance(_dtype, _UndefinedStub):
150+
if isinstance(_dtype, _UndefinedStub):
151+
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
152+
# uint32 or uint64), we skip testing the output dtype.
153+
# See https://github.com/data-apis/array-api-tests/issues/106
154+
if _dtype in dh.uint_dtypes:
155+
assert dh.is_int_dtype(out.dtype) # sanity check
156+
else:
153157
ph.assert_dtype("prod", x.dtype, out.dtype, _dtype)
154158
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
155159
ph.assert_keepdimable_shape(
@@ -248,7 +252,14 @@ def test_sum(x, data):
248252
_dtype = dh.default_float
249253
else:
250254
_dtype = dtype
251-
ph.assert_dtype("sum", x.dtype, out.dtype, _dtype)
255+
if isinstance(_dtype, _UndefinedStub):
256+
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
257+
# uint32 or uint64), we skip testing the output dtype.
258+
# See https://github.com/data-apis/array-api-tests/issues/160
259+
if _dtype in dh.uint_dtypes:
260+
assert dh.is_int_dtype(out.dtype) # sanity check
261+
else:
262+
ph.assert_dtype("sum", x.dtype, out.dtype, _dtype)
252263
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
253264
ph.assert_keepdimable_shape(
254265
"sum", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw

0 commit comments

Comments
 (0)