Skip to content

2022.12 support #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 20, 2023
Merged
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,6 @@ version.
in the spec. Use the `size(x)` helper function as a portable workaround (see
above).

- The `linalg` extension is not yet implemented.

- PyTorch does not have unsigned integer types other than `uint8`, and no
attempt is made to implement them here.

Expand Down
16 changes: 11 additions & 5 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,12 @@ def sum(
keepdims: bool = False,
**kwargs,
) -> ndarray:
# `xp.sum` already upcasts integers, but not floats
if dtype is None and x.dtype == xp.float32:
dtype = xp.float64
# `xp.sum` already upcasts integers, but not floats or complexes
if dtype is None:
if x.dtype == xp.float32:
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)

def prod(
Expand All @@ -412,8 +415,11 @@ def prod(
keepdims: bool = False,
**kwargs,
) -> ndarray:
if dtype is None and x.dtype == xp.float32:
dtype = xp.float64
if dtype is None:
if x.dtype == xp.float32:
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)

# ceil, floor, and trunc return integers for integer inputs
Expand Down
9 changes: 7 additions & 2 deletions array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,13 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]
def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)

def trace(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
return xp.asarray(xp.trace(x, offset=offset, axis1=-2, axis2=-1, **kwargs))
def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray:
if dtype is None:
if x.dtype == xp.float32:
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))

__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
Expand Down
2 changes: 1 addition & 1 deletion array_api_compat/cupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@

from ..common._helpers import *

__array_api_version__ = '2021.12'
__array_api_version__ = '2022.12'
2 changes: 1 addition & 1 deletion array_api_compat/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@

from ..common._helpers import *

__array_api_version__ = '2021.12'
__array_api_version__ = '2022.12'
2 changes: 1 addition & 1 deletion array_api_compat/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@

from ..common._helpers import *

__array_api_version__ = '2021.12'
__array_api_version__ = '2022.12'
36 changes: 29 additions & 7 deletions array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
*_int_dtypes,
torch.float32,
torch.float64,
torch.complex64,
torch.complex128,
}

_promotion_table = {
Expand Down Expand Up @@ -70,6 +72,16 @@
(torch.float32, torch.float64): torch.float64,
(torch.float64, torch.float32): torch.float64,
(torch.float64, torch.float64): torch.float64,
# complexes
(torch.complex64, torch.complex64): torch.complex64,
(torch.complex64, torch.complex128): torch.complex128,
(torch.complex128, torch.complex64): torch.complex128,
(torch.complex128, torch.complex128): torch.complex128,
# Mixed float and complex
(torch.float32, torch.complex64): torch.complex64,
(torch.float32, torch.complex128): torch.complex128,
(torch.float64, torch.complex64): torch.complex128,
(torch.float64, torch.complex128): torch.complex128,
}


Expand Down Expand Up @@ -129,7 +141,6 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
return torch.can_cast(from_, to)

# Basic renames
permute_dims = torch.permute
bitwise_invert = torch.bitwise_not

# Two-arg elementwise functions
Expand Down Expand Up @@ -439,18 +450,26 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
x = torch.squeeze(x, a)
return x

# torch.broadcast_to uses size instead of shape
def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array:
return torch.broadcast_to(x, shape, **kwargs)

# torch.permute uses dims instead of axes
def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
return torch.permute(x, axes)

# The axis parameter doesn't work for flip() and roll()
# https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't
# accept axis=None
def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array:
def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
if axis is None:
axis = tuple(range(x.ndim))
# torch.flip doesn't accept dim as an int but the method does
# https://github.com/pytorch/pytorch/issues/18095
return x.flip(axis)
return x.flip(axis, **kwargs)

def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array:
return torch.roll(x, shift, axis)
def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
return torch.roll(x, shift, axis, **kwargs)

def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
return torch.nonzero(x, as_tuple=True, **kwargs)
Expand Down Expand Up @@ -662,15 +681,18 @@ def isdtype(
else:
return dtype == kind

def take(x: array, indices: array, /, *, axis: int, **kwargs) -> array:
return torch.index_select(x, axis, indices, **kwargs)

__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add',
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal',
'floor_divide', 'greater', 'greater_equal', 'less', 'less_equal',
'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder',
'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all',
'mean', 'std', 'var', 'concat', 'squeeze', 'flip', 'roll',
'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll',
'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full',
'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
'broadcast_arrays', 'unique_all', 'unique_counts',
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',
'vecdot', 'tensordot', 'isdtype']
'vecdot', 'tensordot', 'isdtype', 'take']
5 changes: 5 additions & 0 deletions cupy-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mo
array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)]
# floating point inaccuracy
array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)]

Expand All @@ -55,6 +56,10 @@ array_api_tests/test_statistical_functions.py::test_max
# (https://github.com/data-apis/array-api-tests/issues/171)
array_api_tests/test_signatures.py::test_func_signature[meshgrid]

# testsuite issue with test_square
# https://github.com/data-apis/array-api-tests/issues/190
array_api_tests/test_operators_and_elementwise_functions.py::test_square

# We cannot add array attributes
array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__]
array_api_tests/test_signatures.py::test_array_method_signature[__index__]
Expand Down
4 changes: 4 additions & 0 deletions numpy-1-21-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices

# testsuite issue with test_square
# https://github.com/data-apis/array-api-tests/issues/190
array_api_tests/test_operators_and_elementwise_functions.py::test_square

# NumPy 1.21 specific XFAILS
############################

Expand Down
4 changes: 4 additions & 0 deletions numpy-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]

# testsuite issue with test_square
# https://github.com/data-apis/array-api-tests/issues/190
array_api_tests/test_operators_and_elementwise_functions.py::test_square

# https://github.com/numpy/numpy/issues/21213
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
Expand Down
6 changes: 0 additions & 6 deletions test_cupy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@ cd $tmpdir
git clone https://github.com/data-apis/array-api-tests
cd array-api-tests

# Remove this once https://github.com/data-apis/array-api-tests/pull/157 is
# merged
git remote add asmeurer https://github.com/asmeurer/array-api-tests
git fetch asmeurer
git checkout asmeurer/xfails-file

git submodule update --init

# store the hypothesis examples database in this directory, so that failures
Expand Down
11 changes: 11 additions & 0 deletions torch-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and
array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0) -> +0]
array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i < 0) -> -0]

# testsuite issue with test_square
# https://github.com/data-apis/array-api-tests/issues/190
array_api_tests/test_operators_and_elementwise_functions.py::test_square

# Float correction is not supported by pytorch
# (https://github.com/data-apis/array-api-tests/issues/168)
array_api_tests/test_special_cases.py::test_empty_arrays[std]
Expand All @@ -182,3 +186,10 @@ array_api_tests/test_statistical_functions.py::test_var
# The test suite is incorrectly checking sums that have loss of significance
# (https://github.com/data-apis/array-api-tests/issues/168)
array_api_tests/test_statistical_functions.py::test_sum

# These functions do not yet support complex numbers
array_api_tests/test_operators_and_elementwise_functions.py::test_sign
array_api_tests/test_operators_and_elementwise_functions.py::test_expm1
array_api_tests/test_operators_and_elementwise_functions.py::test_round
array_api_tests/test_set_functions.py::test_unique_counts
array_api_tests/test_set_functions.py::test_unique_values