Skip to content

Commit c7024c4

Browse files
committed
fixes
1 parent 9345e5f commit c7024c4

File tree

3 files changed

+22
-8
lines changed

3 files changed

+22
-8
lines changed

array_api_compat/torch/__init__.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,25 @@
33
from torch import * # noqa: F403
44

55
# Several names are not included in the above import *
6-
_torch_all = set()
6+
_torch_dir = set()
77
import torch
88
for n in dir(torch):
99
if (n.startswith('_')
1010
or n.endswith('_')
11-
or 'cuda' in n
12-
or 'cpu' in n
1311
or 'backward' in n):
1412
continue
1513
exec(f"{n} = torch.{n}")
16-
_torch_all.add(n)
14+
_torch_dir.add(n)
1715
del n
1816

17+
# torch.__all__ is wildly incorrect
18+
_n: dict[str, object] = {}
19+
exec('from torch import *', _n)
20+
_torch_all = set(_n)
21+
del _n
22+
1923
# These imports may overwrite names from the import * above.
20-
import _aliases
24+
from . import _aliases
2125
from ._aliases import * # noqa: F403
2226
from ._info import __array_namespace_info__ # noqa: F401
2327

@@ -31,7 +35,8 @@
3135
set(_torch_all)
3236
| set(_aliases.__all__)
3337
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
38+
| {"from_dlpack"}
3439
)
3540

3641
def __dir__() -> list[str]:
37-
return __all__
42+
return sorted(set(__all__) | set(_torch_dir))

array_api_compat/torch/fft.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,15 @@ def ifftshift(
7373
return torch.fft.ifftshift(x, dim=axes, **kwargs)
7474

7575

76-
__all__ = torch.fft.__all__ + [
76+
_all = {
7777
"fftn",
7878
"ifftn",
7979
"rfftn",
8080
"irfftn",
8181
"fftshift",
8282
"ifftshift",
83-
]
83+
}
84+
__all__ = sorted(set(torch.fft.__all__) |_all)
85+
86+
def __dir__() -> list[str]:
87+
return sorted(set(dir(torch.fft)) | _all)

tests/test_all.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
# Inspection
1313
"__array_api_version__",
1414
"__array_namespace_info__",
15+
# Submodules
16+
"fft",
17+
"linalg",
1518
# Constants
1619
"e",
1720
"inf",
@@ -240,6 +243,8 @@ def all_names(mod):
240243
This is typically `__all__` but, if not defined, Python
241244
implements automated fallbacks.
242245
"""
246+
# Note: this method also makes the test trip if a name is
247+
# in __all__ but doesn't actually appear in the module.
243248
objs = {}
244249
exec(f"from {mod.__name__} import *", objs)
245250
return list(objs)

0 commit comments

Comments
 (0)