diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index c88d0a53..390245c7 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -147,9 +147,13 @@ def test_prod(x, data): _dtype = dh.default_float else: _dtype = dtype - # We ignore asserting the out dtype if what we expect is undefined - # See https://github.com/data-apis/array-api-tests/issues/106 - if not isinstance(_dtype, _UndefinedStub): + if isinstance(_dtype, _UndefinedStub): + # If a default uint cannot exist (i.e. in PyTorch which doesn't support + # uint32 or uint64), we skip testing the output dtype. + # See https://github.com/data-apis/array-api-tests/issues/106 + if _dtype in dh.uint_dtypes: + assert dh.is_int_dtype(out.dtype) # sanity check + else: ph.assert_dtype("prod", x.dtype, out.dtype, _dtype) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( @@ -248,7 +252,14 @@ def test_sum(x, data): _dtype = dh.default_float else: _dtype = dtype - ph.assert_dtype("sum", x.dtype, out.dtype, _dtype) + if isinstance(_dtype, _UndefinedStub): + # If a default uint cannot exist (i.e. in PyTorch which doesn't support + # uint32 or uint64), we skip testing the output dtype. + # See https://github.com/data-apis/array-api-tests/issues/160 + if _dtype in dh.uint_dtypes: + assert dh.is_int_dtype(out.dtype) # sanity check + else: + ph.assert_dtype("sum", x.dtype, out.dtype, _dtype) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( "sum", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw