Skip to content

Commit cfe4d71

Browse files
authored
Add duckarray test for np.array_api (#8391)
1 parent f63ede9 commit cfe4d71

File tree

2 files changed

+38
-17
lines changed

2 files changed

+38
-17
lines changed

xarray/namedarray/_typing.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,6 @@ def shape(self) -> _Shape:
9393
def dtype(self) -> _DType_co:
9494
...
9595

96-
@overload
97-
def __array__(self, dtype: None = ..., /) -> np.ndarray[Any, _DType_co]:
98-
...
99-
100-
@overload
101-
def __array__(self, dtype: _DType, /) -> np.ndarray[Any, _DType]:
102-
...
103-
10496

10597
@runtime_checkable
10698
class _arrayfunction(
@@ -112,6 +104,19 @@ class _arrayfunction(
112104
Corresponds to np.ndarray.
113105
"""
114106

107+
@overload
108+
def __array__(self, dtype: None = ..., /) -> np.ndarray[Any, _DType_co]:
109+
...
110+
111+
@overload
112+
def __array__(self, dtype: _DType, /) -> np.ndarray[Any, _DType]:
113+
...
114+
115+
def __array__(
116+
self, dtype: _DType | None = ..., /
117+
) -> np.ndarray[Any, _DType] | np.ndarray[Any, _DType_co]:
118+
...
119+
115120
# TODO: Should return the same subclass but with a new dtype generic.
116121
# https://github.com/python/typing/issues/548
117122
def __array_ufunc__(

xarray/tests/test_namedarray.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import copy
4+
import warnings
45
from collections.abc import Mapping
56
from typing import TYPE_CHECKING, Any, Generic, cast, overload
67

@@ -66,13 +67,13 @@ def test_namedarray_init() -> None:
6667
expected = np.array([1, 2], dtype=dtype)
6768
actual: NamedArray[Any, np.dtype[np.int8]]
6869
actual = NamedArray(("x",), expected)
69-
assert np.array_equal(actual.data, expected)
70+
assert np.array_equal(np.asarray(actual.data), expected)
7071

7172
with pytest.raises(AttributeError):
7273
expected2 = [1, 2]
7374
actual2: NamedArray[Any, Any]
7475
actual2 = NamedArray(("x",), expected2) # type: ignore[arg-type]
75-
assert np.array_equal(actual2.data, expected2)
76+
assert np.array_equal(np.asarray(actual2.data), expected2)
7677

7778

7879
@pytest.mark.parametrize(
@@ -101,7 +102,7 @@ def test_from_array(
101102
else:
102103
actual = from_array(dims, data)
103104

104-
assert np.array_equal(actual.data, expected)
105+
assert np.array_equal(np.asarray(actual.data), expected)
105106

106107

107108
def test_from_array_with_masked_array() -> None:
@@ -114,7 +115,8 @@ def test_from_array_with_masked_array() -> None:
114115
def test_from_array_with_0d_object() -> None:
115116
data = np.empty((), dtype=object)
116117
data[()] = (10, 12, 12)
117-
np.array_equal(from_array((), data).data, data)
118+
narr = from_array((), data)
119+
np.array_equal(np.asarray(narr.data), data)
118120

119121

120122
# TODO: Make xr.core.indexing.ExplicitlyIndexed pass as a subclass of_arrayfunction_or_api
@@ -140,7 +142,7 @@ def test_properties() -> None:
140142
named_array: NamedArray[Any, Any]
141143
named_array = NamedArray(["x", "y"], data, {"key": "value"})
142144
assert named_array.dims == ("x", "y")
143-
assert np.array_equal(named_array.data, data)
145+
assert np.array_equal(np.asarray(named_array.data), data)
144146
assert named_array.attrs == {"key": "value"}
145147
assert named_array.ndim == 2
146148
assert named_array.sizes == {"x": 2, "y": 5}
@@ -162,7 +164,7 @@ def test_attrs() -> None:
162164
def test_data(random_inputs: np.ndarray[Any, Any]) -> None:
163165
named_array: NamedArray[Any, Any]
164166
named_array = NamedArray(["x", "y", "z"], random_inputs)
165-
assert np.array_equal(named_array.data, random_inputs)
167+
assert np.array_equal(np.asarray(named_array.data), random_inputs)
166168
with pytest.raises(ValueError):
167169
named_array.data = np.random.random((3, 4)).astype(np.float64)
168170

@@ -181,11 +183,11 @@ def test_real_and_imag() -> None:
181183
named_array = NamedArray(["x"], arr)
182184

183185
actual_real: duckarray[Any, np.dtype[np.float64]] = named_array.real.data
184-
assert np.array_equal(actual_real, expected_real)
186+
assert np.array_equal(np.asarray(actual_real), expected_real)
185187
assert actual_real.dtype == expected_real.dtype
186188

187189
actual_imag: duckarray[Any, np.dtype[np.float64]] = named_array.imag.data
188-
assert np.array_equal(actual_imag, expected_imag)
190+
assert np.array_equal(np.asarray(actual_imag), expected_imag)
189191
assert actual_imag.dtype == expected_imag.dtype
190192

191193

@@ -214,7 +216,7 @@ def test_0d_object() -> None:
214216
named_array = from_array([], (10, 12, 12))
215217
expected_data = np.empty((), dtype=object)
216218
expected_data[()] = (10, 12, 12)
217-
assert np.array_equal(named_array.data, expected_data)
219+
assert np.array_equal(np.asarray(named_array.data), expected_data)
218220

219221
assert named_array.dims == ()
220222
assert named_array.sizes == {}
@@ -294,6 +296,20 @@ def test_duck_array_typevar(a: duckarray[Any, _DType]) -> duckarray[Any, _DType]
294296
test_duck_array_typevar(numpy_a)
295297
test_duck_array_typevar(custom_a)
296298

299+
# Test numpy's array api:
300+
with warnings.catch_warnings():
301+
warnings.filterwarnings(
302+
"ignore",
303+
r"The numpy.array_api submodule is still experimental",
304+
category=UserWarning,
305+
)
306+
import numpy.array_api as nxp
307+
308+
# TODO: nxp doesn't use dtype typevars, so can only use Any for the moment:
309+
arrayapi_a: duckarray[Any, Any] # duckarray[Any, np.dtype[np.int64]]
310+
arrayapi_a = nxp.asarray([2.1, 4], dtype=np.dtype(np.int64))
311+
test_duck_array_typevar(arrayapi_a)
312+
297313

298314
def test_new_namedarray() -> None:
299315
dtype_float = np.dtype(np.float32)

0 commit comments

Comments
 (0)