Skip to content

Commit 9088c68

Browse files
committed
v3
1 parent c7024c4 commit 9088c68

File tree

7 files changed

+49
-53
lines changed

7 files changed

+49
-53
lines changed

array_api_compat/cupy/fft.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
__all__ = fft_all + _fft.__all__
3232

33-
del get_xp
34-
del cp
35-
del fft_all
36-
del _fft
33+
def __dir__() -> list[str]:
34+
return __all__
35+

array_api_compat/cupy/linalg.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,5 @@
4343

4444
__all__ = linalg_all + _linalg.__all__
4545

46-
del get_xp
47-
del cp
48-
del linalg_all
49-
del _linalg
46+
def __dir__() -> list[str]:
47+
return __all__
Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,25 @@
11
from typing import Final
22

3-
import dask.array as da
43
from dask.array import * # noqa: F403
54

5+
# The above is missing a wealth of stuff
6+
import dask.array as da
7+
__all__ = [n for n in dir(da) if not n.startswith("_")]
8+
globals().update({n: getattr(da, n) for n in __all__})
9+
del da
10+
611
# These imports may overwrite names from the import * above.
712
from . import _aliases
813
from ._aliases import * # noqa: F403
914
from ._info import __array_namespace_info__ # noqa: F401
1015

1116
__array_api_version__: Final = "2024.12"
17+
del Final
1218

1319
# See the comment in the numpy __init__.py
1420
__import__(__package__ + '.linalg')
1521
__import__(__package__ + '.fft')
1622

17-
def _make_all(base):
18-
return sorted(
19-
set(base)
20-
| set(_aliases.__all__)
21-
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
22-
)
23-
24-
__all__ = _make_all(da.__all__)
25-
26-
def __dir__() -> list[str]:
27-
return _make_all(dir(da))
23+
__all__ += _aliases.__all__
24+
__all__ += ["__array_api_version__", "__array_namespace_info__", "linalg", "fft"]
25+
__all__ = sorted(set(__all__))

array_api_compat/numpy/fft.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import numpy as np
22
from numpy.fft import * # noqa: F403
33

4+
__all__ = [n for n in dir(np.fft) if not n.startswith("_")]
5+
globals().update({n: getattr(np.fft, n) for n in __all__})
6+
47
from .._internal import get_xp
58
from ..common import _fft
69

@@ -20,7 +23,8 @@
2023
ifftshift = get_xp(np)(_fft.ifftshift)
2124

2225

23-
__all__ = sorted(set(np.fft.__all__) | set(_fft.__all__))
26+
__all__ = sorted(set(__all__) | set(_fft.__all__))
2427

2528
def __dir__() -> list[str]:
26-
return sorted(set(dir(np.fft)) | set(_fft.__all__))
29+
return __all__
30+

array_api_compat/numpy/linalg.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
1717
from ._typing import Array
1818

19+
__all__ = [n for n in dir(np.linalg) if not n.startswith("_")]
20+
globals().update({n: getattr(np.linalg, n) for n in __all__})
21+
1922
cross = get_xp(np)(_linalg.cross)
2023
outer = get_xp(np)(_linalg.outer)
2124
EighResult = _linalg.EighResult
@@ -122,7 +125,7 @@ def solve(x1: Array, x2: Array, /) -> Array:
122125
"tensorsolve",
123126
"vector_norm",
124127
]
125-
__all__ = sorted(set(np.linalg.__all__) | set(_linalg.__all__) | set(_all))
128+
__all__ = sorted(set(__all__) | set(_linalg.__all__) | set(_all))
126129

127130
def __dir__() -> list[str]:
128-
return sorted(set(dir(np.linalg)) | set(_linalg.__all__) | set(_all))
131+
return __all__

array_api_compat/torch/fft.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88

99
from ._typing import Array
1010

11+
# The above is missing a wealth of stuff
12+
__all__ = [n for n in dir(torch.fft) if not n.startswith("_")]
13+
globals().update({n: getattr(torch.fft, n) for n in __all__})
14+
1115
# Several torch fft functions do not map axes to dim
1216

1317
def fftn(
@@ -73,15 +77,10 @@ def ifftshift(
7377
return torch.fft.ifftshift(x, dim=axes, **kwargs)
7478

7579

76-
_all = {
77-
"fftn",
78-
"ifftn",
79-
"rfftn",
80-
"irfftn",
81-
"fftshift",
82-
"ifftshift",
83-
}
84-
__all__ = sorted(set(torch.fft.__all__) |_all)
80+
__all__ = sorted(
81+
set(__all__)
82+
| {"fftn", "ifftn", "rfftn", "irfftn", "fftshift", "ifftshift"}
83+
)
8584

8685
def __dir__() -> list[str]:
87-
return sorted(set(dir(torch.fft)) | _all)
86+
return __all__

tests/test_all.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"nan",
2222
"newaxis",
2323
"pi",
24-
# Creation functions
24+
# Creation Functions
2525
"arange",
2626
"asarray",
2727
"empty",
@@ -157,7 +157,7 @@
157157
"nonzero",
158158
"searchsorted",
159159
"where",
160-
# Set functions
160+
# Set Functions
161161
"unique_all",
162162
"unique_counts",
163163
"unique_inverse",
@@ -239,15 +239,13 @@
239239

240240

241241
def all_names(mod):
242-
"""Return all names imported by `from mod import *`.
243-
This is typically `__all__` but, if not defined, Python
244-
implements automated fallbacks.
245-
"""
246-
# Note: this method also makes the test trip if a name is
247-
# in __all__ but doesn't actually appear in the module.
242+
"""Return all names available in a module."""
248243
objs = {}
249244
exec(f"from {mod.__name__} import *", objs)
250-
return list(objs)
245+
for n in dir(mod):
246+
if not n.startswith("_") and hasattr(mod, n):
247+
objs[n] = getattr(mod, n)
248+
return set(objs)
251249

252250

253251
def get_mod(library, module, *, compat):
@@ -257,49 +255,46 @@ def get_mod(library, module, *, compat):
257255
return getattr(xp, module) if module else xp
258256

259257

260-
@pytest.mark.parametrize("func", [all_names, dir])
261258
@pytest.mark.parametrize("module", list(NAMES))
262259
@pytest.mark.parametrize("library", wrapped_libraries)
263-
def test_array_api_names(library, module, func):
264-
"""Test that __all__ and dir() aren't missing any exports
260+
def test_array_api_names(library, module):
261+
"""Test that __all__ isn't missing any exports
265262
dictated by the Standard.
266263
"""
267264
mod = get_mod(library, module, compat=True)
268-
missing = set(NAMES[module]) - set(func(mod))
265+
missing = set(NAMES[module]) - all_names(mod)
269266
xfail = set(XFAILS.get((library, module), []))
270267
xpass = xfail - missing
271268
fails = missing - xfail
272269
assert not xpass, f"Names in XFAILS are defined: {xpass}"
273270
assert not fails, f"Missing exports: {fails}"
274271

275272

276-
@pytest.mark.parametrize("func", [all_names, dir])
277273
@pytest.mark.parametrize("module", list(NAMES))
278274
@pytest.mark.parametrize("library", wrapped_libraries)
279-
def test_compat_doesnt_hide_names(library, module, func):
275+
def test_compat_doesnt_hide_names(library, module):
280276
"""The base namespace can have more names than the ones explicitly exported
281277
by array-api-compat. Test that we're not suppressing them.
282278
"""
283279
bare_mod = get_mod(library, module, compat=False)
284280
compat_mod = get_mod(library, module, compat=True)
285281

286-
missing = set(func(bare_mod)) - set(func(compat_mod))
282+
missing = all_names(bare_mod) - all_names(compat_mod)
287283
missing = {name for name in missing if not name.startswith("_")}
288284
assert not missing, f"Non-Array API names have been hidden: {missing}"
289285

290286

291-
@pytest.mark.parametrize("func", [all_names, dir])
292287
@pytest.mark.parametrize("module", list(NAMES))
293288
@pytest.mark.parametrize("library", wrapped_libraries)
294-
def test_compat_doesnt_add_names(library, module, func):
289+
def test_compat_doesnt_add_names(library, module):
295290
"""Test that array-api-compat isn't adding names to the namespace
296291
besides those defined by the Array API Standard.
297292
"""
298293
bare_mod = get_mod(library, module, compat=False)
299294
compat_mod = get_mod(library, module, compat=True)
300295

301296
aapi_names = set(NAMES[module])
302-
spurious = set(func(compat_mod)) - set(func(bare_mod)) - aapi_names - {"__all__"}
297+
spurious = all_names(compat_mod) - all_names(bare_mod) - aapi_names
303298
# Quietly ignore *Result dataclasses
304299
spurious = {name for name in spurious if not name.endswith("Result")}
305300
assert not spurious, (

0 commit comments

Comments
 (0)