Skip to content

Commit a52063e

Browse files
authored
Merge pull request #20527 from asmeurer/array_api-__array__
ENH: Add __array__ to the array_api Array object Original NumPy Commit: b5331ea9ece515d45e5a5adcb8e06117c9d33569
2 parents 472e03f + 1540ff8 commit a52063e

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

array_api_strict/_array_object.py

+12
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
if TYPE_CHECKING:
3535
from ._typing import Any, PyCapsule, Device, Dtype
36+
import numpy.typing as npt
3637

3738
import numpy as np
3839

@@ -108,6 +109,17 @@ def __repr__(self: Array, /) -> str:
108109
mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
109110
return prefix + mid + suffix
110111

112+
# This function is not required by the spec, but we implement it here for
113+
# convenience so that np.asarray(np.array_api.Array) will work.
114+
def __array__(self, dtype: None | np.dtype[Any] = None) -> npt.NDArray[Any]:
115+
"""
116+
Warning: this method is NOT part of the array API spec. Implementers
117+
of other libraries need not include it, and users should not assume it
118+
will be present in other implementations.
119+
120+
"""
121+
return np.asarray(self._array, dtype=dtype)
122+
111123
# These are various helper functions to make the array behavior match the
112124
# spec in places where it either deviates from or is more strict than
113125
# NumPy behavior

array_api_strict/tests/test_array_object.py

+7
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,10 @@ def test_array_properties():
315315
assert a.mT.shape == (1, 3, 2)
316316
assert isinstance(b.mT, Array)
317317
assert b.mT.shape == (3, 2)
318+
319+
def test___array__():
320+
a = ones((2, 3), dtype=int16)
321+
assert np.asarray(a) is a._array
322+
b = np.asarray(a, dtype=np.float64)
323+
assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64)))
324+
assert b.dtype == np.float64

0 commit comments

Comments
 (0)