diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index e6adc948..ee742b16 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -71,9 +71,7 @@ def your_function(x, y): """ namespaces = set() for x in xs: - if isinstance(x, (tuple, list)): - namespaces.add(array_namespace(*x, _use_compat=_use_compat)) - elif hasattr(x, '__array_namespace__'): + if hasattr(x, '__array_namespace__'): namespaces.add(x.__array_namespace__(api_version=api_version)) elif _is_numpy_array(x): _check_api_version(api_version) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 806b1192..1675377d 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -19,26 +19,25 @@ def test_array_namespace(library, api_version): else: assert namespace == getattr(array_api_compat, library) -def test_array_namespace_multiple(): - import numpy as np - - x = np.asarray([1, 2]) - assert array_namespace(x, x) == array_namespace((x, x)) == \ - array_namespace((x, x), x) == array_api_compat.numpy def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace([1])) pytest.raises(TypeError, lambda: array_namespace()) import numpy as np - import torch x = np.asarray([1, 2]) + + pytest.raises(TypeError, lambda: array_namespace((x, x))) + pytest.raises(TypeError, lambda: array_namespace(x, (x, x))) + + import torch y = torch.asarray([1, 2]) pytest.raises(TypeError, lambda: array_namespace(x, y)) pytest.raises(ValueError, lambda: array_namespace(x, api_version='2022.12')) + def test_get_namespace(): # Backwards compatible wrapper assert array_api_compat.get_namespace is array_api_compat.array_namespace