-
Notifications
You must be signed in to change notification settings - Fork 53
RFC: add support for a tuple of axes in expand_dims
#760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Seems tuple support was omitted because torch doesn't support it #42. I found a few feature requests for it for I agree this ambiguity is a potential concern. If we standardize this, we should somehow only require a subset of behavior that omits this ambiguity, e.g., by leaving the mixing of negative and nonnegative indices unspecified. Consider for example: >>> np.expand_dims(np.empty((2,)), (1, -1)).shape
(2, 1, 1) The resulting shape has But also consider >>> np.expand_dims(np.empty((2, 3, 4, 5)), (3, -3)).shape
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/aaronmeurer/miniconda3/envs/array-apis/lib/python3.11/site-packages/numpy/lib/shape_base.py", line 597, in expand_dims
axis = normalize_axis_tuple(axis, out_ndim)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/aaronmeurer/miniconda3/envs/array-apis/lib/python3.11/site-packages/numpy/core/numeric.py", line 1385, in normalize_axis_tuple
raise ValueError('repeated axis')
ValueError: repeated axis There's no way to insert Here's a small proof. There's no length list where you can remove indices 3 and -3 and result in a list of length 4>>> def remove_indices(n, idxes):
... """Return range(n) with `idxes` indices removed"""
... x = list(range(n))
... vals = [x[i] for i in idxes]
... for v in vals:
... try:
... x.remove(v)
... except ValueError: # Already removed
... pass
... return x
>>> [remove_indices(n, (-3, 3)) for n in range(4, 10)]
[[0, 2], [0, 1, 4], [0, 1, 2, 4, 5], [0, 1, 2, 5, 6], [0, 1, 2, 4, 6, 7], [0, 1, 2, 4, 5, 7, 8]]
>>> [len(remove_indices(n, (-3, 3))) for n in range(4, 10)]
[2, 3, 5, 5, 6, 7] At the same time, if the goal of |
expand_dims
for tuple of axesexpand_dims
Regarding removing ambiguity, I think it would suffice to impose an ordering in which to prefer expanding dims right? For example, if we specify "negative indices get resolved first" then your borrowing your example above could be resolved as x = np.empty((2, 3, 4, 5))
xp.expand_dims(x., (3, -3)) == np.expand_dims(np.expand_dims(x, -3), 3) so that the final output shape is Still, I'm not sure if it is worth it since in the first place users could do it in a two-step expansion (albeit with some more thought), and the resolution order (+ or - indices first?) is rather arbitrary. |
When you do repeated expand_dims, the inserted dimensions in the final shape won't necessarily be in the indices you initially specified (that's the whole point of this feature request, that you need a way to do them all at once). |
in case it affects this making v2024 either way, this is now available as https://data-apis.org/array-api-extra/generated/array_api_extra.expand_dims.html |
Given the ambiguity of supporting a tuple of axes in def spread_dims(x: array, ndims: int, axes=Tuple[int, ...]) -> array which expands the shape of an input array This essentially flips the problem into one in which you specify where you want the non-singleton dimensions, rather than where you want to insert the singleton dimensions. |
That sounds like a good idea if anybody takes issue with the interpretation chosen for |
It seems to me the behavior of
This basically describes the existing behavior of NumPy, and handles all the ambiguities mentioned above:
This behavior is semantically equivalent to calling |
Hello all! I raised this issue on array-api-compat earlier (data-apis/array-api-compat#105), but I think it might be more properly directed here.
In the array API,
expand_dims
supports only a single axis (https://data-apis.org/array-api/latest/API_specification/generated/array_api.expand_dims.html) as opposed to a tuple of axes. This is different to NumPy, CuPy, and Jax, which support a tuple of axes. PyTorch, however, supports only a single axis. I don't know the justification for why the array API supports only a single axis as opposed to a tuple, but what it means is thatexpand_dims
no longer works in many places when adopting the array API.In practice, expand_dims is just a light wrapper for reshape, see https://github.com/numpy/numpy/blob/3b246c6488cf246d488bbe5726ca58dc26b6ea74/numpy/lib/_shape_base_impl.py#L594. But it's not great to force users to write their own version of
expand_dims
in every library now. Is the array API willing to updateexpand_dims
to support a tuple of axes? If not, and ifexpand_dims
will only support a single axis going forward, that effectively makes all users ofexpand_dims
copy and paste the NumPy implementation.@lucascolley Pointed out to me that when
expand_dims
was added to the array API, only NumPy supported a tuple of axes. See #42. That was 4 years ago and the situation has changed, as above.The text was updated successfully, but these errors were encountered: