Skip to content

Commit 1a13e76

Browse files
committed
Replace device="cpu" with a special object in numpy.array_api
This way, it does not appear that "cpu" is a portable device object across different array API compatible libraries. See data-apis/array-api#626. Original NumPy Commit: 3b20ad9c5ead16282c530cf48737aa3768a77f91
1 parent 136bdd7 commit 1a13e76

File tree

5 files changed

+66
-45
lines changed

5 files changed

+66
-45
lines changed

array_api_strict/_array_object.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@
3939

4040
import numpy as np
4141

42+
# Placeholder object to represent the "cpu" device (the only device NumPy
43+
# supports).
44+
class _cpu_device:
45+
def __repr__(self):
46+
return "CPU_DEVICE"
47+
CPU_DEVICE = _cpu_device()
4248

4349
class Array:
4450
"""
@@ -1067,7 +1073,7 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array:
10671073
def to_device(self: Array, device: Device, /, stream: None = None) -> Array:
10681074
if stream is not None:
10691075
raise ValueError("The stream argument to to_device() is not supported")
1070-
if device == 'cpu':
1076+
if device == CPU_DEVICE:
10711077
return self
10721078
raise ValueError(f"Unsupported device {device!r}")
10731079

@@ -1082,7 +1088,7 @@ def dtype(self) -> Dtype:
10821088

10831089
@property
10841090
def device(self) -> Device:
1085-
return "cpu"
1091+
return CPU_DEVICE
10861092

10871093
# Note: mT is new in array API spec (see matrix_transpose)
10881094
@property

array_api_strict/_creation_functions.py

+24-24
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ def asarray(
5050
"""
5151
# _array_object imports in this file are inside the functions to avoid
5252
# circular imports
53-
from ._array_object import Array
53+
from ._array_object import Array, CPU_DEVICE
5454

5555
_check_valid_dtype(dtype)
56-
if device not in ["cpu", None]:
56+
if device not in [CPU_DEVICE, None]:
5757
raise ValueError(f"Unsupported device {device!r}")
5858
if copy in (False, np._CopyMode.IF_NEEDED):
5959
# Note: copy=False is not yet implemented in np.asarray
@@ -86,10 +86,10 @@ def arange(
8686
8787
See its docstring for more information.
8888
"""
89-
from ._array_object import Array
89+
from ._array_object import Array, CPU_DEVICE
9090

9191
_check_valid_dtype(dtype)
92-
if device not in ["cpu", None]:
92+
if device not in [CPU_DEVICE, None]:
9393
raise ValueError(f"Unsupported device {device!r}")
9494
return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype))
9595

@@ -105,10 +105,10 @@ def empty(
105105
106106
See its docstring for more information.
107107
"""
108-
from ._array_object import Array
108+
from ._array_object import Array, CPU_DEVICE
109109

110110
_check_valid_dtype(dtype)
111-
if device not in ["cpu", None]:
111+
if device not in [CPU_DEVICE, None]:
112112
raise ValueError(f"Unsupported device {device!r}")
113113
return Array._new(np.empty(shape, dtype=dtype))
114114

@@ -121,10 +121,10 @@ def empty_like(
121121
122122
See its docstring for more information.
123123
"""
124-
from ._array_object import Array
124+
from ._array_object import Array, CPU_DEVICE
125125

126126
_check_valid_dtype(dtype)
127-
if device not in ["cpu", None]:
127+
if device not in [CPU_DEVICE, None]:
128128
raise ValueError(f"Unsupported device {device!r}")
129129
return Array._new(np.empty_like(x._array, dtype=dtype))
130130

@@ -143,10 +143,10 @@ def eye(
143143
144144
See its docstring for more information.
145145
"""
146-
from ._array_object import Array
146+
from ._array_object import Array, CPU_DEVICE
147147

148148
_check_valid_dtype(dtype)
149-
if device not in ["cpu", None]:
149+
if device not in [CPU_DEVICE, None]:
150150
raise ValueError(f"Unsupported device {device!r}")
151151
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
152152

@@ -169,10 +169,10 @@ def full(
169169
170170
See its docstring for more information.
171171
"""
172-
from ._array_object import Array
172+
from ._array_object import Array, CPU_DEVICE
173173

174174
_check_valid_dtype(dtype)
175-
if device not in ["cpu", None]:
175+
if device not in [CPU_DEVICE, None]:
176176
raise ValueError(f"Unsupported device {device!r}")
177177
if isinstance(fill_value, Array) and fill_value.ndim == 0:
178178
fill_value = fill_value._array
@@ -197,10 +197,10 @@ def full_like(
197197
198198
See its docstring for more information.
199199
"""
200-
from ._array_object import Array
200+
from ._array_object import Array, CPU_DEVICE
201201

202202
_check_valid_dtype(dtype)
203-
if device not in ["cpu", None]:
203+
if device not in [CPU_DEVICE, None]:
204204
raise ValueError(f"Unsupported device {device!r}")
205205
res = np.full_like(x._array, fill_value, dtype=dtype)
206206
if res.dtype not in _all_dtypes:
@@ -225,10 +225,10 @@ def linspace(
225225
226226
See its docstring for more information.
227227
"""
228-
from ._array_object import Array
228+
from ._array_object import Array, CPU_DEVICE
229229

230230
_check_valid_dtype(dtype)
231-
if device not in ["cpu", None]:
231+
if device not in [CPU_DEVICE, None]:
232232
raise ValueError(f"Unsupported device {device!r}")
233233
return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))
234234

@@ -264,10 +264,10 @@ def ones(
264264
265265
See its docstring for more information.
266266
"""
267-
from ._array_object import Array
267+
from ._array_object import Array, CPU_DEVICE
268268

269269
_check_valid_dtype(dtype)
270-
if device not in ["cpu", None]:
270+
if device not in [CPU_DEVICE, None]:
271271
raise ValueError(f"Unsupported device {device!r}")
272272
return Array._new(np.ones(shape, dtype=dtype))
273273

@@ -280,10 +280,10 @@ def ones_like(
280280
281281
See its docstring for more information.
282282
"""
283-
from ._array_object import Array
283+
from ._array_object import Array, CPU_DEVICE
284284

285285
_check_valid_dtype(dtype)
286-
if device not in ["cpu", None]:
286+
if device not in [CPU_DEVICE, None]:
287287
raise ValueError(f"Unsupported device {device!r}")
288288
return Array._new(np.ones_like(x._array, dtype=dtype))
289289

@@ -327,10 +327,10 @@ def zeros(
327327
328328
See its docstring for more information.
329329
"""
330-
from ._array_object import Array
330+
from ._array_object import Array, CPU_DEVICE
331331

332332
_check_valid_dtype(dtype)
333-
if device not in ["cpu", None]:
333+
if device not in [CPU_DEVICE, None]:
334334
raise ValueError(f"Unsupported device {device!r}")
335335
return Array._new(np.zeros(shape, dtype=dtype))
336336

@@ -343,9 +343,9 @@ def zeros_like(
343343
344344
See its docstring for more information.
345345
"""
346-
from ._array_object import Array
346+
from ._array_object import Array, CPU_DEVICE
347347

348348
_check_valid_dtype(dtype)
349-
if device not in ["cpu", None]:
349+
if device not in [CPU_DEVICE, None]:
350350
raise ValueError(f"Unsupported device {device!r}")
351351
return Array._new(np.zeros_like(x._array, dtype=dtype))

array_api_strict/_typing.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
Protocol,
3030
)
3131

32-
from ._array_object import Array
32+
from ._array_object import Array, CPU_DEVICE
3333
from numpy import (
3434
dtype,
3535
int8,
@@ -50,7 +50,7 @@ class NestedSequence(Protocol[_T_co]):
5050
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
5151
def __len__(self, /) -> int: ...
5252

53-
Device = Literal["cpu"]
53+
Device = type(CPU_DEVICE)
5454

5555
Dtype = dtype[Union[
5656
int8,

array_api_strict/tests/test_array_object.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66

77
from .. import ones, asarray, reshape, result_type, all, equal
8-
from .._array_object import Array
8+
from .._array_object import Array, CPU_DEVICE
99
from .._dtypes import (
1010
_all_dtypes,
1111
_boolean_dtypes,
@@ -311,12 +311,15 @@ def test_python_scalar_construtors():
311311

312312
def test_device_property():
313313
a = ones((3, 4))
314-
assert a.device == 'cpu'
314+
assert a.device == CPU_DEVICE
315+
assert a.device != 'cpu'
315316

316-
assert all(equal(a.to_device('cpu'), a))
317+
assert all(equal(a.to_device(CPU_DEVICE), a))
318+
assert_raises(ValueError, lambda: a.to_device('cpu'))
317319
assert_raises(ValueError, lambda: a.to_device('gpu'))
318320

319-
assert all(equal(asarray(a, device='cpu'), a))
321+
assert all(equal(asarray(a, device=CPU_DEVICE), a))
322+
assert_raises(ValueError, lambda: asarray(a, device='cpu'))
320323
assert_raises(ValueError, lambda: asarray(a, device='gpu'))
321324

322325
def test_array_properties():

array_api_strict/tests/test_creation_functions.py

+25-13
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
zeros_like,
1919
)
2020
from .._dtypes import float32, float64
21-
from .._array_object import Array
21+
from .._array_object import Array, CPU_DEVICE
2222

2323

2424
def test_asarray_errors():
@@ -30,7 +30,8 @@ def test_asarray_errors():
3030
# Preferably this would be OverflowError
3131
# assert_raises(OverflowError, lambda: asarray([2**100]))
3232
assert_raises(TypeError, lambda: asarray([2**100]))
33-
asarray([1], device="cpu") # Doesn't error
33+
asarray([1], device=CPU_DEVICE) # Doesn't error
34+
assert_raises(ValueError, lambda: asarray([1], device="cpu"))
3435
assert_raises(ValueError, lambda: asarray([1], device="gpu"))
3536

3637
assert_raises(ValueError, lambda: asarray([1], dtype=int))
@@ -58,77 +59,88 @@ def test_asarray_copy():
5859

5960

6061
def test_arange_errors():
61-
arange(1, device="cpu") # Doesn't error
62+
arange(1, device=CPU_DEVICE) # Doesn't error
63+
assert_raises(ValueError, lambda: arange(1, device="cpu"))
6264
assert_raises(ValueError, lambda: arange(1, device="gpu"))
6365
assert_raises(ValueError, lambda: arange(1, dtype=int))
6466
assert_raises(ValueError, lambda: arange(1, dtype="i"))
6567

6668

6769
def test_empty_errors():
68-
empty((1,), device="cpu") # Doesn't error
70+
empty((1,), device=CPU_DEVICE) # Doesn't error
71+
assert_raises(ValueError, lambda: empty((1,), device="cpu"))
6972
assert_raises(ValueError, lambda: empty((1,), device="gpu"))
7073
assert_raises(ValueError, lambda: empty((1,), dtype=int))
7174
assert_raises(ValueError, lambda: empty((1,), dtype="i"))
7275

7376

7477
def test_empty_like_errors():
75-
empty_like(asarray(1), device="cpu") # Doesn't error
78+
empty_like(asarray(1), device=CPU_DEVICE) # Doesn't error
79+
assert_raises(ValueError, lambda: empty_like(asarray(1), device="cpu"))
7680
assert_raises(ValueError, lambda: empty_like(asarray(1), device="gpu"))
7781
assert_raises(ValueError, lambda: empty_like(asarray(1), dtype=int))
7882
assert_raises(ValueError, lambda: empty_like(asarray(1), dtype="i"))
7983

8084

8185
def test_eye_errors():
82-
eye(1, device="cpu") # Doesn't error
86+
eye(1, device=CPU_DEVICE) # Doesn't error
87+
assert_raises(ValueError, lambda: eye(1, device="cpu"))
8388
assert_raises(ValueError, lambda: eye(1, device="gpu"))
8489
assert_raises(ValueError, lambda: eye(1, dtype=int))
8590
assert_raises(ValueError, lambda: eye(1, dtype="i"))
8691

8792

8893
def test_full_errors():
89-
full((1,), 0, device="cpu") # Doesn't error
94+
full((1,), 0, device=CPU_DEVICE) # Doesn't error
95+
assert_raises(ValueError, lambda: full((1,), 0, device="cpu"))
9096
assert_raises(ValueError, lambda: full((1,), 0, device="gpu"))
9197
assert_raises(ValueError, lambda: full((1,), 0, dtype=int))
9298
assert_raises(ValueError, lambda: full((1,), 0, dtype="i"))
9399

94100

95101
def test_full_like_errors():
96-
full_like(asarray(1), 0, device="cpu") # Doesn't error
102+
full_like(asarray(1), 0, device=CPU_DEVICE) # Doesn't error
103+
assert_raises(ValueError, lambda: full_like(asarray(1), 0, device="cpu"))
97104
assert_raises(ValueError, lambda: full_like(asarray(1), 0, device="gpu"))
98105
assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype=int))
99106
assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype="i"))
100107

101108

102109
def test_linspace_errors():
103-
linspace(0, 1, 10, device="cpu") # Doesn't error
110+
linspace(0, 1, 10, device=CPU_DEVICE) # Doesn't error
111+
assert_raises(ValueError, lambda: linspace(0, 1, 10, device="cpu"))
104112
assert_raises(ValueError, lambda: linspace(0, 1, 10, device="gpu"))
105113
assert_raises(ValueError, lambda: linspace(0, 1, 10, dtype=float))
106114
assert_raises(ValueError, lambda: linspace(0, 1, 10, dtype="f"))
107115

108116

109117
def test_ones_errors():
110-
ones((1,), device="cpu") # Doesn't error
118+
ones((1,), device=CPU_DEVICE) # Doesn't error
119+
assert_raises(ValueError, lambda: ones((1,), device="cpu"))
111120
assert_raises(ValueError, lambda: ones((1,), device="gpu"))
112121
assert_raises(ValueError, lambda: ones((1,), dtype=int))
113122
assert_raises(ValueError, lambda: ones((1,), dtype="i"))
114123

115124

116125
def test_ones_like_errors():
117-
ones_like(asarray(1), device="cpu") # Doesn't error
126+
ones_like(asarray(1), device=CPU_DEVICE) # Doesn't error
127+
assert_raises(ValueError, lambda: ones_like(asarray(1), device="cpu"))
118128
assert_raises(ValueError, lambda: ones_like(asarray(1), device="gpu"))
119129
assert_raises(ValueError, lambda: ones_like(asarray(1), dtype=int))
120130
assert_raises(ValueError, lambda: ones_like(asarray(1), dtype="i"))
121131

122132

123133
def test_zeros_errors():
124-
zeros((1,), device="cpu") # Doesn't error
134+
zeros((1,), device=CPU_DEVICE) # Doesn't error
135+
assert_raises(ValueError, lambda: zeros((1,), device="cpu"))
125136
assert_raises(ValueError, lambda: zeros((1,), device="gpu"))
126137
assert_raises(ValueError, lambda: zeros((1,), dtype=int))
127138
assert_raises(ValueError, lambda: zeros((1,), dtype="i"))
128139

129140

130141
def test_zeros_like_errors():
131-
zeros_like(asarray(1), device="cpu") # Doesn't error
142+
zeros_like(asarray(1), device=CPU_DEVICE) # Doesn't error
143+
assert_raises(ValueError, lambda: zeros_like(asarray(1), device="cpu"))
132144
assert_raises(ValueError, lambda: zeros_like(asarray(1), device="gpu"))
133145
assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype=int))
134146
assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype="i"))

0 commit comments

Comments
 (0)