@@ -147,9 +147,13 @@ def test_prod(x, data):
147
147
_dtype = dh .default_float
148
148
else :
149
149
_dtype = dtype
150
- # We ignore asserting the out dtype if what we expect is undefined
151
- # See https://github.com/data-apis/array-api-tests/issues/106
152
- if not isinstance (_dtype , _UndefinedStub ):
150
+ if isinstance (_dtype , _UndefinedStub ):
151
+ # If a default uint cannot exist (i.e. in PyTorch which doesn't support
152
+ # uint32 or uint64), we skip testing the output dtype.
153
+ # See https://github.com/data-apis/array-api-tests/issues/106
154
+ if _dtype in dh .uint_dtypes :
155
+ assert dh .is_int_dtype (out .dtype ) # sanity check
156
+ else :
153
157
ph .assert_dtype ("prod" , x .dtype , out .dtype , _dtype )
154
158
_axes = sh .normalise_axis (kw .get ("axis" , None ), x .ndim )
155
159
ph .assert_keepdimable_shape (
@@ -248,7 +252,14 @@ def test_sum(x, data):
248
252
_dtype = dh .default_float
249
253
else :
250
254
_dtype = dtype
251
- ph .assert_dtype ("sum" , x .dtype , out .dtype , _dtype )
255
+ if isinstance (_dtype , _UndefinedStub ):
256
+ # If a default uint cannot exist (i.e. in PyTorch which doesn't support
257
+ # uint32 or uint64), we skip testing the output dtype.
258
+ # See https://github.com/data-apis/array-api-tests/issues/160
259
+ if _dtype in dh .uint_dtypes :
260
+ assert dh .is_int_dtype (out .dtype ) # sanity check
261
+ else :
262
+ ph .assert_dtype ("sum" , x .dtype , out .dtype , _dtype )
252
263
_axes = sh .normalise_axis (kw .get ("axis" , None ), x .ndim )
253
264
ph .assert_keepdimable_shape (
254
265
"sum" , x .shape , out .shape , _axes , kw .get ("keepdims" , False ), ** kw
0 commit comments