diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index d0b67388..efe2f377 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -549,7 +549,7 @@ def isclose( xp=xp, ) if equal_nan: - out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out) + out = xp.where(xp.isnan(a) & xp.isnan(b), True, out) return out if xp.isdtype(a.dtype, "bool") or xp.isdtype(b.dtype, "bool"): diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index f592eb45..e5ec16a6 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -13,6 +13,7 @@ from ._utils._compat import ( array_namespace, + is_array_api_strict_namespace, is_cupy_namespace, is_dask_namespace, is_pydata_sparse_namespace, @@ -105,8 +106,18 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None: actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - # JAX uses `np.testing` - np.testing.assert_array_equal(actual, desired, err_msg=err_msg) # pyright: ignore[reportUnknownArgumentType] + actual_np = None + desired_np = None + if is_array_api_strict_namespace(xp): + # __array__ doesn't work on array-api-strict device arrays + # We need to convert to the CPU device first + actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) + desired_np = np.asarray(xp.asarray(desired, device=xp.Device("CPU_DEVICE"))) + + # JAX/Dask arrays work with `np.testing` + actual_np = actual if actual_np is None else actual_np + desired_np = desired if desired_np is None else desired_np + np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg) # pyright: ignore[reportUnknownArgumentType] def xp_assert_close( @@ -169,14 +180,25 @@ def xp_assert_close( actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - # JAX uses `np.testing` + actual_np = None + desired_np = None + if is_array_api_strict_namespace(xp): + # __array__ doesn't work on array-api-strict device arrays + # We need to convert to the CPU device first + actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) + desired_np = np.asarray(xp.asarray(desired, device=xp.Device("CPU_DEVICE"))) + + # JAX/Dask arrays work with `np.testing` + actual_np = actual if actual_np is None else actual_np + desired_np = desired if desired_np is None else desired_np + assert isinstance(rtol, float) np.testing.assert_allclose( # pyright: ignore[reportCallIssue] - actual, # pyright: ignore[reportArgumentType] - desired, # pyright: ignore[reportArgumentType] + actual_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + desired_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] rtol=rtol, atol=atol, - err_msg=err_msg, # type: ignore[call-overload] + err_msg=err_msg, ) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index b93cc7c9..46591ed6 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -716,6 +716,16 @@ def test_xp(self, xp: ModuleType): b = xp.asarray([1e-9, 1e-4]) xp_assert_equal(isclose(a, b, xp=xp), xp.asarray([True, False])) + @pytest.mark.parametrize("equal_nan", [True, False]) + def test_device(self, xp: ModuleType, device: Device, equal_nan: bool): + a = xp.asarray([0.0, 0.0, xp.nan], device=device) + b = xp.asarray([1e-9, 1e-4, xp.nan], device=device) + res = isclose(a, b, equal_nan=equal_nan) + assert get_device(res) == device + xp_assert_equal( + isclose(a, b, equal_nan=equal_nan), xp.asarray([True, False, equal_nan]) + ) + class TestKron: def test_basic(self, xp: ModuleType):