Skip to content

Commit 0acc2b0

Browse files
committed
ENH: add conversion to Array API compatible namespace
1 parent 3b90238 commit 0acc2b0

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

array_api_compat/common/_helpers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,17 @@ def your_function(x, y):
6969
api_version should be the newest version of the spec that you need support
7070
for (currently the compat library wrapped APIs only support v2021.12).
7171
"""
72+
# convert fallback_namespace
73+
if fallback_namespace is not None:
74+
try:
75+
x_ = fallback_namespace.asarray(1)
76+
fallback_namespace = array_namespace(
77+
x_, _use_compat=_use_compat
78+
)
79+
except AttributeError as exc:
80+
msg = "'fallback_namespace' must be an Array API compatible namespace"
81+
raise TypeError(msg) from exc
82+
7283
namespaces = set()
7384
for x in xs:
7485
if isinstance(x, (tuple, list)):

tests/test_array_namespace.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def test_array_namespace_multiple():
2828

2929
def test_fallback_namespace():
3030
import numpy as np
31+
import numpy.array_api
3132
import array_api_compat.numpy
3233

3334
xp = array_api_compat.numpy
@@ -37,9 +38,18 @@ def test_fallback_namespace():
3738
xp_ = array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace=xp)
3839
assert xp_ == xp
3940

41+
# convert to Array API compatible namespace
42+
xp = array_api_compat.numpy
43+
xp_ = array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace=np)
44+
assert xp_ == xp
45+
4046
msg = 'Multiple namespaces'
4147
with pytest.raises(TypeError, match=msg):
42-
array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace=np)
48+
array_namespace([1, 2], numpy.array_api.asarray([1, 2]), fallback_namespace=np)
49+
50+
msg = "'fallback_namespace' must be an Array API"
51+
with pytest.raises(TypeError, match=msg):
52+
array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace="hop")
4353

4454

4555
def test_array_namespace_errors():

0 commit comments

Comments
 (0)