From 9b09796516adde3d89e50559fd3a5431cf09a14b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 24 Mar 2025 10:11:05 +0100 Subject: [PATCH] MAINT: simplify the count_nonzero strategy It used to inline a workaround for the lack of unsigned ints on torch, but this should be done with an env var, ARRAY_API_TESTS_SKIP_DTYPES, instead. --- array_api_tests/test_searching_functions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 412085c5..3accb2c6 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -88,8 +88,6 @@ def test_argmin(x, data): ph.assert_scalar_equals("argmin", type_=int, idx=out_idx, out=min_i, expected=expected) -# XXX: dtype= stanza below is to work around unsigned int dtypes in torch -# (count_nonzero_cpu not implemented for uint32 etc) # XXX: the strategy for x is problematic on JAX unless JAX_ENABLE_X64 is on # the problem is tha for ints >iinfo(int32) it runs into essentially this: # >>> jnp.asarray[2147483648], dtype=jnp.int64) @@ -99,7 +97,7 @@ def test_argmin(x, data): @pytest.mark.min_version("2024.12") @given( x=hh.arrays( - dtype=st.sampled_from(dh.int_dtypes + dh.real_float_dtypes + dh.complex_dtypes + (xp.bool,)), + dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1, min_side=1), elements={"allow_nan": False}, ),