Skip to content

Commit 4c9dd0e

Browse files
committed
Add CuPy inspection APIs
I'm not sure if all the details here are correct. See data-apis#127 (comment).
1 parent 11cb6ef commit 4c9dd0e

File tree

2 files changed

+333
-4
lines changed

2 files changed

+333
-4
lines changed

array_api_compat/cupy/_aliases.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from ..common import _aliases
66
from .._internal import get_xp
77

8+
from ._info import __array_namespace_info__
9+
810
from typing import TYPE_CHECKING
911
if TYPE_CHECKING:
1012
from typing import Optional, Union
@@ -123,9 +125,10 @@ def asarray(
123125
else:
124126
unstack = get_xp(cp)(_aliases.unstack)
125127

126-
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
127-
'acosh', 'asin', 'asinh', 'atan', 'atan2',
128-
'atanh', 'bitwise_left_shift', 'bitwise_invert',
129-
'bitwise_right_shift', 'concat', 'pow']
128+
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
129+
'acos', 'acosh', 'asin', 'asinh', 'atan',
130+
'atan2', 'atanh', 'bitwise_left_shift',
131+
'bitwise_invert', 'bitwise_right_shift',
132+
'concat', 'pow']
130133

131134
_all_ignore = ['cp', 'get_xp']

array_api_compat/cupy/_info.py

+326
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
"""
2+
Array API Inspection namespace
3+
4+
This is the namespace for inspection functions as defined by the array API
5+
standard. See
6+
https://data-apis.org/array-api/latest/API_specification/inspection.html for
7+
more details.
8+
9+
"""
10+
from cupy import (
11+
dtype,
12+
cuda,
13+
bool_ as bool,
14+
intp,
15+
int8,
16+
int16,
17+
int32,
18+
int64,
19+
uint8,
20+
uint16,
21+
uint32,
22+
uint64,
23+
float32,
24+
float64,
25+
complex64,
26+
complex128,
27+
)
28+
29+
class __array_namespace_info__:
30+
"""
31+
Get the array API inspection namespace for CuPy.
32+
33+
The array API inspection namespace defines the following functions:
34+
35+
- capabilities()
36+
- default_device()
37+
- default_dtypes()
38+
- dtypes()
39+
- devices()
40+
41+
See
42+
https://data-apis.org/array-api/latest/API_specification/inspection.html
43+
for more details.
44+
45+
Returns
46+
-------
47+
info : ModuleType
48+
The array API inspection namespace for CuPy.
49+
50+
Examples
51+
--------
52+
>>> info = np.__array_namespace_info__()
53+
>>> info.default_dtypes()
54+
{'real floating': cupy.float64,
55+
'complex floating': cupy.complex128,
56+
'integral': cupy.int64,
57+
'indexing': cupy.int64}
58+
59+
"""
60+
61+
__module__ = 'cupy'
62+
63+
def capabilities(self):
64+
"""
65+
Return a dictionary of array API library capabilities.
66+
67+
The resulting dictionary has the following keys:
68+
69+
- **"boolean indexing"**: boolean indicating whether an array library
70+
supports boolean indexing. Always ``True`` for CuPy.
71+
72+
- **"data-dependent shapes"**: boolean indicating whether an array
73+
library supports data-dependent output shapes. Always ``True`` for
74+
CuPy.
75+
76+
See
77+
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
78+
for more details.
79+
80+
See Also
81+
--------
82+
__array_namespace_info__.default_device,
83+
__array_namespace_info__.default_dtypes,
84+
__array_namespace_info__.dtypes,
85+
__array_namespace_info__.devices
86+
87+
Returns
88+
-------
89+
capabilities : dict
90+
A dictionary of array API library capabilities.
91+
92+
Examples
93+
--------
94+
>>> info = xp.__array_namespace_info__()
95+
>>> info.capabilities()
96+
{'boolean indexing': True,
97+
'data-dependent shapes': True}
98+
99+
"""
100+
return {
101+
"boolean indexing": True,
102+
"data-dependent shapes": True,
103+
# 'max rank' will be part of the 2024.12 standard
104+
# "max rank": 64,
105+
}
106+
107+
def default_device(self):
108+
"""
109+
The default device used for new CuPy arrays.
110+
111+
See Also
112+
--------
113+
__array_namespace_info__.capabilities,
114+
__array_namespace_info__.default_dtypes,
115+
__array_namespace_info__.dtypes,
116+
__array_namespace_info__.devices
117+
118+
Returns
119+
-------
120+
device : str
121+
The default device used for new CuPy arrays.
122+
123+
Examples
124+
--------
125+
>>> info = xp.__array_namespace_info__()
126+
>>> info.default_device()
127+
Device(0)
128+
129+
"""
130+
return cuda.Device(0)
131+
132+
def default_dtypes(self, *, device=None):
133+
"""
134+
The default data types used for new CuPy arrays.
135+
136+
For CuPy, this always returns the following dictionary:
137+
138+
- **"real floating"**: ``cupy.float64``
139+
- **"complex floating"**: ``cupy.complex128``
140+
- **"integral"**: ``cupy.intp``
141+
- **"indexing"**: ``cupy.intp``
142+
143+
Parameters
144+
----------
145+
device : str, optional
146+
The device to get the default data types for.
147+
148+
Returns
149+
-------
150+
dtypes : dict
151+
A dictionary describing the default data types used for new CuPy
152+
arrays.
153+
154+
See Also
155+
--------
156+
__array_namespace_info__.capabilities,
157+
__array_namespace_info__.default_device,
158+
__array_namespace_info__.dtypes,
159+
__array_namespace_info__.devices
160+
161+
Examples
162+
--------
163+
>>> info = xp.__array_namespace_info__()
164+
>>> info.default_dtypes()
165+
{'real floating': cupy.float64,
166+
'complex floating': cupy.complex128,
167+
'integral': cupy.int64,
168+
'indexing': cupy.int64}
169+
170+
"""
171+
# TODO: Does this depend on device?
172+
return {
173+
"real floating": dtype(float64),
174+
"complex floating": dtype(complex128),
175+
"integral": dtype(intp),
176+
"indexing": dtype(intp),
177+
}
178+
179+
def dtypes(self, *, device=None, kind=None):
180+
"""
181+
The array API data types supported by CuPy.
182+
183+
Note that this function only returns data types that are defined by
184+
the array API.
185+
186+
Parameters
187+
----------
188+
device : str, optional
189+
The device to get the data types for.
190+
kind : str or tuple of str, optional
191+
The kind of data types to return. If ``None``, all data types are
192+
returned. If a string, only data types of that kind are returned.
193+
If a tuple, a dictionary containing the union of the given kinds
194+
is returned. The following kinds are supported:
195+
196+
- ``'bool'``: boolean data types (i.e., ``bool``).
197+
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
198+
``int16``, ``int32``, ``int64``).
199+
- ``'unsigned integer'``: unsigned integer data types (i.e.,
200+
``uint8``, ``uint16``, ``uint32``, ``uint64``).
201+
- ``'integral'``: integer data types. Shorthand for ``('signed
202+
integer', 'unsigned integer')``.
203+
- ``'real floating'``: real-valued floating-point data types
204+
(i.e., ``float32``, ``float64``).
205+
- ``'complex floating'``: complex floating-point data types (i.e.,
206+
``complex64``, ``complex128``).
207+
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
208+
'real floating', 'complex floating')``.
209+
210+
Returns
211+
-------
212+
dtypes : dict
213+
A dictionary mapping the names of data types to the corresponding
214+
CuPy data types.
215+
216+
See Also
217+
--------
218+
__array_namespace_info__.capabilities,
219+
__array_namespace_info__.default_device,
220+
__array_namespace_info__.default_dtypes,
221+
__array_namespace_info__.devices
222+
223+
Examples
224+
--------
225+
>>> info = xp.__array_namespace_info__()
226+
>>> info.dtypes(kind='signed integer')
227+
{'int8': cupy.int8,
228+
'int16': cupy.int16,
229+
'int32': cupy.int32,
230+
'int64': cupy.int64}
231+
232+
"""
233+
# TODO: Does this depend on device?
234+
if kind is None:
235+
return {
236+
"bool": dtype(bool),
237+
"int8": dtype(int8),
238+
"int16": dtype(int16),
239+
"int32": dtype(int32),
240+
"int64": dtype(int64),
241+
"uint8": dtype(uint8),
242+
"uint16": dtype(uint16),
243+
"uint32": dtype(uint32),
244+
"uint64": dtype(uint64),
245+
"float32": dtype(float32),
246+
"float64": dtype(float64),
247+
"complex64": dtype(complex64),
248+
"complex128": dtype(complex128),
249+
}
250+
if kind == "bool":
251+
return {"bool": bool}
252+
if kind == "signed integer":
253+
return {
254+
"int8": dtype(int8),
255+
"int16": dtype(int16),
256+
"int32": dtype(int32),
257+
"int64": dtype(int64),
258+
}
259+
if kind == "unsigned integer":
260+
return {
261+
"uint8": dtype(uint8),
262+
"uint16": dtype(uint16),
263+
"uint32": dtype(uint32),
264+
"uint64": dtype(uint64),
265+
}
266+
if kind == "integral":
267+
return {
268+
"int8": dtype(int8),
269+
"int16": dtype(int16),
270+
"int32": dtype(int32),
271+
"int64": dtype(int64),
272+
"uint8": dtype(uint8),
273+
"uint16": dtype(uint16),
274+
"uint32": dtype(uint32),
275+
"uint64": dtype(uint64),
276+
}
277+
if kind == "real floating":
278+
return {
279+
"float32": dtype(float32),
280+
"float64": dtype(float64),
281+
}
282+
if kind == "complex floating":
283+
return {
284+
"complex64": dtype(complex64),
285+
"complex128": dtype(complex128),
286+
}
287+
if kind == "numeric":
288+
return {
289+
"int8": dtype(int8),
290+
"int16": dtype(int16),
291+
"int32": dtype(int32),
292+
"int64": dtype(int64),
293+
"uint8": dtype(uint8),
294+
"uint16": dtype(uint16),
295+
"uint32": dtype(uint32),
296+
"uint64": dtype(uint64),
297+
"float32": dtype(float32),
298+
"float64": dtype(float64),
299+
"complex64": dtype(complex64),
300+
"complex128": dtype(complex128),
301+
}
302+
if isinstance(kind, tuple):
303+
res = {}
304+
for k in kind:
305+
res.update(self.dtypes(kind=k))
306+
return res
307+
raise ValueError(f"unsupported kind: {kind!r}")
308+
309+
def devices(self):
310+
"""
311+
The devices supported by CuPy.
312+
313+
Returns
314+
-------
315+
devices : list of str
316+
The devices supported by CuPy.
317+
318+
See Also
319+
--------
320+
__array_namespace_info__.capabilities,
321+
__array_namespace_info__.default_device,
322+
__array_namespace_info__.default_dtypes,
323+
__array_namespace_info__.dtypes
324+
325+
"""
326+
return [cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())]

0 commit comments

Comments
 (0)