Skip to content

Commit b3a12d9

Browse files
committed
Move all the NumPy functionality into a numpy submodule
There will also be a 'cupy' submodule. Common code will be factored out into the 'common' submodule.
1 parent 1751356 commit b3a12d9

File tree

9 files changed

+70
-63
lines changed

9 files changed

+70
-63
lines changed

numpy_array_api_compat/__init__.py

-18
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,3 @@
4040
- NumPy functions which are not wrapped may not use positional-only arguments.
4141
4242
"""
43-
44-
from numpy import *
45-
46-
# These imports may overwrite names from the import * above.
47-
from ._aliases import *
48-
49-
# Don't know why, but we have to do an absolute import to import linalg. If we
50-
# instead do
51-
#
52-
# from . import linalg
53-
#
54-
# It doesn't overwrite np.linalg from above. The import is generated
55-
# dynamically so that the library can be vendored.
56-
__import__(__package__ + '.linalg')
57-
58-
from .linalg import matrix_transpose, vecdot
59-
60-
from ._helpers import *

numpy_array_api_compat/_internal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from functools import wraps
66
from inspect import signature
77

8-
from ._helpers import get_namespace
8+
from .common._helpers import get_namespace
99

1010
def get_xp(f):
1111
"""

numpy_array_api_compat/common/__init__.py

Whitespace-only changes.
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""
2+
Various helper functions which are not part of the spec.
3+
"""
4+
def get_namespace(*xs, _use_compat=True):
5+
"""
6+
Get the array API compatible namespace for the arrays `xs`.
7+
8+
`xs` should contain one or more arrays.
9+
"""
10+
from ..numpy._helpers import _is_numpy_array
11+
12+
namespaces = set()
13+
for x in xs:
14+
if isinstance(x, (tuple, list)):
15+
namespaces.add(get_namespace(*x, _use_compat=_use_compat))
16+
elif hasattr(x, '__array_namespace__'):
17+
namespaces.add(x.__array_namespace__)
18+
elif _is_numpy_array(x):
19+
if _use_compat:
20+
from .. import numpy as numpy_namespace
21+
namespaces.add(numpy_namespace)
22+
else:
23+
import numpy as np
24+
namespaces.add(np)
25+
else:
26+
# TODO: Support Python scalars?
27+
raise ValueError("The input is not a supported array type")
28+
29+
if not namespaces:
30+
raise ValueError("Unrecognized array input")
31+
32+
if len(namespaces) != 1:
33+
raise ValueError(f"Multiple namespaces for array inputs: {namespaces}")
34+
35+
xp, = namespaces
36+
37+
return xp
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from numpy import *
2+
3+
# from numpy import * doesn't overwrite these builtin names
4+
from numpy import abs, max, min, round
5+
6+
# These imports may overwrite names from the import * above.
7+
from ._aliases import *
8+
9+
# Don't know why, but we have to do an absolute import to import linalg. If we
10+
# instead do
11+
#
12+
# from . import linalg
13+
#
14+
# It doesn't overwrite np.linalg from above. The import is generated
15+
# dynamically so that the library can be vendored.
16+
__import__(__package__ + '.linalg')
17+
18+
from .linalg import matrix_transpose, vecdot
19+
20+
from ._helpers import *

numpy_array_api_compat/_aliases.py renamed to numpy_array_api_compat/numpy/_aliases.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
from typing import Optional, Tuple, Union, List
1010
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
1111

12+
from functools import partial
1213
from typing import NamedTuple
1314
from types import ModuleType
1415

1516
from ._helpers import _is_numpy_array, get_namespace
16-
from ._internal import get_xp
17+
from .._internal import get_xp
1718

1819
# Basic renames
1920
@get_xp
@@ -194,7 +195,7 @@ def _check_device(device):
194195
raise ValueError(f"Unsupported device {device!r}")
195196

196197
# asarray also adds the copy keyword
197-
def asarray(
198+
def _asarray(
198199
obj: Union[
199200
ndarray,
200201
bool,
@@ -245,6 +246,8 @@ def asarray(
245246

246247
return xp.asarray(obj, dtype=dtype)
247248

249+
asarray_numpy = partial(_asarray, namespace='numpy')
250+
248251
@get_xp
249252
def arange(
250253
start: Union[int, float],
@@ -462,15 +465,12 @@ def trunc(x: ndarray, /, xp) -> ndarray:
462465
return x
463466
return xp.trunc(x)
464467

465-
# from numpy import * doesn't overwrite these builtin names
466-
from numpy import abs, max, min, round
467-
468468
__all__ = ['acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh',
469469
'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift',
470470
'bool', 'concat', 'pow', 'UniqueAllResult', 'UniqueCountsResult',
471471
'UniqueInverseResult', 'unique_all', 'unique_counts',
472-
'unique_inverse', 'unique_values', 'astype', 'abs', 'max', 'min',
473-
'round', 'std', 'var', 'permute_dims', 'asarray', 'arange',
474-
'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace',
475-
'ones', 'ones_like', 'zeros', 'zeros_like', 'reshape', 'argsort',
476-
'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc']
472+
'unique_inverse', 'unique_values', 'astype', 'std', 'var',
473+
'permute_dims', 'asarray_numpy', 'arange', 'empty', 'empty_like',
474+
'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like',
475+
'zeros', 'zeros_like', 'reshape', 'argsort', 'sort', 'sum', 'prod',
476+
'ceil', 'floor', 'trunc']

numpy_array_api_compat/_helpers.py renamed to numpy_array_api_compat/numpy/_helpers.py

+2-34
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44

55
from __future__ import annotations
66

7-
import importlib
8-
compat_namespace = importlib.import_module(__package__)
9-
107
import numpy as np
118

9+
from ..common._helpers import get_namespace
10+
1211
def _is_numpy_array(x):
1312
# TODO: Should we reject ndarray subclasses?
1413
return isinstance(x, (np.ndarray, np.generic))
@@ -19,37 +18,6 @@ def is_array_api_obj(x):
1918
"""
2019
return _is_numpy_array(x) or hasattr(x, '__array_namespace__')
2120

22-
def get_namespace(*xs, _use_compat=True):
23-
"""
24-
Get the array API compatible namespace for the arrays `xs`.
25-
26-
`xs` should contain one or more arrays.
27-
"""
28-
namespaces = set()
29-
for x in xs:
30-
if isinstance(x, (tuple, list)):
31-
namespaces.add(get_namespace(*x, _use_compat=_use_compat))
32-
elif hasattr(x, '__array_namespace__'):
33-
namespaces.add(x.__array_namespace__)
34-
elif _is_numpy_array(x):
35-
if _use_compat:
36-
namespaces.add(compat_namespace)
37-
else:
38-
namespaces.add(np)
39-
else:
40-
# TODO: Support Python scalars?
41-
raise ValueError("The input is not a supported array type")
42-
43-
if not namespaces:
44-
raise ValueError("Unrecognized array input")
45-
46-
if len(namespaces) != 1:
47-
raise ValueError(f"Multiple namespaces for array inputs: {namespaces}")
48-
49-
xp, = namespaces
50-
51-
return xp
52-
5321
# device and to_device are not included in array object of this library
5422
# because this library just reuses ndarray without wrapping or subclassing it.
5523
# These helper functions can be used instead of the wrapper functions for

0 commit comments

Comments
 (0)