Skip to content

Commit 6bf5dad

Browse files
committed
Add JAX support
Unlike other libraries, there is no wrapping for JAX. Actual JAX array_api support is in JAX itself in the jax.experimental.array_api submodule. This just adds JAX support to the various helper functions. This also means that we do not run array-api-tests on JAX. Closes #83.
1 parent 12b5294 commit 6bf5dad

File tree

4 files changed

+64
-10
lines changed

4 files changed

+64
-10
lines changed

README.md

+16-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
This is a small wrapper around common array libraries that is compatible with
44
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
5-
NumPy, CuPy, PyTorch, and Dask are supported. If you want support for other array
5+
NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array
66
libraries, or if you encounter any issues, please [open an
77
issue](https://github.com/data-apis/array-api-compat/issues).
88

@@ -60,6 +60,12 @@ import array_api_compat.torch as torch
6060
import array_api_compat.dask as da
6161
```
6262

63+
> [!NOTE]
64+
> There is no `array_api_compat.jax` submodule. JAX support is contained
65+
> in JAX itself in the `jax.experimental.array_api` module. array-api-compat simply
66+
> wraps that submodule. The main JAX support in this module consists of
67+
> supporting it in the [helper functions](#helper-functions) defined below.
68+
6369
Each will include all the functions from the normal NumPy/CuPy/PyTorch/dask.array
6470
namespace, except that functions that are part of the array API are wrapped so
6571
that they have the correct array API behavior. In each case, the array object
@@ -104,9 +110,9 @@ part of the specification but which are useful for using the array API:
104110
object.
105111

106112
- `is_numpy_array(x)`, `is_cupy_array(x)`, `is_torch_array(x)`,
107-
`is_dask_array(x)`: return `True` if `x` is an array from the corresponding
108-
library. These functions do not import the underlying library if it has not
109-
already been imported, so they are cheap to use.
113+
`is_dask_array(x)`, `is_jax_array(x)`: return `True` if `x` is an array from
114+
the corresponding library. These functions do not import the underlying
115+
library if it has not already been imported, so they are cheap to use.
110116

111117
- `array_namespace(*xs)`: Get the corresponding array API namespace for the
112118
arrays `xs`. For example, if the arrays are NumPy arrays, the returned
@@ -228,6 +234,12 @@ version.
228234

229235
The minimum supported PyTorch version is 1.13.
230236

237+
### JAX
238+
239+
Unlike the other libraries supported here, JAX array API support is contained
240+
entirely in the JAX library. The JAX array API support is tracked at
241+
https://github.com/google/jax/issues/18353.
242+
231243
## Vendoring
232244

233245
This library supports vendoring as an installation method. To vendor the

array_api_compat/common/_helpers.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import sys
1111
import math
12+
import inspect
1213

1314
def is_numpy_array(x):
1415
# Avoid importing NumPy if it isn't already
@@ -49,6 +50,15 @@ def is_dask_array(x):
4950

5051
return isinstance(x, dask.array.Array)
5152

53+
def is_jax_array(x):
54+
# Avoid importing jax if it isn't already
55+
if 'jax' not in sys.modules:
56+
return False
57+
58+
import jax.numpy
59+
60+
return isinstance(x, jax.numpy.ndarray)
61+
5262
def is_array_api_obj(x):
5363
"""
5464
Check if x is an array API compatible array object.
@@ -57,6 +67,7 @@ def is_array_api_obj(x):
5767
or is_cupy_array(x) \
5868
or is_torch_array(x) \
5969
or is_dask_array(x) \
70+
or is_jax_array(x) \
6071
or hasattr(x, '__array_namespace__')
6172

6273
def _check_api_version(api_version):
@@ -112,6 +123,13 @@ def your_function(x, y):
112123
namespaces.add(dask_namespace)
113124
else:
114125
raise TypeError("_use_compat cannot be False if input array is a dask array!")
126+
elif is_jax_array(x):
127+
_check_api_version(api_version)
128+
# jax.numpy is already an array namespace, but requires this
129+
# side-effecting import for __array_namespace__ and some other
130+
# things to be defined.
131+
import jax.experimental.array_api as jnp
132+
namespaces.add(jnp)
115133
elif hasattr(x, '__array_namespace__'):
116134
namespaces.add(x.__array_namespace__(api_version=api_version))
117135
else:
@@ -158,6 +176,15 @@ def device(x: "Array", /) -> "Device":
158176
"""
159177
if is_numpy_array(x):
160178
return "cpu"
179+
if is_jax_array(x):
180+
# JAX has .device() as a method, but it is being deprecated so that it
181+
# can become a property, in accordance with the standard. In order for
182+
# this function to not break when JAX makes the flip, we check for
183+
# both here.
184+
if inspect.ismethod(x.device):
185+
return x.device()
186+
else:
187+
return x.device
161188
return x.device
162189

163190
# Based on cupy.array_api.Array.to_device
@@ -204,6 +231,12 @@ def _torch_to_device(x, device, /, stream=None):
204231
raise NotImplementedError
205232
return x.to(device)
206233

234+
def _jax_to_device(x, device, /, stream=None):
235+
import jax
236+
if stream is not None:
237+
raise NotImplementedError
238+
return jax.device_put(x, device)
239+
207240
def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array":
208241
"""
209242
Copy the array from the device on which it currently resides to the specified ``device``.
@@ -243,6 +276,8 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A
243276
if device == 'cpu':
244277
return x
245278
raise ValueError(f"Unsupported device {device!r}")
279+
elif is_jax_array(x):
280+
return _jax_to_device(x, device, stream=stream)
246281
return x.to_device(device, stream=stream)
247282

248283
def size(x):
@@ -255,4 +290,4 @@ def size(x):
255290

256291
__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device',
257292
'to_device', 'size', 'is_numpy_array', 'is_cupy_array',
258-
'is_torch_array', 'is_dask_array']
293+
'is_torch_array', 'is_dask_array', 'is_jax_array']

tests/test_helpers.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array,
2-
is_dask_array, is_array_api_obj)
2+
is_dask_array, is_jax_array, is_array_api_obj)
33

44
from ._helpers import import_
55

@@ -10,6 +10,7 @@
1010
'cupy': 'is_cupy_array',
1111
'torch': 'is_torch_array',
1212
'dask.array': 'is_dask_array',
13+
'jax.numpy': 'is_jax_array',
1314
}
1415

1516
@pytest.mark.parametrize('library', is_functions.keys())

tests/test_isdtype.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,12 @@ def isdtype_(dtype_, kind):
6464
assert type(res) is bool
6565
return res
6666

67-
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
67+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
6868
def test_isdtype_spec_dtypes(library):
69-
xp = import_('array_api_compat.' + library)
69+
if library == "jax.numpy":
70+
xp = import_('jax.experimental.array_api')
71+
else:
72+
xp = import_('array_api_compat.' + library)
7073

7174
isdtype = xp.isdtype
7275

@@ -98,10 +101,13 @@ def test_isdtype_spec_dtypes(library):
98101
'bfloat16',
99102
]
100103

101-
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
104+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
102105
@pytest.mark.parametrize("dtype_", additional_dtypes)
103106
def test_isdtype_additional_dtypes(library, dtype_):
104-
xp = import_('array_api_compat.' + library)
107+
if library == "jax.numpy":
108+
xp = import_('jax.experimental.array_api')
109+
else:
110+
xp = import_('array_api_compat.' + library)
105111

106112
isdtype = xp.isdtype
107113

0 commit comments

Comments
 (0)