You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello all! I've just spent time converting a reasonably large code base to use array-api-compat. I was pleasantly surprised that nearly everything worked, but there was one particular function that didn't map smoothly. And that was expand_dims. The issue is not necessarily the problem of array-api-compat, but I thought I'd start here.
In the array API, expand_dims supports only a single axis (https://data-apis.org/array-api/latest/API_specification/generated/array_api.expand_dims.html) as opposed to a tuple of axes. This is different to NumPy, CuPy, and Jax, which support a tuple of axes. PyTorch, however, supports only a single axis. I don't know the justification for why the array API supports only a single axis as opposed to a tuple, but what it means is that expand_dims no longer works in many places.
In practice, expand_dims is just a light wrapper for reshape, see https://github.com/numpy/numpy/blob/3b246c6488cf246d488bbe5726ca58dc26b6ea74/numpy/lib/_shape_base_impl.py#L594. But I'd really rather not have to write my own version of expand_dims in every library now. Would array_api_compat be willing to provide a non-strict version of expand_dims that still supports a tuple of axes? Or has there been a clear discussion and decision that expand_dims will only support a single axis going forward, effectively making all users of expand_dims copy and paste the NumPy implementation?
Many thanks!
The text was updated successfully, but these errors were encountered:
Based on https://data-apis.org/array-api-compat/#scope, I'm going to close this. If this proposal gets included in a draft version of the standard, we should implement it here, but until then, it's better to request upstream support for this feature in the standard and/or PyTorch.
Hello all! I've just spent time converting a reasonably large code base to use array-api-compat. I was pleasantly surprised that nearly everything worked, but there was one particular function that didn't map smoothly. And that was
expand_dims
. The issue is not necessarily the problem of array-api-compat, but I thought I'd start here.In the array API,
expand_dims
supports only a single axis (https://data-apis.org/array-api/latest/API_specification/generated/array_api.expand_dims.html) as opposed to a tuple of axes. This is different to NumPy, CuPy, and Jax, which support a tuple of axes. PyTorch, however, supports only a single axis. I don't know the justification for why the array API supports only a single axis as opposed to a tuple, but what it means is thatexpand_dims
no longer works in many places.In practice,
expand_dims
is just a light wrapper forreshape
, see https://github.com/numpy/numpy/blob/3b246c6488cf246d488bbe5726ca58dc26b6ea74/numpy/lib/_shape_base_impl.py#L594. But I'd really rather not have to write my own version of expand_dims in every library now. Wouldarray_api_compat
be willing to provide a non-strict version ofexpand_dims
that still supports a tuple of axes? Or has there been a clear discussion and decision thatexpand_dims
will only support a single axis going forward, effectively making all users ofexpand_dims
copy and paste the NumPy implementation?Many thanks!
The text was updated successfully, but these errors were encountered: