Skip to content

Commit d9f43f4

Browse files
committed
Add 'max dimensions' to capabilities() for 2024.12
1 parent 1d111b3 commit d9f43f4

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

array_api_strict/_info.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import TYPE_CHECKING
44

5+
import numpy as np
6+
57
if TYPE_CHECKING:
68
from typing import Optional, Union, Tuple, List
79
from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info
@@ -18,9 +20,23 @@ def __array_namespace_info__() -> Info:
1820
@requires_api_version('2023.12')
1921
def capabilities() -> Capabilities:
2022
flags = get_array_api_strict_flags()
21-
return {"boolean indexing": flags['boolean_indexing'],
23+
res = {"boolean indexing": flags['boolean_indexing'],
2224
"data-dependent shapes": flags['data_dependent_shapes'],
2325
}
26+
if flags['api_version'] >= '2024.12':
27+
# maxdims is 32 for NumPy 1.x and 64 for NumPy 2.0. Eventually we will
28+
# drop support for NumPy 1 but for now, just compute the number
29+
# directly
30+
for i in range(1, 100):
31+
try:
32+
np.zeros((1,)*i)
33+
except ValueError:
34+
maxdims = i - 1
35+
break
36+
else:
37+
raise RuntimeError("Could not get max dimensions (this is a bug in array-api-strict)")
38+
res['max dimensions'] = maxdims
39+
return res
2440

2541
@requires_api_version('2023.12')
2642
def default_device() -> device:

0 commit comments

Comments
 (0)