Skip to content

Commit cddc9ef

Browse files
authored
ENH: Review exported symbols; redesign test_all (#315)
Review and discussion at #315
1 parent 2b559e6 commit cddc9ef

23 files changed

+435
-197
lines changed

array_api_compat/_internal.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Internal helpers
33
"""
44

5+
import importlib
56
from collections.abc import Callable
67
from functools import wraps
78
from inspect import signature
@@ -52,8 +53,25 @@ def wrapped_f(*args: object, **kwargs: object) -> object:
5253
return inner
5354

5455

55-
__all__ = ["get_xp"]
56+
def clone_module(mod_name: str, globals_: dict[str, object]) -> list[str]:
57+
"""Import everything from module, updating globals().
58+
Returns __all__.
59+
"""
60+
mod = importlib.import_module(mod_name)
61+
# Neither of these two methods is sufficient by itself,
62+
# depending on various idiosyncrasies of the libraries we're wrapping.
63+
objs = {}
64+
exec(f"from {mod.__name__} import *", objs)
65+
66+
for n in dir(mod):
67+
if not n.startswith("_") and hasattr(mod, n):
68+
objs[n] = getattr(mod, n)
69+
70+
globals_.update(objs)
71+
return list(objs)
72+
5673

74+
__all__ = ["get_xp", "clone_module"]
5775

5876
def __dir__() -> list[str]:
5977
return __all__

array_api_compat/common/_aliases.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -721,8 +721,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
721721
"finfo",
722722
"iinfo",
723723
]
724-
_all_ignore = ["inspect", "array_namespace", "NamedTuple"]
725-
726724

727725
def __dir__() -> list[str]:
728726
return __all__

array_api_compat/common/_helpers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,5 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
10621062
"to_device",
10631063
]
10641064

1065-
_all_ignore = ['lru_cache', 'sys', 'math', 'inspect', 'warnings']
1066-
10671065
def __dir__() -> list[str]:
10681066
return __all__

array_api_compat/common/_linalg.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,6 @@ def trace(
225225
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
226226
'trace']
227227

228-
_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype']
229-
230228

231229
def __dir__() -> list[str]:
232230
return __all__

array_api_compat/cupy/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
1+
from typing import Final
12
from cupy import * # noqa: F403
23

34
# from cupy import * doesn't overwrite these builtin names
45
from cupy import abs, max, min, round # noqa: F401
56

67
# These imports may overwrite names from the import * above.
78
from ._aliases import * # noqa: F403
9+
from ._info import __array_namespace_info__ # noqa: F401
810

911
# See the comment in the numpy __init__.py
1012
__import__(__package__ + '.linalg')
1113
__import__(__package__ + '.fft')
1214

13-
__array_api_version__ = '2024.12'
15+
__array_api_version__: Final = '2024.12'
16+
17+
__all__ = sorted(
18+
{name for name in globals() if not name.startswith("__")}
19+
- {"Final", "_aliases", "_info", "_typing"}
20+
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
21+
)
22+
23+
def __dir__() -> list[str]:
24+
return __all__

array_api_compat/cupy/_aliases.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from ..common import _aliases, _helpers
88
from ..common._typing import NestedSequence, SupportsBufferProtocol
99
from .._internal import get_xp
10-
from ._info import __array_namespace_info__
1110
from ._typing import Array, Device, DType
1211

1312
bool = cp.bool_
@@ -141,7 +140,7 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
141140
else:
142141
unstack = get_xp(cp)(_aliases.unstack)
143142

144-
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
143+
__all__ = _aliases.__all__ + ['asarray', 'astype',
145144
'acos', 'acosh', 'asin', 'asinh', 'atan',
146145
'atan2', 'atanh', 'bitwise_left_shift',
147146
'bitwise_invert', 'bitwise_right_shift',

array_api_compat/cupy/_typing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
__all__ = ["Array", "DType", "Device"]
4-
_all_ignore = ["cp"]
54

65
from typing import TYPE_CHECKING
76

array_api_compat/cupy/fft.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
__all__ = fft_all + _fft.__all__
3333

34-
del get_xp
35-
del cp
36-
del fft_all
37-
del _fft
34+
def __dir__() -> list[str]:
35+
return __all__
36+

array_api_compat/cupy/linalg.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,5 @@
4343

4444
__all__ = linalg_all + _linalg.__all__
4545

46-
del get_xp
47-
del cp
48-
del linalg_all
49-
del _linalg
46+
def __dir__() -> list[str]:
47+
return __all__
Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,26 @@
11
from typing import Final
22

3-
from dask.array import * # noqa: F403
3+
from ..._internal import clone_module
4+
5+
__all__ = clone_module("dask.array", globals())
46

57
# These imports may overwrite names from the import * above.
8+
from . import _aliases
69
from ._aliases import * # type: ignore[assignment] # noqa: F403
10+
from ._info import __array_namespace_info__ # noqa: F401
711

812
__array_api_version__: Final = "2024.12"
13+
del Final
914

1015
# See the comment in the numpy __init__.py
1116
__import__(__package__ + '.linalg')
1217
__import__(__package__ + '.fft')
18+
19+
__all__ = sorted(
20+
set(__all__)
21+
| set(_aliases.__all__)
22+
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
23+
)
24+
25+
def __dir__() -> list[str]:
26+
return __all__

array_api_compat/dask/array/_aliases.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
NestedSequence,
4242
SupportsBufferProtocol,
4343
)
44-
from ._info import __array_namespace_info__
4544

4645
isdtype = get_xp(np)(_aliases.isdtype)
4746
unstack = get_xp(da)(_aliases.unstack)
@@ -355,7 +354,6 @@ def count_nonzero(
355354

356355

357356
__all__ = [
358-
"__array_namespace_info__",
359357
"count_nonzero",
360358
"bool",
361359
"int8", "int16", "int32", "int64",
@@ -369,8 +367,6 @@ def count_nonzero(
369367
"bitwise_left_shift", "bitwise_right_shift", "bitwise_invert",
370368
] # fmt: skip
371369
__all__ += _aliases.__all__
372-
_all_ignore = ["array_namespace", "get_xp", "da", "np"]
373-
374370

375371
def __dir__() -> list[str]:
376372
return __all__

array_api_compat/dask/array/fft.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
1-
from dask.array.fft import * # noqa: F403
2-
# dask.array.fft doesn't have __all__. If it is added, replace this with
3-
#
4-
# from dask.array.fft import __all__ as linalg_all
5-
_n: dict[str, object] = {}
6-
exec('from dask.array.fft import *', _n)
7-
for k in ("__builtins__", "Sequence", "annotations", "warnings"):
8-
_n.pop(k, None)
9-
fft_all = list(_n)
10-
del _n, k
1+
from ..._internal import clone_module
2+
3+
__all__ = clone_module("dask.array.fft", globals())
114

125
from ...common import _fft
136
from ..._internal import get_xp
@@ -17,5 +10,7 @@
1710
fftfreq = get_xp(da)(_fft.fftfreq)
1811
rfftfreq = get_xp(da)(_fft.rfftfreq)
1912

20-
__all__ = fft_all + ["fftfreq", "rfftfreq"]
21-
_all_ignore = ["da", "fft_all", "get_xp", "warnings"]
13+
__all__ += ["fftfreq", "rfftfreq"]
14+
15+
def __dir__() -> list[str]:
16+
return __all__

array_api_compat/dask/array/linalg.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,17 @@
44

55
import dask.array as da
66

7-
# Exports
8-
from dask.array.linalg import * # noqa: F403
9-
from dask.array import outer
107
# The `matmul` and `tensordot` functions are in both the main and linalg namespaces
11-
from dask.array import matmul, tensordot
12-
8+
from dask.array import matmul, outer, tensordot
139

14-
from ..._internal import get_xp
10+
# Exports
11+
from ..._internal import clone_module, get_xp
1512
from ...common import _linalg
1613
from ...common._typing import Array
17-
from ._aliases import matrix_transpose, vecdot
1814

19-
# dask.array.linalg doesn't have __all__. If it is added, replace this with
20-
#
21-
# from dask.array.linalg import __all__ as linalg_all
22-
_n: dict[str, object] = {}
23-
exec('from dask.array.linalg import *', _n)
24-
for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'):
25-
_n.pop(k, None)
26-
linalg_all = list(_n)
27-
del _n, k
15+
__all__ = clone_module("dask.array.linalg", globals())
16+
17+
from ._aliases import matrix_transpose, vecdot
2818

2919
EighResult = _linalg.EighResult
3020
QRResult = _linalg.QRResult
@@ -64,10 +54,11 @@ def svdvals(x: Array) -> Array:
6454
vector_norm = get_xp(da)(_linalg.vector_norm)
6555
diagonal = get_xp(da)(_linalg.diagonal)
6656

67-
__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot",
68-
"matrix_transpose", "vecdot", "EighResult",
69-
"QRResult", "SlogdetResult", "SVDResult", "qr",
70-
"cholesky", "matrix_rank", "matrix_norm", "svdvals",
71-
"vector_norm", "diagonal"]
57+
__all__ += ["trace", "outer", "matmul", "tensordot",
58+
"matrix_transpose", "vecdot", "EighResult",
59+
"QRResult", "SlogdetResult", "SVDResult", "qr",
60+
"cholesky", "matrix_rank", "matrix_norm", "svdvals",
61+
"vector_norm", "diagonal"]
7262

73-
_all_ignore = ['get_xp', 'da', 'linalg_all', 'warnings']
63+
def __dir__() -> list[str]:
64+
return __all__

array_api_compat/numpy/__init__.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
# ruff: noqa: PLC0414
22
from typing import Final
33

4-
from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary]
4+
from .._internal import clone_module
55

6-
# from numpy import * doesn't overwrite these builtin names
7-
from numpy import abs as abs
8-
from numpy import max as max
9-
from numpy import min as min
10-
from numpy import round as round
6+
# This needs to be loaded explicitly before cloning
7+
import numpy.typing # noqa: F401
8+
9+
__all__ = clone_module("numpy", globals())
1110

1211
# These imports may overwrite names from the import * above.
12+
from . import _aliases
1313
from ._aliases import * # type: ignore[assignment,no-redef] # noqa: F403
14+
from ._info import __array_namespace_info__ # noqa: F401
1415

1516
# Don't know why, but we have to do an absolute import to import linalg. If we
1617
# instead do
@@ -26,3 +27,12 @@
2627
from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401
2728

2829
__array_api_version__: Final = "2024.12"
30+
31+
__all__ = sorted(
32+
set(__all__)
33+
| set(_aliases.__all__)
34+
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
35+
)
36+
37+
def __dir__() -> list[str]:
38+
return __all__

array_api_compat/numpy/_aliases.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from .._internal import get_xp
1010
from ..common import _aliases, _helpers
1111
from ..common._typing import NestedSequence, SupportsBufferProtocol
12-
from ._info import __array_namespace_info__
1312
from ._typing import Array, Device, DType
1413

1514
bool = np.bool_
@@ -147,8 +146,7 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
147146
else:
148147
unstack = get_xp(np)(_aliases.unstack)
149148

150-
__all__ = [
151-
"__array_namespace_info__",
149+
__all__ = _aliases.__all__ + [
152150
"asarray",
153151
"astype",
154152
"acos",
@@ -167,8 +165,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
167165
"pow",
168166
"take_along_axis"
169167
]
170-
__all__ += _aliases.__all__
171-
_all_ignore = ["np", "get_xp"]
172168

173169

174170
def __dir__() -> list[str]:

array_api_compat/numpy/_typing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
Array: TypeAlias = np.ndarray
2424

2525
__all__ = ["Array", "DType", "Device"]
26-
_all_ignore = ["np"]
2726

2827

2928
def __dir__() -> list[str]:

array_api_compat/numpy/fft.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
2-
from numpy.fft import __all__ as fft_all
3-
from numpy.fft import fft2, ifft2, irfft2, rfft2
2+
3+
from .._internal import clone_module
4+
5+
__all__ = clone_module("numpy.fft", globals())
46

57
from .._internal import get_xp
68
from ..common import _fft
@@ -21,15 +23,8 @@
2123
ifftshift = get_xp(np)(_fft.ifftshift)
2224

2325

24-
__all__ = ["rfft2", "irfft2", "fft2", "ifft2"]
25-
__all__ += _fft.__all__
26-
26+
__all__ = sorted(set(__all__) | set(_fft.__all__))
2727

2828
def __dir__() -> list[str]:
2929
return __all__
3030

31-
32-
del get_xp
33-
del np
34-
del fft_all
35-
del _fft

0 commit comments

Comments
 (0)