Skip to content

Commit e722706

Browse files
committed
Add the api_version keyword to get_namespace
1 parent a1bb958 commit e722706

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

array_api_compat/common/_helpers.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ def is_array_api_obj(x):
4949
or _is_torch_array(x) \
5050
or hasattr(x, '__array_namespace__')
5151

52-
def get_namespace(*xs, _use_compat=True):
52+
def _check_api_version(api_version):
53+
if api_version is not None and api_version != '2021.12':
54+
raise ValueError("Only the 2021.12 version of the array API specification is currently supported")
55+
56+
def get_namespace(*xs, api_version=None, _use_compat=True):
5357
"""
5458
Get the array API compatible namespace for the arrays `xs`.
5559
@@ -61,28 +65,34 @@ def your_function(x, y):
6165
xp = array_api_compat.get_namespace(x, y)
6266
# Now use xp as the array library namespace
6367
return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
68+
69+
api_version should be the newest version of the spec that you need support
70+
for (currently the compat library wrapped APIs only support v2021.12).
6471
"""
6572
namespaces = set()
6673
for x in xs:
6774
if isinstance(x, (tuple, list)):
6875
namespaces.add(get_namespace(*x, _use_compat=_use_compat))
6976
elif hasattr(x, '__array_namespace__'):
70-
namespaces.add(x.__array_namespace__())
77+
namespaces.add(x.__array_namespace__(api_version=api_version))
7178
elif _is_numpy_array(x):
79+
_check_api_version(api_version)
7280
if _use_compat:
7381
from .. import numpy as numpy_namespace
7482
namespaces.add(numpy_namespace)
7583
else:
7684
import numpy as np
7785
namespaces.add(np)
7886
elif _is_cupy_array(x):
87+
_check_api_version(api_version)
7988
if _use_compat:
8089
from .. import cupy as cupy_namespace
8190
namespaces.add(cupy_namespace)
8291
else:
8392
import cupy as cp
8493
namespaces.add(cp)
8594
elif _is_torch_array(x):
95+
_check_api_version(api_version)
8696
if _use_compat:
8797
from .. import torch as torch_namespace
8898
namespaces.add(torch_namespace)

tests/test_get_namespace.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,17 @@
44

55

66
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
7-
def test_get_namespace(library):
7+
@pytest.mark.parametrize("api_version", [None, '2021.12'])
8+
def test_get_namespace(library, api_version):
89
lib = pytest.importorskip(library)
910

1011
array = lib.asarray([1.0, 2.0, 3.0])
11-
namespace = array_api_compat.get_namespace(array)
12+
namespace = array_api_compat.get_namespace(array, api_version=api_version)
1213

13-
expected_namespace = getattr(array_api_compat, library)
14-
assert namespace is expected_namespace
15-
16-
17-
@pytest.mark.parametrize("array_namespace", ["cupy.array_api", "numpy.array_api"])
18-
def test_get_namespace_returns_actual_namespace(array_namespace):
19-
xp = pytest.importorskip(array_namespace)
20-
X = xp.asarray([1, 2, 3])
21-
xp_ = get_namespace(X)
22-
assert xp_ is xp
14+
if 'array_api' in library:
15+
assert namespace == lib
16+
else:
17+
assert namespace == getattr(array_api_compat, library)
2318

2419
def test_get_namespace_multiple():
2520
import numpy as np
@@ -38,3 +33,5 @@ def test_get_namespace_errors():
3833
y = torch.asarray([1, 2])
3934

4035
pytest.raises(TypeError, lambda: get_namespace(x, y))
36+
37+
pytest.raises(ValueError, lambda: get_namespace(x, api_version='2022.12'))

0 commit comments

Comments
 (0)