Skip to content

Commit 548f071

Browse files
committed
Make __array_namespace_info__ a class
This makes it so that it doesn't have a bunch of extra names on it, which it did as a module.
1 parent 632e895 commit 548f071

File tree

3 files changed

+130
-136
lines changed

3 files changed

+130
-136
lines changed

array_api_strict/_info.py

Lines changed: 119 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -6,143 +6,134 @@
66

77
if TYPE_CHECKING:
88
from typing import Optional, Union, Tuple, List
9-
from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info
9+
from ._typing import device, DefaultDataTypes, DataTypes, Capabilities
1010

1111
from ._array_object import ALL_DEVICES, CPU_DEVICE
1212
from ._flags import get_array_api_strict_flags, requires_api_version
1313
from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128
1414

1515
@requires_api_version('2023.12')
16-
def __array_namespace_info__() -> Info:
17-
import array_api_strict._info
18-
return array_api_strict._info
19-
20-
@requires_api_version('2023.12')
21-
def capabilities() -> Capabilities:
22-
flags = get_array_api_strict_flags()
23-
res = {"boolean indexing": flags['boolean_indexing'],
24-
"data-dependent shapes": flags['data_dependent_shapes'],
25-
}
26-
if flags['api_version'] >= '2024.12':
27-
# maxdims is 32 for NumPy 1.x and 64 for NumPy 2.0. Eventually we will
28-
# drop support for NumPy 1 but for now, just compute the number
29-
# directly
30-
for i in range(1, 100):
31-
try:
32-
np.zeros((1,)*i)
33-
except ValueError:
34-
maxdims = i - 1
35-
break
36-
else:
37-
raise RuntimeError("Could not get max dimensions (this is a bug in array-api-strict)")
38-
res['max dimensions'] = maxdims
39-
return res
40-
41-
@requires_api_version('2023.12')
42-
def default_device() -> device:
43-
return CPU_DEVICE
16+
class __array_namespace_info__:
17+
@requires_api_version('2023.12')
18+
def capabilities(self) -> Capabilities:
19+
flags = get_array_api_strict_flags()
20+
res = {"boolean indexing": flags['boolean_indexing'],
21+
"data-dependent shapes": flags['data_dependent_shapes'],
22+
}
23+
if flags['api_version'] >= '2024.12':
24+
# maxdims is 32 for NumPy 1.x and 64 for NumPy 2.0. Eventually we will
25+
# drop support for NumPy 1 but for now, just compute the number
26+
# directly
27+
for i in range(1, 100):
28+
try:
29+
np.zeros((1,)*i)
30+
except ValueError:
31+
maxdims = i - 1
32+
break
33+
else:
34+
raise RuntimeError("Could not get max dimensions (this is a bug in array-api-strict)")
35+
res['max dimensions'] = maxdims
36+
return res
4437

45-
@requires_api_version('2023.12')
46-
def default_dtypes(
47-
*,
48-
device: Optional[device] = None,
49-
) -> DefaultDataTypes:
50-
return {
51-
"real floating": float64,
52-
"complex floating": complex128,
53-
"integral": int64,
54-
"indexing": int64,
55-
}
38+
@requires_api_version('2023.12')
39+
def default_device(self) -> device:
40+
return CPU_DEVICE
5641

57-
@requires_api_version('2023.12')
58-
def dtypes(
59-
*,
60-
device: Optional[device] = None,
61-
kind: Optional[Union[str, Tuple[str, ...]]] = None,
62-
) -> DataTypes:
63-
if kind is None:
42+
@requires_api_version('2023.12')
43+
def default_dtypes(
44+
self,
45+
*,
46+
device: Optional[device] = None,
47+
) -> DefaultDataTypes:
6448
return {
65-
"bool": bool,
66-
"int8": int8,
67-
"int16": int16,
68-
"int32": int32,
69-
"int64": int64,
70-
"uint8": uint8,
71-
"uint16": uint16,
72-
"uint32": uint32,
73-
"uint64": uint64,
74-
"float32": float32,
75-
"float64": float64,
76-
"complex64": complex64,
77-
"complex128": complex128,
49+
"real floating": float64,
50+
"complex floating": complex128,
51+
"integral": int64,
52+
"indexing": int64,
7853
}
79-
if kind == "bool":
80-
return {"bool": bool}
81-
if kind == "signed integer":
82-
return {
83-
"int8": int8,
84-
"int16": int16,
85-
"int32": int32,
86-
"int64": int64,
87-
}
88-
if kind == "unsigned integer":
89-
return {
90-
"uint8": uint8,
91-
"uint16": uint16,
92-
"uint32": uint32,
93-
"uint64": uint64,
94-
}
95-
if kind == "integral":
96-
return {
97-
"int8": int8,
98-
"int16": int16,
99-
"int32": int32,
100-
"int64": int64,
101-
"uint8": uint8,
102-
"uint16": uint16,
103-
"uint32": uint32,
104-
"uint64": uint64,
105-
}
106-
if kind == "real floating":
107-
return {
108-
"float32": float32,
109-
"float64": float64,
110-
}
111-
if kind == "complex floating":
112-
return {
113-
"complex64": complex64,
114-
"complex128": complex128,
115-
}
116-
if kind == "numeric":
117-
return {
118-
"int8": int8,
119-
"int16": int16,
120-
"int32": int32,
121-
"int64": int64,
122-
"uint8": uint8,
123-
"uint16": uint16,
124-
"uint32": uint32,
125-
"uint64": uint64,
126-
"float32": float32,
127-
"float64": float64,
128-
"complex64": complex64,
129-
"complex128": complex128,
130-
}
131-
if isinstance(kind, tuple):
132-
res = {}
133-
for k in kind:
134-
res.update(dtypes(kind=k))
135-
return res
136-
raise ValueError(f"unsupported kind: {kind!r}")
13754

138-
@requires_api_version('2023.12')
139-
def devices() -> List[device]:
140-
return list(ALL_DEVICES)
55+
@requires_api_version('2023.12')
56+
def dtypes(
57+
self,
58+
*,
59+
device: Optional[device] = None,
60+
kind: Optional[Union[str, Tuple[str, ...]]] = None,
61+
) -> DataTypes:
62+
if kind is None:
63+
return {
64+
"bool": bool,
65+
"int8": int8,
66+
"int16": int16,
67+
"int32": int32,
68+
"int64": int64,
69+
"uint8": uint8,
70+
"uint16": uint16,
71+
"uint32": uint32,
72+
"uint64": uint64,
73+
"float32": float32,
74+
"float64": float64,
75+
"complex64": complex64,
76+
"complex128": complex128,
77+
}
78+
if kind == "bool":
79+
return {"bool": bool}
80+
if kind == "signed integer":
81+
return {
82+
"int8": int8,
83+
"int16": int16,
84+
"int32": int32,
85+
"int64": int64,
86+
}
87+
if kind == "unsigned integer":
88+
return {
89+
"uint8": uint8,
90+
"uint16": uint16,
91+
"uint32": uint32,
92+
"uint64": uint64,
93+
}
94+
if kind == "integral":
95+
return {
96+
"int8": int8,
97+
"int16": int16,
98+
"int32": int32,
99+
"int64": int64,
100+
"uint8": uint8,
101+
"uint16": uint16,
102+
"uint32": uint32,
103+
"uint64": uint64,
104+
}
105+
if kind == "real floating":
106+
return {
107+
"float32": float32,
108+
"float64": float64,
109+
}
110+
if kind == "complex floating":
111+
return {
112+
"complex64": complex64,
113+
"complex128": complex128,
114+
}
115+
if kind == "numeric":
116+
return {
117+
"int8": int8,
118+
"int16": int16,
119+
"int32": int32,
120+
"int64": int64,
121+
"uint8": uint8,
122+
"uint16": uint16,
123+
"uint32": uint32,
124+
"uint64": uint64,
125+
"float32": float32,
126+
"float64": float64,
127+
"complex64": complex64,
128+
"complex128": complex128,
129+
}
130+
if isinstance(kind, tuple):
131+
res = {}
132+
for k in kind:
133+
res.update(dtypes(kind=k))
134+
return res
135+
raise ValueError(f"unsupported kind: {kind!r}")
141136

142-
__all__ = [
143-
"capabilities",
144-
"default_device",
145-
"default_dtypes",
146-
"devices",
147-
"dtypes",
148-
]
137+
@requires_api_version('2023.12')
138+
def devices(self) -> List[device]:
139+
return list(ALL_DEVICES)

array_api_strict/_typing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from ._array_object import Array, _device
3131
from ._dtypes import _DType
32+
from ._info import __array_namespace_info__
3233

3334
_T_co = TypeVar("_T_co", covariant=True)
3435

@@ -41,7 +42,7 @@ def __len__(self, /) -> int: ...
4142

4243
Dtype = _DType
4344

44-
Info = ModuleType
45+
Info = __array_namespace_info__
4546

4647
if sys.version_info >= (3, 12):
4748
from collections.abc import Buffer as SupportsBufferProtocol

array_api_strict/tests/test_flags.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33

44
from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags,
55
reset_array_api_strict_flags)
6-
from .._info import (capabilities, default_device, default_dtypes, devices,
7-
dtypes)
6+
from .._info import __array_namespace_info__
87
from .._fft import (fft, ifft, fftn, ifftn, rfft, irfft, rfftn, irfftn, hfft,
98
ihfft, fftfreq, rfftfreq, fftshift, ifftshift)
109
from .._linalg import (cholesky, cross, det, diagonal, eigh, eigvalsh, inv,
@@ -260,14 +259,17 @@ def test_fft(func_name):
260259
set_array_api_strict_flags(enabled_extensions=('fft',))
261260
func()
262261

262+
# Test functionality even if the info object is already created
263+
_info = xp.__array_namespace_info__()
264+
263265
api_version_2023_12_examples = {
264266
'__array_namespace_info__': lambda: xp.__array_namespace_info__(),
265267
# Test these functions directly to ensure they are properly decorated
266-
'capabilities': capabilities,
267-
'default_device': default_device,
268-
'default_dtypes': default_dtypes,
269-
'devices': devices,
270-
'dtypes': dtypes,
268+
'capabilities': _info.capabilities,
269+
'default_device': _info.default_device,
270+
'default_dtypes': _info.default_dtypes,
271+
'devices': _info.devices,
272+
'dtypes': _info.dtypes,
271273
'clip': lambda: xp.clip(xp.asarray([1, 2, 3]), 1, 2),
272274
'copysign': lambda: xp.copysign(xp.asarray([1., 2., 3.]), xp.asarray([-1., -1., -1.])),
273275
'cumulative_sum': lambda: xp.cumulative_sum(xp.asarray([1, 2, 3])),

0 commit comments

Comments
 (0)