Skip to content

Implement torch.take with axis argument #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
thomasjpfan opened this issue Mar 27, 2023 · 13 comments · Fixed by #47
Closed

Implement torch.take with axis argument #34

thomasjpfan opened this issue Mar 27, 2023 · 13 comments · Fixed by #47

Comments

@thomasjpfan
Copy link
Contributor

I suspect take's axis argument will be needed at some point. Can we add a simple implementation for PyTorch?

def take(array, indices, *, axis):
    key = [slice(None)] * array.ndim
    key[axis] = indices
    return array[key]
@thomasjpfan thomasjpfan changed the title Implement torch.take with axis argument. Implement torch.take with axis argument Mar 28, 2023
@asmeurer
Copy link
Member

Is there an upstream pytorch issue for adding this to pytorch itself? It seems to be impossible to search for in the GitHub issue tracker. @lezcano

@lezcano
Copy link

lezcano commented Mar 29, 2023

@ogrisel
Copy link

ogrisel commented Jul 3, 2023

Is this issue still up to date? It seems that the latest release of array-api-compat (1.3) makes xp.take accept an axis parameter but also makes it mandatory:

>>> import array_api_compat
>>> from array_api_compat import get_namespace
>>> array_api_compat.__version__
'1.3'
>>> import torch
>>> x = torch.randn(3, 3)
>>> x
tensor([[ 0.5549,  0.1046,  0.8453],
        [ 0.8627, -0.7882, -0.6627],
        [-1.4976, -0.1424,  0.4018]])
>>> indices = torch.tensor([1, 0])
>>> xp = get_namespace(x)
>>> xp.take(x, indices)
Traceback (most recent call last):
  Cell In[15], line 1
    xp.take(x, indices)
TypeError: take() missing 1 required keyword-only argument: 'axis'
>>> xp.take(x, indices, axis=0)
tensor([[ 0.8627, -0.7882, -0.6627],
        [ 0.5549,  0.1046,  0.8453]])
>>> xp.take(x, indices, axis=1)
tensor([[ 0.1046,  0.5549],
        [-0.7882,  0.8627],
        [-0.1424, -1.4976]])

The fact that it's mandatory was a bit surprising to me, but apparently the signature is in line with the specification:

>>> xp.take?
Signature: xp.take(x: 'array', indices: 'array', /, *, axis: 'int', **kwargs) -> 'array'
Docstring: <no docstring>
File:      ~/mambaforge/envs/dev/lib/python3.11/site-packages/array_api_compat/torch/_aliases.py
Type:      function
take(x: array, indices: array, /, *, axis: int) → array

@kgryte
Copy link
Contributor

kgryte commented Jul 3, 2023

@ogrisel Only when the input array is one-dimensional is the axis kwarg optional. Otherwise, take is equivalent to integer indexing on a multi-dimensional array, in which one would explicitly indicate the axis to index.

@asmeurer
Copy link
Member

asmeurer commented Jul 5, 2023

This looks like a bug in the spec. The text says passing it is optional, but the Python signature does not allow for that.

@kgryte
Copy link
Contributor

kgryte commented Jul 5, 2023

Yes, this is unfortunate, as, by making it optional for one-dimensional arrays, we don't have a great option for the signature, as setting the default value to None makes it seem as if the argument is optional for multi-dimensional arrays. My feeling is it would be better if the axis argument is required for all arrays in order to resolve this ambiguity.

@rgommers
Copy link
Member

rgommers commented Jul 5, 2023

The discussion in data-apis/array-api#416 explicitly states that axis is optional, and that was supported by multiple people. So I'd consider this a bug in that PR and in the spec, and the correct resolution seems to me to be to add the missing = None to the signature. The docs are clear enough, so there isn't much of an ambiguity.

@kgryte
Copy link
Contributor

kgryte commented Jul 5, 2023

The ambiguity stems from the type signature: the signature is not able to encode optionality and "requiredness" at the same time (or based on the input array shape).

We made the axis kwarg optional as a convenience; however, at the time, I had omitted using None as the default to satisfy the general case in which the kwarg is required for >1D. Obviously, this doesn't work for the scenario where axis is optional; hence, my statement concerning the unfortunate aspect of the signature.

@rgommers
Copy link
Member

rgommers commented Jul 5, 2023

Yes, I understand what's going on - it seems clear to me that the spec has a bug that has to be resolved one way or the other, and the written agreement is for optionality for 1-D arrays, hence changing the signature to axis : int = None. It's the path of least resistance, and avoids the surprise that @ogrisel expressed above. We already considered scikit-learn's needs in data-apis/array-api#416 (comment); it's mostly 1-D arrays.

@asmeurer
Copy link
Member

asmeurer commented Jul 5, 2023

Yes, this is an issue with the current spec. With the current signature, take(x, indices, *, axis), axis is always required, due to the way the keyword-only syntax works.

The problem is that people (such as myself) copy the signatures from the spec directly. So if the Python signature doesn't allow axis to be optional, it won't be, because that's the exact signature I used.

The alternative signature, take(x, indices, *, axis=None) is easy to make work in the function logic:

def take(x, indices, *, axis=None):
    if x.ndim > 1 and axis is None:
        raise ValueError("axis must be provided when ndim > 1")
    ...

This is somewhat similar to arange where it's impossible to encode the "true" signature into the Python signature so we have to do what is as close as possible.

kgryte added a commit to kgryte/array-api that referenced this issue Jul 5, 2023
This commit fixes the function signature for `take`. Namely, when
an input array is one-dimensional, the `axis` kwarg is optional;
when the array has more than one dimension, the `axis` kwarg is
required. Unfortunately, the type signature cannot encode this
duality, and we must rely on the specification text to clarify
that the `axis` kwarg is required for arrays having ranks greater
than unity.

Ref: data-apis/array-api-compat#34
@kgryte
Copy link
Contributor

kgryte commented Jul 5, 2023

I opened a PR correcting the signature for take: data-apis/array-api#644.

rgommers pushed a commit to data-apis/array-api that referenced this issue Jul 10, 2023
This commit fixes the function signature for `take`. Namely, when
an input array is one-dimensional, the `axis` kwarg is optional;
when the array has more than one dimension, the `axis` kwarg is
required. Unfortunately, the type signature cannot encode this
duality, and we must rely on the specification text to clarify
that the `axis` kwarg is required for arrays having ranks greater
than unity.

Ref: data-apis/array-api-compat#34
@asmeurer
Copy link
Member

I updated numpy.array_api here numpy/numpy#24187

I'm unclear if I need to do anything for the compat library for numpy/cupy. The axis argument is not optional for ndim > 1 in numpy (it flattens). This generally the sort of thing I would expect the strict numpy.array_api implementation to catch, whereas in the compat library, we allow things that aren't strictly disallowed and do things like pass additional keyword arguments through.

Torch does need to be fixed though because it just wraps torch.index_select which requires the axis argument.

@rgommers
Copy link
Member

I'm unclear if I need to do anything for the compat library for numpy/cupy.

Probably not, I think it works as is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants