Skip to content

Commit be8a4b1

Browse files
committed
BUG: Fix isclose multidevice
1 parent 1e3614e commit be8a4b1

File tree

5 files changed

+17
-5
lines changed

5 files changed

+17
-5
lines changed

src/array_api_extra/_lib/_backends.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an
2424
"""
2525

2626
ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace
27+
ARRAY_API_STRICT_DEVICE1 = "array_api_strict", _compat.is_array_api_strict_namespace
2728
NUMPY = "numpy", _compat.is_numpy_namespace
2829
NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace
2930
CUPY = "cupy", _compat.is_cupy_namespace

src/array_api_extra/_lib/_funcs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def isclose(
549549
xp=xp,
550550
)
551551
if equal_nan:
552-
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out)
552+
out = xp.where(xp.isnan(a) & xp.isnan(b), True, out)
553553
return out
554554

555555
if xp.isdtype(a.dtype, "bool") or xp.isdtype(b.dtype, "bool"):

src/array_api_extra/_lib/_testing.py

+6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import cast
1111

1212
import pytest
13+
from array_api_compat import is_array_api_strict_namespace
1314

1415
from ._utils._compat import (
1516
array_namespace,
@@ -106,6 +107,11 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
106107
desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
107108

108109
# JAX uses `np.testing`
110+
if is_array_api_strict_namespace(xp):
111+
# Have to move to CPU for array API strict devices before
112+
# we're allowed to convert into numpy
113+
actual = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
114+
desired = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
109115
np.testing.assert_array_equal(actual, desired, err_msg=err_msg) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
110116

111117

tests/conftest.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def xp(
113113
The current array namespace.
114114
"""
115115
if library == Backend.NUMPY_READONLY:
116-
return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
116+
return NumPyReadOnly(), None # type: ignore[return-value] # pyright: ignore[reportReturnType]
117117
xp = pytest.importorskip(library.value)
118118
# Possibly wrap module with array_api_compat
119119
xp = array_namespace(xp.empty(0))
@@ -131,7 +131,11 @@ def xp(
131131
# suppress unused-ignore to run mypy in -e lint as well as -e dev
132132
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]
133133

134-
return xp
134+
device = None
135+
if library == Backend.ARRAY_API_STRICT_DEVICE1:
136+
import array_api_strict
137+
device = array_api_strict.Device("device1")
138+
return xp, device
135139

136140

137141
@pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask`

tests/test_funcs.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -633,8 +633,9 @@ def test_some_inf(self, xp: ModuleType):
633633
xp_assert_equal(actual, xp.asarray([True, True, True, False, False]))
634634

635635
def test_equal_nan(self, xp: ModuleType):
636-
a = xp.asarray([xp.nan, xp.nan, 1.0])
637-
b = xp.asarray([xp.nan, 1.0, xp.nan])
636+
xp, device = xp
637+
a = xp.asarray([xp.nan, xp.nan, 1.0], device=device)
638+
b = xp.asarray([xp.nan, 1.0, xp.nan], device=device)
638639
xp_assert_equal(isclose(a, b), xp.asarray([False, False, False]))
639640
xp_assert_equal(isclose(a, b, equal_nan=True), xp.asarray([True, False, False]))
640641

0 commit comments

Comments
 (0)