Skip to content

Commit 86dc4a0

Browse files
committed
v4
1 parent 9088c68 commit 86dc4a0

File tree

10 files changed

+59
-87
lines changed

10 files changed

+59
-87
lines changed

array_api_compat/_internal.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Internal helpers
33
"""
44

5+
import importlib
56
from collections.abc import Callable
67
from functools import wraps
78
from inspect import signature
@@ -52,8 +53,20 @@ def wrapped_f(*args: object, **kwargs: object) -> object:
5253
return inner
5354

5455

55-
__all__ = ["get_xp"]
56+
def clone_module(mod_name: str, globals_: dict[str, object]) -> list[str]:
57+
"""Import everything from module, updating globals().
58+
Returns __all__.
59+
"""
60+
mod = importlib.import_module(mod_name)
61+
all_ = []
62+
for n in dir(mod):
63+
if not n.startswith("_") and hasattr(mod, n):
64+
all_.append(n)
65+
globals_[n] = getattr(mod, n)
66+
return all_
67+
5668

69+
__all__ = ["get_xp", "clone_module"]
5770

5871
def __dir__() -> list[str]:
5972
return __all__
Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
from typing import Final
22

3-
from dask.array import * # noqa: F403
3+
from ..._internal import clone_module
44

5-
# The above is missing a wealth of stuff
6-
import dask.array as da
7-
__all__ = [n for n in dir(da) if not n.startswith("_")]
8-
globals().update({n: getattr(da, n) for n in __all__})
9-
del da
5+
__all__ = clone_module("dask.array", globals())
106

117
# These imports may overwrite names from the import * above.
128
from . import _aliases
@@ -20,6 +16,11 @@
2016
__import__(__package__ + '.linalg')
2117
__import__(__package__ + '.fft')
2218

23-
__all__ += _aliases.__all__
24-
__all__ += ["__array_api_version__", "__array_namespace_info__", "linalg", "fft"]
25-
__all__ = sorted(set(__all__))
19+
__all__ = sorted(
20+
set(__all__)
21+
| set(_aliases.__all__)
22+
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
23+
)
24+
25+
def __dir__() -> list[str]:
26+
return __all__

array_api_compat/dask/array/fft.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
1-
from dask.array.fft import * # noqa: F403
2-
# dask.array.fft doesn't have __all__. If it is added, replace this with
3-
# from dask.array.fft import __all__ as fft_all
4-
_n: dict[str, object] = {}
5-
exec('from dask.array.fft import *', _n)
6-
fft_all = list(_n)
7-
del _n
1+
from ..._internal import clone_module
2+
3+
__all__ = clone_module("dask.array.fft", globals())
84

95
from ...common import _fft
106
from ..._internal import get_xp
@@ -14,7 +10,7 @@
1410
fftfreq = get_xp(da)(_fft.fftfreq)
1511
rfftfreq = get_xp(da)(_fft.rfftfreq)
1612

17-
__all__ = fft_all + ["fftfreq", "rfftfreq"]
13+
__all__ += ["fftfreq", "rfftfreq"]
1814

1915
def __dir__() -> list[str]:
2016
return __all__

array_api_compat/dask/array/linalg.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,13 @@
88
from dask.array import matmul, outer, tensordot
99

1010
# Exports
11-
from dask.array.linalg import * # noqa: F403
12-
13-
from ..._internal import get_xp
11+
from ..._internal import clone_module, get_xp
1412
from ...common import _linalg
1513
from ...common._typing import Array as _Array
16-
from ._aliases import matrix_transpose, vecdot
1714

18-
# dask.array.linalg doesn't have __all__. If it is added, replace this with
19-
#
20-
# from dask.array.linalg import __all__ as linalg_all
21-
_n = {}
22-
exec('from dask.array.linalg import *', _n)
23-
linalg_all = list(_n)
24-
del _n
15+
__all__ = clone_module("dask.array.linalg", globals())
16+
17+
from ._aliases import matrix_transpose, vecdot
2518

2619
EighResult = _linalg.EighResult
2720
QRResult = _linalg.QRResult
@@ -61,11 +54,11 @@ def svdvals(x: _Array) -> _Array:
6154
vector_norm = get_xp(da)(_linalg.vector_norm)
6255
diagonal = get_xp(da)(_linalg.diagonal)
6356

64-
__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot",
65-
"matrix_transpose", "vecdot", "EighResult",
66-
"QRResult", "SlogdetResult", "SVDResult", "qr",
67-
"cholesky", "matrix_rank", "matrix_norm", "svdvals",
68-
"vector_norm", "diagonal"]
57+
__all__ += ["trace", "outer", "matmul", "tensordot",
58+
"matrix_transpose", "vecdot", "EighResult",
59+
"QRResult", "SlogdetResult", "SVDResult", "qr",
60+
"cholesky", "matrix_rank", "matrix_norm", "svdvals",
61+
"vector_norm", "diagonal"]
6962

7063
def __dir__() -> list[str]:
7164
return __all__

array_api_compat/numpy/__init__.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
# ruff: noqa: PLC0414
22
from typing import Final
33

4-
import numpy as np
5-
from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary]
4+
from .._internal import clone_module
65

7-
# from numpy import * doesn't overwrite these builtin names
8-
from numpy import abs as abs
9-
from numpy import max as max
10-
from numpy import min as min
11-
from numpy import round as round
6+
__all__ = clone_module("numpy", globals())
127

138
# These imports may overwrite names from the import * above.
149
from . import _aliases
@@ -31,7 +26,7 @@
3126
__array_api_version__: Final = "2024.12"
3227

3328
__all__ = sorted(
34-
set(np.__all__)
29+
set(__all__)
3530
| set(_aliases.__all__)
3631
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
3732
)

array_api_compat/numpy/fft.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import numpy as np
2-
from numpy.fft import * # noqa: F403
32

4-
__all__ = [n for n in dir(np.fft) if not n.startswith("_")]
5-
globals().update({n: getattr(np.fft, n) for n in __all__})
3+
from .._internal import clone_module
4+
5+
__all__ = clone_module("numpy.fft", globals())
66

77
from .._internal import get_xp
88
from ..common import _fft

array_api_compat/numpy/linalg.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,17 @@
77

88
import numpy as np
99

10-
from numpy.linalg import * # noqa: F403
11-
12-
from .._internal import get_xp
10+
from .._internal import clone_module, get_xp
1311
from ..common import _linalg
1412

13+
from .._internal import clone_module
14+
15+
__all__ = clone_module("numpy.linalg", globals())
16+
1517
# These functions are in both the main and linalg namespaces
1618
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
1719
from ._typing import Array
1820

19-
__all__ = [n for n in dir(np.linalg) if not n.startswith("_")]
20-
globals().update({n: getattr(np.linalg, n) for n in __all__})
21-
2221
cross = get_xp(np)(_linalg.cross)
2322
outer = get_xp(np)(_linalg.outer)
2423
EighResult = _linalg.EighResult

array_api_compat/torch/__init__.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,8 @@
11
from typing import Final
22

3-
from torch import * # noqa: F403
3+
from .._internal import clone_module
44

5-
# Several names are not included in the above import *
6-
_torch_dir = set()
7-
import torch
8-
for n in dir(torch):
9-
if (n.startswith('_')
10-
or n.endswith('_')
11-
or 'backward' in n):
12-
continue
13-
exec(f"{n} = torch.{n}")
14-
_torch_dir.add(n)
15-
del n
16-
17-
# torch.__all__ is wildly incorrect
18-
_n: dict[str, object] = {}
19-
exec('from torch import *', _n)
20-
_torch_all = set(_n)
21-
del _n
5+
__all__ = clone_module("torch", globals())
226

237
# These imports may overwrite names from the import * above.
248
from . import _aliases
@@ -32,11 +16,10 @@
3216
__array_api_version__: Final = '2024.12'
3317

3418
__all__ = sorted(
35-
set(_torch_all)
19+
set(__all__)
3620
| set(_aliases.__all__)
3721
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
38-
| {"from_dlpack"}
3922
)
4023

4124
def __dir__() -> list[str]:
42-
return sorted(set(__all__) | set(_torch_dir))
25+
return __all__

array_api_compat/torch/fft.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44

55
import torch
66
import torch.fft
7-
from torch.fft import * # noqa: F403
87

98
from ._typing import Array
9+
from .._internal import clone_module
1010

11-
# The above is missing a wealth of stuff
12-
__all__ = [n for n in dir(torch.fft) if not n.startswith("_")]
13-
globals().update({n: getattr(torch.fft, n) for n in __all__})
11+
__all__ = clone_module("torch.fft", globals())
1412

1513
# Several torch fft functions do not map axes to dim
1614

@@ -77,10 +75,7 @@ def ifftshift(
7775
return torch.fft.ifftshift(x, dim=axes, **kwargs)
7876

7977

80-
__all__ = sorted(
81-
set(__all__)
82-
| {"fftn", "ifftn", "rfftn", "irfftn", "fftshift", "ifftshift"}
83-
)
78+
__all__ += ["fftn", "ifftn", "rfftn", "irfftn", "fftshift", "ifftshift"]
8479

8580
def __dir__() -> list[str]:
8681
return __all__

array_api_compat/torch/linalg.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,9 @@
33
import torch
44
from typing import Optional, Union, Tuple
55

6-
from torch.linalg import * # noqa: F403
6+
from .._internal import clone_module
77

8-
# torch.linalg doesn't define __all__
9-
# from torch.linalg import __all__ as linalg_all
10-
from torch import linalg as torch_linalg
11-
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
8+
__all__ = clone_module("torch.linalg", globals())
129

1310
# outer is implemented in torch but aren't in the linalg namespace
1411
from torch import outer
@@ -110,8 +107,8 @@ def vector_norm(
110107
return out
111108
return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
112109

113-
__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
114-
'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
110+
__all__ += ['outer', 'matmul', 'matrix_transpose', 'tensordot',
111+
'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
115112

116113
def __dir__() -> list[str]:
117114
return __all__

0 commit comments

Comments
 (0)