Skip to content

Commit 6737695

Browse files
committed
Properly include __array_namespace_info__ in the stubs
This doesn't yet add signature tests for the info namespace functions themselves.
1 parent b7065de commit 6737695

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

array_api_tests/stubs.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545
category_to_funcs: Dict[str, List[FunctionType]] = {}
4646
for name, mod in name_to_mod.items():
47-
if name.endswith("_functions") or name == "info": # info functions file just named info.py
47+
if name.endswith("_functions"):
4848
category = name.replace("_functions", "")
4949
objects = [getattr(mod, name) for name in mod.__all__]
5050
assert all(isinstance(o, FunctionType) for o in objects) # sanity check
@@ -55,6 +55,23 @@
5555
all_funcs.extend(funcs)
5656
name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs}
5757

58+
info_funcs = []
59+
if api_version >= "2023.12":
60+
# The info functions in the stubs are in info.py, but this is not a name
61+
# in the standard.
62+
info_mod = name_to_mod["info"]
63+
64+
# Note that __array_namespace_info__ is in info.__all__ but it is in the
65+
# top-level namespace, not the info namespace.
66+
info_funcs = [getattr(info_mod, name) for name in info_mod.__all__
67+
if name != '__array_namespace_info__']
68+
assert all(isinstance(f, FunctionType) for f in info_funcs)
69+
name_to_func.update({f.__name__: f for f in info_funcs})
70+
71+
all_funcs.append(info_mod.__array_namespace_info__)
72+
name_to_func['__array_namespace_info__'] = info_mod.__array_namespace_info__
73+
category_to_funcs['info'] = [info_mod.__array_namespace_info__]
74+
5875
EXTENSIONS: List[str] = ["linalg"]
5976
if api_version >= "2022.12":
6077
EXTENSIONS.append("fft")

0 commit comments

Comments
 (0)