|
6 | 6 |
|
7 | 7 | if TYPE_CHECKING:
|
8 | 8 | from typing import Optional, Union, Tuple, List
|
9 |
| - from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info |
| 9 | + from ._typing import device, DefaultDataTypes, DataTypes, Capabilities |
10 | 10 |
|
11 | 11 | from ._array_object import ALL_DEVICES, CPU_DEVICE
|
12 | 12 | from ._flags import get_array_api_strict_flags, requires_api_version
|
13 | 13 | from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128
|
14 | 14 |
|
15 | 15 | @requires_api_version('2023.12')
|
16 |
| -def __array_namespace_info__() -> Info: |
17 |
| - import array_api_strict._info |
18 |
| - return array_api_strict._info |
19 |
| - |
20 |
| -@requires_api_version('2023.12') |
21 |
| -def capabilities() -> Capabilities: |
22 |
| - flags = get_array_api_strict_flags() |
23 |
| - res = {"boolean indexing": flags['boolean_indexing'], |
24 |
| - "data-dependent shapes": flags['data_dependent_shapes'], |
25 |
| - } |
26 |
| - if flags['api_version'] >= '2024.12': |
27 |
| - # maxdims is 32 for NumPy 1.x and 64 for NumPy 2.0. Eventually we will |
28 |
| - # drop support for NumPy 1 but for now, just compute the number |
29 |
| - # directly |
30 |
| - for i in range(1, 100): |
31 |
| - try: |
32 |
| - np.zeros((1,)*i) |
33 |
| - except ValueError: |
34 |
| - maxdims = i - 1 |
35 |
| - break |
36 |
| - else: |
37 |
| - raise RuntimeError("Could not get max dimensions (this is a bug in array-api-strict)") |
38 |
| - res['max dimensions'] = maxdims |
39 |
| - return res |
40 |
| - |
41 |
| -@requires_api_version('2023.12') |
42 |
| -def default_device() -> device: |
43 |
| - return CPU_DEVICE |
| 16 | +class __array_namespace_info__: |
| 17 | + @requires_api_version('2023.12') |
| 18 | + def capabilities(self) -> Capabilities: |
| 19 | + flags = get_array_api_strict_flags() |
| 20 | + res = {"boolean indexing": flags['boolean_indexing'], |
| 21 | + "data-dependent shapes": flags['data_dependent_shapes'], |
| 22 | + } |
| 23 | + if flags['api_version'] >= '2024.12': |
| 24 | + # maxdims is 32 for NumPy 1.x and 64 for NumPy 2.0. Eventually we will |
| 25 | + # drop support for NumPy 1 but for now, just compute the number |
| 26 | + # directly |
| 27 | + for i in range(1, 100): |
| 28 | + try: |
| 29 | + np.zeros((1,)*i) |
| 30 | + except ValueError: |
| 31 | + maxdims = i - 1 |
| 32 | + break |
| 33 | + else: |
| 34 | + raise RuntimeError("Could not get max dimensions (this is a bug in array-api-strict)") |
| 35 | + res['max dimensions'] = maxdims |
| 36 | + return res |
44 | 37 |
|
45 |
| -@requires_api_version('2023.12') |
46 |
| -def default_dtypes( |
47 |
| - *, |
48 |
| - device: Optional[device] = None, |
49 |
| -) -> DefaultDataTypes: |
50 |
| - return { |
51 |
| - "real floating": float64, |
52 |
| - "complex floating": complex128, |
53 |
| - "integral": int64, |
54 |
| - "indexing": int64, |
55 |
| - } |
| 38 | + @requires_api_version('2023.12') |
| 39 | + def default_device(self) -> device: |
| 40 | + return CPU_DEVICE |
56 | 41 |
|
57 |
| -@requires_api_version('2023.12') |
58 |
| -def dtypes( |
59 |
| - *, |
60 |
| - device: Optional[device] = None, |
61 |
| - kind: Optional[Union[str, Tuple[str, ...]]] = None, |
62 |
| -) -> DataTypes: |
63 |
| - if kind is None: |
| 42 | + @requires_api_version('2023.12') |
| 43 | + def default_dtypes( |
| 44 | + self, |
| 45 | + *, |
| 46 | + device: Optional[device] = None, |
| 47 | + ) -> DefaultDataTypes: |
64 | 48 | return {
|
65 |
| - "bool": bool, |
66 |
| - "int8": int8, |
67 |
| - "int16": int16, |
68 |
| - "int32": int32, |
69 |
| - "int64": int64, |
70 |
| - "uint8": uint8, |
71 |
| - "uint16": uint16, |
72 |
| - "uint32": uint32, |
73 |
| - "uint64": uint64, |
74 |
| - "float32": float32, |
75 |
| - "float64": float64, |
76 |
| - "complex64": complex64, |
77 |
| - "complex128": complex128, |
| 49 | + "real floating": float64, |
| 50 | + "complex floating": complex128, |
| 51 | + "integral": int64, |
| 52 | + "indexing": int64, |
78 | 53 | }
|
79 |
| - if kind == "bool": |
80 |
| - return {"bool": bool} |
81 |
| - if kind == "signed integer": |
82 |
| - return { |
83 |
| - "int8": int8, |
84 |
| - "int16": int16, |
85 |
| - "int32": int32, |
86 |
| - "int64": int64, |
87 |
| - } |
88 |
| - if kind == "unsigned integer": |
89 |
| - return { |
90 |
| - "uint8": uint8, |
91 |
| - "uint16": uint16, |
92 |
| - "uint32": uint32, |
93 |
| - "uint64": uint64, |
94 |
| - } |
95 |
| - if kind == "integral": |
96 |
| - return { |
97 |
| - "int8": int8, |
98 |
| - "int16": int16, |
99 |
| - "int32": int32, |
100 |
| - "int64": int64, |
101 |
| - "uint8": uint8, |
102 |
| - "uint16": uint16, |
103 |
| - "uint32": uint32, |
104 |
| - "uint64": uint64, |
105 |
| - } |
106 |
| - if kind == "real floating": |
107 |
| - return { |
108 |
| - "float32": float32, |
109 |
| - "float64": float64, |
110 |
| - } |
111 |
| - if kind == "complex floating": |
112 |
| - return { |
113 |
| - "complex64": complex64, |
114 |
| - "complex128": complex128, |
115 |
| - } |
116 |
| - if kind == "numeric": |
117 |
| - return { |
118 |
| - "int8": int8, |
119 |
| - "int16": int16, |
120 |
| - "int32": int32, |
121 |
| - "int64": int64, |
122 |
| - "uint8": uint8, |
123 |
| - "uint16": uint16, |
124 |
| - "uint32": uint32, |
125 |
| - "uint64": uint64, |
126 |
| - "float32": float32, |
127 |
| - "float64": float64, |
128 |
| - "complex64": complex64, |
129 |
| - "complex128": complex128, |
130 |
| - } |
131 |
| - if isinstance(kind, tuple): |
132 |
| - res = {} |
133 |
| - for k in kind: |
134 |
| - res.update(dtypes(kind=k)) |
135 |
| - return res |
136 |
| - raise ValueError(f"unsupported kind: {kind!r}") |
137 | 54 |
|
138 |
| -@requires_api_version('2023.12') |
139 |
| -def devices() -> List[device]: |
140 |
| - return list(ALL_DEVICES) |
| 55 | + @requires_api_version('2023.12') |
| 56 | + def dtypes( |
| 57 | + self, |
| 58 | + *, |
| 59 | + device: Optional[device] = None, |
| 60 | + kind: Optional[Union[str, Tuple[str, ...]]] = None, |
| 61 | + ) -> DataTypes: |
| 62 | + if kind is None: |
| 63 | + return { |
| 64 | + "bool": bool, |
| 65 | + "int8": int8, |
| 66 | + "int16": int16, |
| 67 | + "int32": int32, |
| 68 | + "int64": int64, |
| 69 | + "uint8": uint8, |
| 70 | + "uint16": uint16, |
| 71 | + "uint32": uint32, |
| 72 | + "uint64": uint64, |
| 73 | + "float32": float32, |
| 74 | + "float64": float64, |
| 75 | + "complex64": complex64, |
| 76 | + "complex128": complex128, |
| 77 | + } |
| 78 | + if kind == "bool": |
| 79 | + return {"bool": bool} |
| 80 | + if kind == "signed integer": |
| 81 | + return { |
| 82 | + "int8": int8, |
| 83 | + "int16": int16, |
| 84 | + "int32": int32, |
| 85 | + "int64": int64, |
| 86 | + } |
| 87 | + if kind == "unsigned integer": |
| 88 | + return { |
| 89 | + "uint8": uint8, |
| 90 | + "uint16": uint16, |
| 91 | + "uint32": uint32, |
| 92 | + "uint64": uint64, |
| 93 | + } |
| 94 | + if kind == "integral": |
| 95 | + return { |
| 96 | + "int8": int8, |
| 97 | + "int16": int16, |
| 98 | + "int32": int32, |
| 99 | + "int64": int64, |
| 100 | + "uint8": uint8, |
| 101 | + "uint16": uint16, |
| 102 | + "uint32": uint32, |
| 103 | + "uint64": uint64, |
| 104 | + } |
| 105 | + if kind == "real floating": |
| 106 | + return { |
| 107 | + "float32": float32, |
| 108 | + "float64": float64, |
| 109 | + } |
| 110 | + if kind == "complex floating": |
| 111 | + return { |
| 112 | + "complex64": complex64, |
| 113 | + "complex128": complex128, |
| 114 | + } |
| 115 | + if kind == "numeric": |
| 116 | + return { |
| 117 | + "int8": int8, |
| 118 | + "int16": int16, |
| 119 | + "int32": int32, |
| 120 | + "int64": int64, |
| 121 | + "uint8": uint8, |
| 122 | + "uint16": uint16, |
| 123 | + "uint32": uint32, |
| 124 | + "uint64": uint64, |
| 125 | + "float32": float32, |
| 126 | + "float64": float64, |
| 127 | + "complex64": complex64, |
| 128 | + "complex128": complex128, |
| 129 | + } |
| 130 | + if isinstance(kind, tuple): |
| 131 | + res = {} |
| 132 | + for k in kind: |
| 133 | + res.update(dtypes(kind=k)) |
| 134 | + return res |
| 135 | + raise ValueError(f"unsupported kind: {kind!r}") |
141 | 136 |
|
142 |
| -__all__ = [ |
143 |
| - "capabilities", |
144 |
| - "default_device", |
145 |
| - "default_dtypes", |
146 |
| - "devices", |
147 |
| - "dtypes", |
148 |
| -] |
| 137 | + @requires_api_version('2023.12') |
| 138 | + def devices(self) -> List[device]: |
| 139 | + return list(ALL_DEVICES) |
0 commit comments