Skip to content

RFC: expand_dims for tuple of axes #105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
izaid opened this issue Mar 10, 2024 · 2 comments
Closed

RFC: expand_dims for tuple of axes #105

izaid opened this issue Mar 10, 2024 · 2 comments

Comments

@izaid
Copy link

izaid commented Mar 10, 2024

Hello all! I've just spent time converting a reasonably large code base to use array-api-compat. I was pleasantly surprised that nearly everything worked, but there was one particular function that didn't map smoothly. And that was expand_dims. The issue is not necessarily the problem of array-api-compat, but I thought I'd start here.

In the array API, expand_dims supports only a single axis (https://data-apis.org/array-api/latest/API_specification/generated/array_api.expand_dims.html) as opposed to a tuple of axes. This is different to NumPy, CuPy, and Jax, which support a tuple of axes. PyTorch, however, supports only a single axis. I don't know the justification for why the array API supports only a single axis as opposed to a tuple, but what it means is that expand_dims no longer works in many places.

In practice, expand_dims is just a light wrapper for reshape, see https://github.com/numpy/numpy/blob/3b246c6488cf246d488bbe5726ca58dc26b6ea74/numpy/lib/_shape_base_impl.py#L594. But I'd really rather not have to write my own version of expand_dims in every library now. Would array_api_compat be willing to provide a non-strict version of expand_dims that still supports a tuple of axes? Or has there been a clear discussion and decision that expand_dims will only support a single axis going forward, effectively making all users of expand_dims copy and paste the NumPy implementation?

Many thanks!

@lucascolley
Copy link
Member

It looks like only NumPy supported providing a tuple when this was added to the spec ~4 years ago: data-apis/array-api#42.

Now that CuPy and JAX both support it, maybe you should bring this up on the array API repo.

@asmeurer
Copy link
Member

This was requested in the standard at data-apis/array-api#760.

Based on https://data-apis.org/array-api-compat/#scope, I'm going to close this. If this proposal gets included in a draft version of the standard, we should implement it here, but until then, it's better to request upstream support for this feature in the standard and/or PyTorch.

@asmeurer asmeurer closed this as not planned Won't fix, can't repro, duplicate, stale Mar 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants