-
Notifications
You must be signed in to change notification settings - Fork 34
JAX support #83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@jakevdp has been working on adding array API support to JAX. I think his plan was to have full support in a JAX submodule so that array-api-compat library support would be unnecessary, but maybe he can clarify if that's still the plan and what the latest status of that is. |
Hi - thanks for the tag! We are tracking JAX Array API support here: jax-ml/jax#18353 You can experimentally enable it by importing import jax.numpy as jnp
# This is a side-effecting import that among other
# things adds __array_namespace__ to JAX arrays.
import jax.experimental.array_api as xp
print(xp is jnp.array(0).__array_namespace__()) # Standard access to array_api
# True
print(xp.arange(10)) # most Array API functions are already implemented at HEAD.
# [0 1 2 3 4 5 6 7 8 9] There are still some fundamental questions, though: for example, JAX arrays are immutable so they cannot support pieces of the API standard that require mutability, and JAX arrays had a pre-existing |
So with >>> import jax.experimental.array_api as xp
>>> import array_api_compat
>>> array_api_compat.array_namespace(xp.arange(10)) is xp
True I think the main thing to do here is to add jax support to the I don't think it's necessary to add an
What are the other things it does? I would suggest not using a side-effecting import, but rather just add Otherwise, whether jax.numpy arrays will work with a library like scipy will depend on whether the user (or someone else) has already imported |
The only other patching currently is adding the
This is the intent right now: JAX is not yet ready to non-experimentally support the array API. But in the future when we've gotten to the point that we can advertise non-experimental support, the import will no longer be necessary. |
array-api-compat also has a
So should we go ahead and fix the array-api-compat |
Interesting - thanks. Realistically it will be 1-2 months more before we can do a JAX release with compatible |
IMO it would make sense to just wait the 1-2 months to avoid extra work here. Support from consumer libraries is still quite sparse and probably won't move forward a huge amount in the next 2 months (and being able to test with JAX a month or so earlier probably won't accelerate things). |
Either way. It really isn't hard to add it here. |
As a consumer library maintainer, I am accessing the Array API through |
OK, I've added basic JAX support to the helper functions at #84. |
Thank you!! |
Okay, well here it is. I don't know if I'd have time to add it myself in the near future, but it would be great to have.
The text was updated successfully, but these errors were encountered: