9
9
10
10
import sys
11
11
import math
12
+ import inspect
12
13
13
14
def is_numpy_array (x ):
14
15
# Avoid importing NumPy if it isn't already
@@ -49,6 +50,15 @@ def is_dask_array(x):
49
50
50
51
return isinstance (x , dask .array .Array )
51
52
53
+ def is_jax_array (x ):
54
+ # Avoid importing jax if it isn't already
55
+ if 'jax' not in sys .modules :
56
+ return False
57
+
58
+ import jax .numpy
59
+
60
+ return isinstance (x , jax .numpy .ndarray )
61
+
52
62
def is_array_api_obj (x ):
53
63
"""
54
64
Check if x is an array API compatible array object.
@@ -57,6 +67,7 @@ def is_array_api_obj(x):
57
67
or is_cupy_array (x ) \
58
68
or is_torch_array (x ) \
59
69
or is_dask_array (x ) \
70
+ or is_jax_array (x ) \
60
71
or hasattr (x , '__array_namespace__' )
61
72
62
73
def _check_api_version (api_version ):
@@ -112,6 +123,13 @@ def your_function(x, y):
112
123
namespaces .add (dask_namespace )
113
124
else :
114
125
raise TypeError ("_use_compat cannot be False if input array is a dask array!" )
126
+ elif is_jax_array (x ):
127
+ _check_api_version (api_version )
128
+ # jax.numpy is already an array namespace, but requires this
129
+ # side-effecting import for __array_namespace__ and some other
130
+ # things to be defined.
131
+ import jax .experimental .array_api as jnp
132
+ namespaces .add (jnp )
115
133
elif hasattr (x , '__array_namespace__' ):
116
134
namespaces .add (x .__array_namespace__ (api_version = api_version ))
117
135
else :
@@ -158,6 +176,15 @@ def device(x: "Array", /) -> "Device":
158
176
"""
159
177
if is_numpy_array (x ):
160
178
return "cpu"
179
+ if is_jax_array (x ):
180
+ # JAX has .device() as a method, but it is being deprecated so that it
181
+ # can become a property, in accordance with the standard. In order for
182
+ # this function to not break when JAX makes the flip, we check for
183
+ # both here.
184
+ if inspect .ismethod (x .device ):
185
+ return x .device ()
186
+ else :
187
+ return x .device
161
188
return x .device
162
189
163
190
# Based on cupy.array_api.Array.to_device
@@ -204,6 +231,12 @@ def _torch_to_device(x, device, /, stream=None):
204
231
raise NotImplementedError
205
232
return x .to (device )
206
233
234
+ def _jax_to_device (x , device , / , stream = None ):
235
+ import jax
236
+ if stream is not None :
237
+ raise NotImplementedError
238
+ return jax .device_put (x , device )
239
+
207
240
def to_device (x : "Array" , device : "Device" , / , * , stream : "Optional[Union[int, Any]]" = None ) -> "Array" :
208
241
"""
209
242
Copy the array from the device on which it currently resides to the specified ``device``.
@@ -243,6 +276,8 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A
243
276
if device == 'cpu' :
244
277
return x
245
278
raise ValueError (f"Unsupported device { device !r} " )
279
+ elif is_jax_array (x ):
280
+ return _jax_to_device (x , device , stream = stream )
246
281
return x .to_device (device , stream = stream )
247
282
248
283
def size (x ):
@@ -255,4 +290,4 @@ def size(x):
255
290
256
291
__all__ = ['is_array_api_obj' , 'array_namespace' , 'get_namespace' , 'device' ,
257
292
'to_device' , 'size' , 'is_numpy_array' , 'is_cupy_array' ,
258
- 'is_torch_array' , 'is_dask_array' ]
293
+ 'is_torch_array' , 'is_dask_array' , 'is_jax_array' ]
0 commit comments