Skip to content

Commit 874c2ff

Browse files
authored
Merge pull request #42 from tupui/only_arrays
ENH: disallow passing array-like to array_namespace
2 parents 51d7832 + 9507a42 commit 874c2ff

File tree

2 files changed

+7
-10
lines changed

2 files changed

+7
-10
lines changed

array_api_compat/common/_helpers.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,7 @@ def your_function(x, y):
7171
"""
7272
namespaces = set()
7373
for x in xs:
74-
if isinstance(x, (tuple, list)):
75-
namespaces.add(array_namespace(*x, _use_compat=_use_compat))
76-
elif hasattr(x, '__array_namespace__'):
74+
if hasattr(x, '__array_namespace__'):
7775
namespaces.add(x.__array_namespace__(api_version=api_version))
7876
elif _is_numpy_array(x):
7977
_check_api_version(api_version)

tests/test_array_namespace.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,25 @@ def test_array_namespace(library, api_version):
1919
else:
2020
assert namespace == getattr(array_api_compat, library)
2121

22-
def test_array_namespace_multiple():
23-
import numpy as np
24-
25-
x = np.asarray([1, 2])
26-
assert array_namespace(x, x) == array_namespace((x, x)) == \
27-
array_namespace((x, x), x) == array_api_compat.numpy
2822

2923
def test_array_namespace_errors():
3024
pytest.raises(TypeError, lambda: array_namespace([1]))
3125
pytest.raises(TypeError, lambda: array_namespace())
3226

3327
import numpy as np
34-
import torch
3528
x = np.asarray([1, 2])
29+
30+
pytest.raises(TypeError, lambda: array_namespace((x, x)))
31+
pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))
32+
33+
import torch
3634
y = torch.asarray([1, 2])
3735

3836
pytest.raises(TypeError, lambda: array_namespace(x, y))
3937

4038
pytest.raises(ValueError, lambda: array_namespace(x, api_version='2022.12'))
4139

40+
4241
def test_get_namespace():
4342
# Backwards compatible wrapper
4443
assert array_api_compat.get_namespace is array_api_compat.array_namespace

0 commit comments

Comments
 (0)