@@ -49,7 +49,11 @@ def is_array_api_obj(x):
49
49
or _is_torch_array (x ) \
50
50
or hasattr (x , '__array_namespace__' )
51
51
52
- def get_namespace (* xs , _use_compat = True ):
52
+ def _check_api_version (api_version ):
53
+ if api_version is not None and api_version != '2021.12' :
54
+ raise ValueError ("Only the 2021.12 version of the array API specification is currently supported" )
55
+
56
+ def get_namespace (* xs , api_version = None , _use_compat = True ):
53
57
"""
54
58
Get the array API compatible namespace for the arrays `xs`.
55
59
@@ -61,28 +65,34 @@ def your_function(x, y):
61
65
xp = array_api_compat.get_namespace(x, y)
62
66
# Now use xp as the array library namespace
63
67
return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
68
+
69
+ api_version should be the newest version of the spec that you need support
70
+ for (currently the compat library wrapped APIs only support v2021.12).
64
71
"""
65
72
namespaces = set ()
66
73
for x in xs :
67
74
if isinstance (x , (tuple , list )):
68
75
namespaces .add (get_namespace (* x , _use_compat = _use_compat ))
69
76
elif hasattr (x , '__array_namespace__' ):
70
- namespaces .add (x .__array_namespace__ ())
77
+ namespaces .add (x .__array_namespace__ (api_version = api_version ))
71
78
elif _is_numpy_array (x ):
79
+ _check_api_version (api_version )
72
80
if _use_compat :
73
81
from .. import numpy as numpy_namespace
74
82
namespaces .add (numpy_namespace )
75
83
else :
76
84
import numpy as np
77
85
namespaces .add (np )
78
86
elif _is_cupy_array (x ):
87
+ _check_api_version (api_version )
79
88
if _use_compat :
80
89
from .. import cupy as cupy_namespace
81
90
namespaces .add (cupy_namespace )
82
91
else :
83
92
import cupy as cp
84
93
namespaces .add (cp )
85
94
elif _is_torch_array (x ):
95
+ _check_api_version (api_version )
86
96
if _use_compat :
87
97
from .. import torch as torch_namespace
88
98
namespaces .add (torch_namespace )
0 commit comments