Skip to content

Commit 11cb6ef

Browse files
committed
Add NumPy inspection namespace
1 parent 284bd99 commit 11cb6ef

File tree

2 files changed

+353
-4
lines changed

2 files changed

+353
-4
lines changed

array_api_compat/numpy/_aliases.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from .._internal import get_xp
66

7+
from ._info import __array_namespace_info__
8+
79
from typing import TYPE_CHECKING
810
if TYPE_CHECKING:
911
from typing import Optional, Union
@@ -128,9 +130,10 @@ def asarray(
128130
else:
129131
unstack = get_xp(np)(_aliases.unstack)
130132

131-
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
132-
'acosh', 'asin', 'asinh', 'atan', 'atan2',
133-
'atanh', 'bitwise_left_shift', 'bitwise_invert',
134-
'bitwise_right_shift', 'concat', 'pow']
133+
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
134+
'acos', 'acosh', 'asin', 'asinh', 'atan',
135+
'atan2', 'atanh', 'bitwise_left_shift',
136+
'bitwise_invert', 'bitwise_right_shift',
137+
'concat', 'pow']
135138

136139
_all_ignore = ['np', 'get_xp']

array_api_compat/numpy/_info.py

+346
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
"""
2+
Array API Inspection namespace
3+
4+
This is the namespace for inspection functions as defined by the array API
5+
standard. See
6+
https://data-apis.org/array-api/latest/API_specification/inspection.html for
7+
more details.
8+
9+
"""
10+
from numpy import (
11+
dtype,
12+
bool_ as bool,
13+
intp,
14+
int8,
15+
int16,
16+
int32,
17+
int64,
18+
uint8,
19+
uint16,
20+
uint32,
21+
uint64,
22+
float32,
23+
float64,
24+
complex64,
25+
complex128,
26+
)
27+
28+
29+
class __array_namespace_info__:
30+
"""
31+
Get the array API inspection namespace for NumPy.
32+
33+
The array API inspection namespace defines the following functions:
34+
35+
- capabilities()
36+
- default_device()
37+
- default_dtypes()
38+
- dtypes()
39+
- devices()
40+
41+
See
42+
https://data-apis.org/array-api/latest/API_specification/inspection.html
43+
for more details.
44+
45+
Returns
46+
-------
47+
info : ModuleType
48+
The array API inspection namespace for NumPy.
49+
50+
Examples
51+
--------
52+
>>> info = np.__array_namespace_info__()
53+
>>> info.default_dtypes()
54+
{'real floating': numpy.float64,
55+
'complex floating': numpy.complex128,
56+
'integral': numpy.int64,
57+
'indexing': numpy.int64}
58+
59+
"""
60+
61+
__module__ = 'numpy'
62+
63+
def capabilities(self):
64+
"""
65+
Return a dictionary of array API library capabilities.
66+
67+
The resulting dictionary has the following keys:
68+
69+
- **"boolean indexing"**: boolean indicating whether an array library
70+
supports boolean indexing. Always ``True`` for NumPy.
71+
72+
- **"data-dependent shapes"**: boolean indicating whether an array
73+
library supports data-dependent output shapes. Always ``True`` for
74+
NumPy.
75+
76+
See
77+
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
78+
for more details.
79+
80+
See Also
81+
--------
82+
__array_namespace_info__.default_device,
83+
__array_namespace_info__.default_dtypes,
84+
__array_namespace_info__.dtypes,
85+
__array_namespace_info__.devices
86+
87+
Returns
88+
-------
89+
capabilities : dict
90+
A dictionary of array API library capabilities.
91+
92+
Examples
93+
--------
94+
>>> info = np.__array_namespace_info__()
95+
>>> info.capabilities()
96+
{'boolean indexing': True,
97+
'data-dependent shapes': True}
98+
99+
"""
100+
return {
101+
"boolean indexing": True,
102+
"data-dependent shapes": True,
103+
# 'max rank' will be part of the 2024.12 standard
104+
# "max rank": 64,
105+
}
106+
107+
def default_device(self):
108+
"""
109+
The default device used for new NumPy arrays.
110+
111+
For NumPy, this always returns ``'cpu'``.
112+
113+
See Also
114+
--------
115+
__array_namespace_info__.capabilities,
116+
__array_namespace_info__.default_dtypes,
117+
__array_namespace_info__.dtypes,
118+
__array_namespace_info__.devices
119+
120+
Returns
121+
-------
122+
device : str
123+
The default device used for new NumPy arrays.
124+
125+
Examples
126+
--------
127+
>>> info = np.__array_namespace_info__()
128+
>>> info.default_device()
129+
'cpu'
130+
131+
"""
132+
return "cpu"
133+
134+
def default_dtypes(self, *, device=None):
135+
"""
136+
The default data types used for new NumPy arrays.
137+
138+
For NumPy, this always returns the following dictionary:
139+
140+
- **"real floating"**: ``numpy.float64``
141+
- **"complex floating"**: ``numpy.complex128``
142+
- **"integral"**: ``numpy.intp``
143+
- **"indexing"**: ``numpy.intp``
144+
145+
Parameters
146+
----------
147+
device : str, optional
148+
The device to get the default data types for. For NumPy, only
149+
``'cpu'`` is allowed.
150+
151+
Returns
152+
-------
153+
dtypes : dict
154+
A dictionary describing the default data types used for new NumPy
155+
arrays.
156+
157+
See Also
158+
--------
159+
__array_namespace_info__.capabilities,
160+
__array_namespace_info__.default_device,
161+
__array_namespace_info__.dtypes,
162+
__array_namespace_info__.devices
163+
164+
Examples
165+
--------
166+
>>> info = np.__array_namespace_info__()
167+
>>> info.default_dtypes()
168+
{'real floating': numpy.float64,
169+
'complex floating': numpy.complex128,
170+
'integral': numpy.int64,
171+
'indexing': numpy.int64}
172+
173+
"""
174+
if device not in ["cpu", None]:
175+
raise ValueError(
176+
'Device not understood. Only "cpu" is allowed, but received:'
177+
f' {device}'
178+
)
179+
return {
180+
"real floating": dtype(float64),
181+
"complex floating": dtype(complex128),
182+
"integral": dtype(intp),
183+
"indexing": dtype(intp),
184+
}
185+
186+
def dtypes(self, *, device=None, kind=None):
187+
"""
188+
The array API data types supported by NumPy.
189+
190+
Note that this function only returns data types that are defined by
191+
the array API.
192+
193+
Parameters
194+
----------
195+
device : str, optional
196+
The device to get the data types for. For NumPy, only ``'cpu'`` is
197+
allowed.
198+
kind : str or tuple of str, optional
199+
The kind of data types to return. If ``None``, all data types are
200+
returned. If a string, only data types of that kind are returned.
201+
If a tuple, a dictionary containing the union of the given kinds
202+
is returned. The following kinds are supported:
203+
204+
- ``'bool'``: boolean data types (i.e., ``bool``).
205+
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
206+
``int16``, ``int32``, ``int64``).
207+
- ``'unsigned integer'``: unsigned integer data types (i.e.,
208+
``uint8``, ``uint16``, ``uint32``, ``uint64``).
209+
- ``'integral'``: integer data types. Shorthand for ``('signed
210+
integer', 'unsigned integer')``.
211+
- ``'real floating'``: real-valued floating-point data types
212+
(i.e., ``float32``, ``float64``).
213+
- ``'complex floating'``: complex floating-point data types (i.e.,
214+
``complex64``, ``complex128``).
215+
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
216+
'real floating', 'complex floating')``.
217+
218+
Returns
219+
-------
220+
dtypes : dict
221+
A dictionary mapping the names of data types to the corresponding
222+
NumPy data types.
223+
224+
See Also
225+
--------
226+
__array_namespace_info__.capabilities,
227+
__array_namespace_info__.default_device,
228+
__array_namespace_info__.default_dtypes,
229+
__array_namespace_info__.devices
230+
231+
Examples
232+
--------
233+
>>> info = np.__array_namespace_info__()
234+
>>> info.dtypes(kind='signed integer')
235+
{'int8': numpy.int8,
236+
'int16': numpy.int16,
237+
'int32': numpy.int32,
238+
'int64': numpy.int64}
239+
240+
"""
241+
if device not in ["cpu", None]:
242+
raise ValueError(
243+
'Device not understood. Only "cpu" is allowed, but received:'
244+
f' {device}'
245+
)
246+
if kind is None:
247+
return {
248+
"bool": dtype(bool),
249+
"int8": dtype(int8),
250+
"int16": dtype(int16),
251+
"int32": dtype(int32),
252+
"int64": dtype(int64),
253+
"uint8": dtype(uint8),
254+
"uint16": dtype(uint16),
255+
"uint32": dtype(uint32),
256+
"uint64": dtype(uint64),
257+
"float32": dtype(float32),
258+
"float64": dtype(float64),
259+
"complex64": dtype(complex64),
260+
"complex128": dtype(complex128),
261+
}
262+
if kind == "bool":
263+
return {"bool": bool}
264+
if kind == "signed integer":
265+
return {
266+
"int8": dtype(int8),
267+
"int16": dtype(int16),
268+
"int32": dtype(int32),
269+
"int64": dtype(int64),
270+
}
271+
if kind == "unsigned integer":
272+
return {
273+
"uint8": dtype(uint8),
274+
"uint16": dtype(uint16),
275+
"uint32": dtype(uint32),
276+
"uint64": dtype(uint64),
277+
}
278+
if kind == "integral":
279+
return {
280+
"int8": dtype(int8),
281+
"int16": dtype(int16),
282+
"int32": dtype(int32),
283+
"int64": dtype(int64),
284+
"uint8": dtype(uint8),
285+
"uint16": dtype(uint16),
286+
"uint32": dtype(uint32),
287+
"uint64": dtype(uint64),
288+
}
289+
if kind == "real floating":
290+
return {
291+
"float32": dtype(float32),
292+
"float64": dtype(float64),
293+
}
294+
if kind == "complex floating":
295+
return {
296+
"complex64": dtype(complex64),
297+
"complex128": dtype(complex128),
298+
}
299+
if kind == "numeric":
300+
return {
301+
"int8": dtype(int8),
302+
"int16": dtype(int16),
303+
"int32": dtype(int32),
304+
"int64": dtype(int64),
305+
"uint8": dtype(uint8),
306+
"uint16": dtype(uint16),
307+
"uint32": dtype(uint32),
308+
"uint64": dtype(uint64),
309+
"float32": dtype(float32),
310+
"float64": dtype(float64),
311+
"complex64": dtype(complex64),
312+
"complex128": dtype(complex128),
313+
}
314+
if isinstance(kind, tuple):
315+
res = {}
316+
for k in kind:
317+
res.update(self.dtypes(kind=k))
318+
return res
319+
raise ValueError(f"unsupported kind: {kind!r}")
320+
321+
def devices(self):
322+
"""
323+
The devices supported by NumPy.
324+
325+
For NumPy, this always returns ``['cpu']``.
326+
327+
Returns
328+
-------
329+
devices : list of str
330+
The devices supported by NumPy.
331+
332+
See Also
333+
--------
334+
__array_namespace_info__.capabilities,
335+
__array_namespace_info__.default_device,
336+
__array_namespace_info__.default_dtypes,
337+
__array_namespace_info__.dtypes
338+
339+
Examples
340+
--------
341+
>>> info = np.__array_namespace_info__()
342+
>>> info.devices()
343+
['cpu']
344+
345+
"""
346+
return ["cpu"]

0 commit comments

Comments
 (0)