Skip to content

Commit 11ebcb6

Browse files
authored
Merge pull request #19173 from czgdp1807/never_copy
ENH: Add support for copy modes to NumPy Original NumPy Commit: 7125cdfa21081b91907a65b6c888791cf150a0b2
2 parents aec7e8b + 17acdd7 commit 11ebcb6

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

array_api_strict/_creation_functions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def asarray(
4141
*,
4242
dtype: Optional[Dtype] = None,
4343
device: Optional[Device] = None,
44-
copy: Optional[bool] = None,
44+
copy: Optional[Union[bool, np._CopyMode]] = None,
4545
) -> Array:
4646
"""
4747
Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`.
@@ -55,13 +55,13 @@ def asarray(
5555
_check_valid_dtype(dtype)
5656
if device not in ["cpu", None]:
5757
raise ValueError(f"Unsupported device {device!r}")
58-
if copy is False:
58+
if copy in (False, np._CopyMode.IF_NEEDED):
5959
# Note: copy=False is not yet implemented in np.asarray
6060
raise NotImplementedError("copy=False is not yet implemented")
6161
if isinstance(obj, Array):
6262
if dtype is not None and obj.dtype != dtype:
6363
copy = True
64-
if copy is True:
64+
if copy in (True, np._CopyMode.ALWAYS):
6565
return Array._new(np.array(obj._array, copy=True, dtype=dtype))
6666
return obj
6767
if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):

array_api_strict/tests/test_creation_functions.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,18 @@ def test_asarray_copy():
4343
a[0] = 0
4444
assert all(b[0] == 1)
4545
assert all(a[0] == 0)
46-
# Once copy=False is implemented, replace this with
47-
# a = asarray([1])
48-
# b = asarray(a, copy=False)
49-
# a[0] = 0
50-
# assert all(b[0] == 0)
46+
a = asarray([1])
47+
b = asarray(a, copy=np._CopyMode.ALWAYS)
48+
a[0] = 0
49+
assert all(b[0] == 1)
50+
assert all(a[0] == 0)
51+
a = asarray([1])
52+
b = asarray(a, copy=np._CopyMode.NEVER)
53+
a[0] = 0
54+
assert all(b[0] == 0)
5155
assert_raises(NotImplementedError, lambda: asarray(a, copy=False))
56+
assert_raises(NotImplementedError,
57+
lambda: asarray(a, copy=np._CopyMode.IF_NEEDED))
5258

5359

5460
def test_arange_errors():

0 commit comments

Comments
 (0)