File tree 2 files changed +14
-8
lines changed
2 files changed +14
-8
lines changed Original file line number Diff line number Diff line change @@ -41,7 +41,7 @@ def asarray(
41
41
* ,
42
42
dtype : Optional [Dtype ] = None ,
43
43
device : Optional [Device ] = None ,
44
- copy : Optional [bool ] = None ,
44
+ copy : Optional [Union [ bool , np . _CopyMode ] ] = None ,
45
45
) -> Array :
46
46
"""
47
47
Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`.
@@ -55,13 +55,13 @@ def asarray(
55
55
_check_valid_dtype (dtype )
56
56
if device not in ["cpu" , None ]:
57
57
raise ValueError (f"Unsupported device { device !r} " )
58
- if copy is False :
58
+ if copy in ( False , np . _CopyMode . IF_NEEDED ) :
59
59
# Note: copy=False is not yet implemented in np.asarray
60
60
raise NotImplementedError ("copy=False is not yet implemented" )
61
61
if isinstance (obj , Array ):
62
62
if dtype is not None and obj .dtype != dtype :
63
63
copy = True
64
- if copy is True :
64
+ if copy in ( True , np . _CopyMode . ALWAYS ) :
65
65
return Array ._new (np .array (obj ._array , copy = True , dtype = dtype ))
66
66
return obj
67
67
if dtype is None and isinstance (obj , int ) and (obj > 2 ** 64 or obj < - (2 ** 63 )):
Original file line number Diff line number Diff line change @@ -43,12 +43,18 @@ def test_asarray_copy():
43
43
a [0 ] = 0
44
44
assert all (b [0 ] == 1 )
45
45
assert all (a [0 ] == 0 )
46
- # Once copy=False is implemented, replace this with
47
- # a = asarray([1])
48
- # b = asarray(a, copy=False)
49
- # a[0] = 0
50
- # assert all(b[0] == 0)
46
+ a = asarray ([1 ])
47
+ b = asarray (a , copy = np ._CopyMode .ALWAYS )
48
+ a [0 ] = 0
49
+ assert all (b [0 ] == 1 )
50
+ assert all (a [0 ] == 0 )
51
+ a = asarray ([1 ])
52
+ b = asarray (a , copy = np ._CopyMode .NEVER )
53
+ a [0 ] = 0
54
+ assert all (b [0 ] == 0 )
51
55
assert_raises (NotImplementedError , lambda : asarray (a , copy = False ))
56
+ assert_raises (NotImplementedError ,
57
+ lambda : asarray (a , copy = np ._CopyMode .IF_NEEDED ))
52
58
53
59
54
60
def test_arange_errors ():
You can’t perform that action at this time.
0 commit comments