-
Notifications
You must be signed in to change notification settings - Fork 34
Implement torch.take with axis
argument
#34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
axis
argument.axis
argument
Is there an upstream pytorch issue for adding this to pytorch itself? It seems to be impossible to search for in the GitHub issue tracker. @lezcano |
Is this issue still up to date? It seems that the latest release of array-api-compat (1.3) makes >>> import array_api_compat
>>> from array_api_compat import get_namespace
>>> array_api_compat.__version__
'1.3'
>>> import torch
>>> x = torch.randn(3, 3)
>>> x
tensor([[ 0.5549, 0.1046, 0.8453],
[ 0.8627, -0.7882, -0.6627],
[-1.4976, -0.1424, 0.4018]])
>>> indices = torch.tensor([1, 0])
>>> xp = get_namespace(x)
>>> xp.take(x, indices)
Traceback (most recent call last):
Cell In[15], line 1
xp.take(x, indices)
TypeError: take() missing 1 required keyword-only argument: 'axis'
>>> xp.take(x, indices, axis=0)
tensor([[ 0.8627, -0.7882, -0.6627],
[ 0.5549, 0.1046, 0.8453]])
>>> xp.take(x, indices, axis=1)
tensor([[ 0.1046, 0.5549],
[-0.7882, 0.8627],
[-0.1424, -1.4976]]) The fact that it's mandatory was a bit surprising to me, but apparently the signature is in line with the specification: >>> xp.take?
Signature: xp.take(x: 'array', indices: 'array', /, *, axis: 'int', **kwargs) -> 'array'
Docstring: <no docstring>
File: ~/mambaforge/envs/dev/lib/python3.11/site-packages/array_api_compat/torch/_aliases.py
Type: function
|
@ogrisel Only when the input array is one-dimensional is the |
This looks like a bug in the spec. The text says passing it is optional, but the Python signature does not allow for that. |
Yes, this is unfortunate, as, by making it optional for one-dimensional arrays, we don't have a great option for the signature, as setting the default value to |
The discussion in data-apis/array-api#416 explicitly states that |
The ambiguity stems from the type signature: the signature is not able to encode optionality and "requiredness" at the same time (or based on the input array shape). We made the |
Yes, I understand what's going on - it seems clear to me that the spec has a bug that has to be resolved one way or the other, and the written agreement is for optionality for 1-D arrays, hence changing the signature to |
Yes, this is an issue with the current spec. With the current signature, The problem is that people (such as myself) copy the signatures from the spec directly. So if the Python signature doesn't allow The alternative signature, def take(x, indices, *, axis=None):
if x.ndim > 1 and axis is None:
raise ValueError("axis must be provided when ndim > 1")
... This is somewhat similar to |
This commit fixes the function signature for `take`. Namely, when an input array is one-dimensional, the `axis` kwarg is optional; when the array has more than one dimension, the `axis` kwarg is required. Unfortunately, the type signature cannot encode this duality, and we must rely on the specification text to clarify that the `axis` kwarg is required for arrays having ranks greater than unity. Ref: data-apis/array-api-compat#34
I opened a PR correcting the signature for |
This commit fixes the function signature for `take`. Namely, when an input array is one-dimensional, the `axis` kwarg is optional; when the array has more than one dimension, the `axis` kwarg is required. Unfortunately, the type signature cannot encode this duality, and we must rely on the specification text to clarify that the `axis` kwarg is required for arrays having ranks greater than unity. Ref: data-apis/array-api-compat#34
I updated numpy.array_api here numpy/numpy#24187 I'm unclear if I need to do anything for the compat library for numpy/cupy. The axis argument is not optional for ndim > 1 in numpy (it flattens). This generally the sort of thing I would expect the strict numpy.array_api implementation to catch, whereas in the compat library, we allow things that aren't strictly disallowed and do things like pass additional keyword arguments through. Torch does need to be fixed though because it just wraps |
Probably not, I think it works as is. |
I suspect take's
axis
argument will be needed at some point. Can we add a simple implementation for PyTorch?The text was updated successfully, but these errors were encountered: