Skip to content

Add functions such as take to existing Array API spec if not implemented yet? #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
thomasjpfan opened this issue Mar 3, 2023 · 11 comments

Comments

@thomasjpfan
Copy link
Contributor

Should array-api-compat "update" the namespace for existing Array API arrays to the most recent spec? For example:

# Assume that `xp.take` is not implemented in the installed CuPy version
import cupy.array_api as xp
import array_api_compat

X = xp.asarray([1.0, 2.0])
xp = array_api_compat.get_namespace(X)

# Should this always be true?
assert hasattr(xp, "take")
@rgommers
Copy link
Member

rgommers commented Mar 5, 2023

I'd say yes, that should be a goal. Taking over the implementation from cupy/cupy#7432 would be good I'd think.

@asmeurer
Copy link
Member

asmeurer commented Mar 6, 2023

I'd say changes like this are definitely in scope. We will add it eventually once we add the 2023 version of the spec, but we can add it sooner if it would help.

And as Ralf pointed out, we also still need to update numpy.array_api and cupy.array_api with 2023 functionality.

@rgommers
Copy link
Member

rgommers commented Mar 6, 2023

Let's add it now indeed - scikit-learn was needing take which is why we prioritized adding it to the standard. It's not draft by the way, it was in the 2022 standard: https://data-apis.org/array-api/latest/changelog.html#v2022-12

@asmeurer
Copy link
Member

asmeurer commented Mar 6, 2023

Yes, sorry I meant to type 2022 above, not 2023.

@asmeurer
Copy link
Member

asmeurer commented Mar 6, 2023

The main challenge here is there aren't any tests for take yet in the test suite. @honno if you could work on adding take to data-apis/array-api-tests#165 (or a new PR) that would help. I noticed that torch.take will need wrapping because it doesn't have an axis keyword (if you don't use the axis keyword the existing torch wrapper should already have a take function that will work).

NumPy and CuPy already have take. Adding it to numpy.array_api and cupy.array_api is a different story. We would need to upstream it there.

Presently this compat library doesn't wrap or do anything existing array API compatible libraries like numpy.array_api. It just returns them as-is. So there are a few options here:

  • We will eventually upstream 2022 spec support to numpy.array_api (and cupy.array_api will presumably follow suit). However, this isn't being worked on yet, and even once it happens, it will require a numpy release to be usable.
  • If you need it sooner, we can do one of two things in this compat library:
    • Make it so that from array_api_compat import take works. This would be a function that works on numpy, cupy, torch, numpy.array_api, and cupy.array_api (and any other library if it already has take)
    • Make it so that get_namespace returns a wrapped module for numpy.array_api. That way we can make it so that xp.take always works. Presently, xp = get_namespace(numpy.array_api.asarray(...)) returns the original numpy.array_api module unwrapped, meaning it doesn't include any functionality from this compat layer at all.

The difference for these two is really between using array_api_compat.take vs. xp.take.

@rgommers
Copy link
Member

rgommers commented Mar 7, 2023

We will eventually upstream 2022 spec support to numpy.array_api (and cupy.array_api will presumably follow suit). However, this isn't being worked on yet, and even once it happens, it will require a numpy release to be usable.

This is already being done for take:

Should array-api-compat "update" the namespace for existing Array API arrays to the most recent spec?

I'm not clear on the need for this request for xp.take though. The point of array-api-compat is to extend the main namespaces and avoid the need to use the separate submodules (numpy.array_api and cupy.array_api). @thomasjpfan can you comment on this?

@thomasjpfan
Copy link
Contributor Author

thomasjpfan commented Mar 7, 2023

I'm not clear on the need for this request for xp.take though.

If the goal is to remove numpy.array_api and cupy.array_api, then take does not need to be added to existing Array API namespaces. For libraries that support Array API and need take, they will require the Array API library to be updated with the v2022 spec.

It does raise the question about how versioning works here. Consider:

# custom_array_library implements v2021 spec, but not v2022.
import custom_array_library.array_api as xp
import array_api_compat

X = xp.asarray([1.0, 2.0])
xp = array_api_compat.get_namespace(X)

# Which spec is `xp` supporting? v2021 or v2022?
xp

The easiest solution is to error on array_api_compat.get_namespace because the latest spec (v2022) is not supported by the array library. The less appealing answer is to update the namespace to the v2022 spec (which is this feature request).

@thomasjpfan
Copy link
Contributor Author

From the weekly call, we think it's best to error in array_api_compat.get_namespace when the input array does not support the latest spec.

@asmeurer
Copy link
Member

asmeurer commented Mar 8, 2023

So for now (i.e., in #25) I am not going to do anything about take, since it is already present in numpy, cupy, and torch. The torch take does not have the axis keyword argument, but @thomasjpfan confirmed today that this is not an issue for him. Going forward we will:

  • Add support for all the 2022 spec to this library (not just take).
  • Add support for all the 2022 spec to numpy.array_api.
  • Support passing a version to get_namespace. As mentioned, the default will be the latest version generally supported, which for now is still 2021 but once we implement 2022 support it will be 2022.

Regarding the version, I think the primary benefit of it is to get a better error message when using an array API namespace that is too old. Virtually every change in the spec is additive. Meaning there's no reason for something like get_namespace(np.array(...), version='2021.12') to return anything different from get_namespace(np.array(...), version='2022.12'). The 2022 compliant namespace will also be 2021 compliant.

The bigger question here is about numpy.array_api. Due to its strictness, we might want to make it so that you can get a 2021 strictly compliant version of it that doesn't include any 2022 stuff. As I said, no work has really happened yet on adding 2022 support, except for a few offshoot PRs like the one mentioned above adding take. I think it would be nice to have this behavior there, but I'll have to think about how easy it is to implement. But this really should be discussed on the NumPy repo, not here.

For this compat library, we are taking a much more pragmatic approach.

@asmeurer
Copy link
Member

asmeurer commented Mar 8, 2023

Implemented api_version (currently only supporting '2021.12') to get_namespace() in #25.

@asmeurer
Copy link
Member

Based on the discussion, I don't think there's anything to do here, since take is already implemented in numpy, cupy, and torch, and has been implemented in the git numpy and cupy array_api submodules (with full 2022.12 support coming later). I've added the version flag to get_namespace and it only accepts 2021.12 for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants