Skip to content

Commit 34e9d0c

Browse files
authored
Merge pull request #38 from asmeurer/2022.12
2022.12 support
2 parents 5eea1c7 + 6bd5e43 commit 34e9d0c

14 files changed

+93
-31
lines changed

CHANGELOG.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
# 1.3 (2023-06-20)
2+
3+
## Major Changes
4+
5+
- Add [2022.12](https://data-apis.org/array-api/2022.12/) standard support.
6+
This includes things like adding complex dtype support, adding the new
7+
`take` function, and various minor changes in the specification.
8+
9+
## Minor Changes
10+
11+
- Support `"cpu"` in CuPy `to_device()`.
12+
13+
- Return a new array in NumPy/CuPy `reshape(copy=False)`.
14+
15+
- Fix signatures for PyTorch `broadcast_to` and `permute_dims`.
16+
117
# 1.2 (2023-04-03)
218

319
## Major Changes

README.md

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,8 @@ each array library itself fully compatible with the array API, but this
1212
requires making backwards incompatible changes in many cases, so this will
1313
take some time.
1414

15-
Currently all libraries here are implemented against the 2021.12 version of
16-
the standard. Support for the [2022.12
17-
version](https://data-apis.org/array-api/2022.12/changelog.html), which adds
18-
complex number support as well as several additional functions, will be added
19-
later this year.
15+
Currently all libraries here are implemented against the [2022.22
16+
version](https://data-apis.org/array-api/2022.22/) of the standard.
2017

2118
## Usage
2219

@@ -177,8 +174,6 @@ version.
177174
in the spec. Use the `size(x)` helper function as a portable workaround (see
178175
above).
179176

180-
- The `linalg` extension is not yet implemented.
181-
182177
- PyTorch does not have unsigned integer types other than `uint8`, and no
183178
attempt is made to implement them here.
184179

array_api_compat/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@
1717
this implementation for the default when working with NumPy arrays.
1818
1919
"""
20-
__version__ = '1.2'
20+
__version__ = '1.3'
2121

2222
from .common import *

array_api_compat/common/_aliases.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,12 @@ def sum(
397397
keepdims: bool = False,
398398
**kwargs,
399399
) -> ndarray:
400-
# `xp.sum` already upcasts integers, but not floats
401-
if dtype is None and x.dtype == xp.float32:
402-
dtype = xp.float64
400+
# `xp.sum` already upcasts integers, but not floats or complexes
401+
if dtype is None:
402+
if x.dtype == xp.float32:
403+
dtype = xp.float64
404+
elif x.dtype == xp.complex64:
405+
dtype = xp.complex128
403406
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)
404407

405408
def prod(
@@ -412,8 +415,11 @@ def prod(
412415
keepdims: bool = False,
413416
**kwargs,
414417
) -> ndarray:
415-
if dtype is None and x.dtype == xp.float32:
416-
dtype = xp.float64
418+
if dtype is None:
419+
if x.dtype == xp.float32:
420+
dtype = xp.float64
421+
elif x.dtype == xp.complex64:
422+
dtype = xp.complex128
417423
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)
418424

419425
# ceil, floor, and trunc return integers for integer inputs

array_api_compat/common/_linalg.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,13 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]
136136
def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
137137
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
138138

139-
def trace(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
140-
return xp.asarray(xp.trace(x, offset=offset, axis1=-2, axis2=-1, **kwargs))
139+
def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray:
140+
if dtype is None:
141+
if x.dtype == xp.float32:
142+
dtype = xp.float64
143+
elif x.dtype == xp.complex64:
144+
dtype = xp.complex128
145+
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
141146

142147
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
143148
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',

array_api_compat/cupy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313

1414
from ..common._helpers import *
1515

16-
__array_api_version__ = '2021.12'
16+
__array_api_version__ = '2022.12'

array_api_compat/numpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@
1919

2020
from ..common._helpers import *
2121

22-
__array_api_version__ = '2021.12'
22+
__array_api_version__ = '2022.12'

array_api_compat/torch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@
1919

2020
from ..common._helpers import *
2121

22-
__array_api_version__ = '2021.12'
22+
__array_api_version__ = '2022.12'

array_api_compat/torch/_aliases.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
*_int_dtypes,
3333
torch.float32,
3434
torch.float64,
35+
torch.complex64,
36+
torch.complex128,
3537
}
3638

3739
_promotion_table = {
@@ -70,6 +72,16 @@
7072
(torch.float32, torch.float64): torch.float64,
7173
(torch.float64, torch.float32): torch.float64,
7274
(torch.float64, torch.float64): torch.float64,
75+
# complexes
76+
(torch.complex64, torch.complex64): torch.complex64,
77+
(torch.complex64, torch.complex128): torch.complex128,
78+
(torch.complex128, torch.complex64): torch.complex128,
79+
(torch.complex128, torch.complex128): torch.complex128,
80+
# Mixed float and complex
81+
(torch.float32, torch.complex64): torch.complex64,
82+
(torch.float32, torch.complex128): torch.complex128,
83+
(torch.float64, torch.complex64): torch.complex128,
84+
(torch.float64, torch.complex128): torch.complex128,
7385
}
7486

7587

@@ -129,7 +141,6 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
129141
return torch.can_cast(from_, to)
130142

131143
# Basic renames
132-
permute_dims = torch.permute
133144
bitwise_invert = torch.bitwise_not
134145

135146
# Two-arg elementwise functions
@@ -439,18 +450,26 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
439450
x = torch.squeeze(x, a)
440451
return x
441452

453+
# torch.broadcast_to uses size instead of shape
454+
def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array:
455+
return torch.broadcast_to(x, shape, **kwargs)
456+
457+
# torch.permute uses dims instead of axes
458+
def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
459+
return torch.permute(x, axes)
460+
442461
# The axis parameter doesn't work for flip() and roll()
443462
# https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't
444463
# accept axis=None
445-
def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array:
464+
def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
446465
if axis is None:
447466
axis = tuple(range(x.ndim))
448467
# torch.flip doesn't accept dim as an int but the method does
449468
# https://github.com/pytorch/pytorch/issues/18095
450-
return x.flip(axis)
469+
return x.flip(axis, **kwargs)
451470

452-
def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array:
453-
return torch.roll(x, shift, axis)
471+
def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
472+
return torch.roll(x, shift, axis, **kwargs)
454473

455474
def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
456475
return torch.nonzero(x, as_tuple=True, **kwargs)
@@ -662,15 +681,18 @@ def isdtype(
662681
else:
663682
return dtype == kind
664683

684+
def take(x: array, indices: array, /, *, axis: int, **kwargs) -> array:
685+
return torch.index_select(x, axis, indices, **kwargs)
686+
665687
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add',
666688
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
667689
'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal',
668690
'floor_divide', 'greater', 'greater_equal', 'less', 'less_equal',
669691
'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder',
670692
'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all',
671-
'mean', 'std', 'var', 'concat', 'squeeze', 'flip', 'roll',
693+
'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll',
672694
'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full',
673695
'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
674696
'broadcast_arrays', 'unique_all', 'unique_counts',
675697
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',
676-
'vecdot', 'tensordot', 'isdtype']
698+
'vecdot', 'tensordot', 'isdtype', 'take']

cupy-xfails.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mo
4141
array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)]
4242
array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)]
4343
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x, s)]
44+
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)]
4445
# floating point inaccuracy
4546
array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)]
4647

@@ -55,6 +56,10 @@ array_api_tests/test_statistical_functions.py::test_max
5556
# (https://github.com/data-apis/array-api-tests/issues/171)
5657
array_api_tests/test_signatures.py::test_func_signature[meshgrid]
5758

59+
# testsuite issue with test_square
60+
# https://github.com/data-apis/array-api-tests/issues/190
61+
array_api_tests/test_operators_and_elementwise_functions.py::test_square
62+
5863
# We cannot add array attributes
5964
array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__]
6065
array_api_tests/test_signatures.py::test_array_method_signature[__index__]

numpy-1-21-xfails.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x
4747
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
4848
array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices
4949

50+
# testsuite issue with test_square
51+
# https://github.com/data-apis/array-api-tests/issues/190
52+
array_api_tests/test_operators_and_elementwise_functions.py::test_square
53+
5054
# NumPy 1.21 specific XFAILS
5155
############################
5256

numpy-xfails.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity
4040
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
4141
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
4242

43+
# testsuite issue with test_square
44+
# https://github.com/data-apis/array-api-tests/issues/190
45+
array_api_tests/test_operators_and_elementwise_functions.py::test_square
46+
4347
# https://github.com/numpy/numpy/issues/21213
4448
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
4549
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]

test_cupy.sh

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,6 @@ cd $tmpdir
1818
git clone https://github.com/data-apis/array-api-tests
1919
cd array-api-tests
2020

21-
# Remove this once https://github.com/data-apis/array-api-tests/pull/157 is
22-
# merged
23-
git remote add asmeurer https://github.com/asmeurer/array-api-tests
24-
git fetch asmeurer
25-
git checkout asmeurer/xfails-file
26-
2721
git submodule update --init
2822

2923
# store the hypothesis examples database in this directory, so that failures

torch-xfails.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and
170170
array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0) -> +0]
171171
array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i < 0) -> -0]
172172

173+
# testsuite issue with test_square
174+
# https://github.com/data-apis/array-api-tests/issues/190
175+
array_api_tests/test_operators_and_elementwise_functions.py::test_square
176+
173177
# Float correction is not supported by pytorch
174178
# (https://github.com/data-apis/array-api-tests/issues/168)
175179
array_api_tests/test_special_cases.py::test_empty_arrays[std]
@@ -182,3 +186,10 @@ array_api_tests/test_statistical_functions.py::test_var
182186
# The test suite is incorrectly checking sums that have loss of significance
183187
# (https://github.com/data-apis/array-api-tests/issues/168)
184188
array_api_tests/test_statistical_functions.py::test_sum
189+
190+
# These functions do not yet support complex numbers
191+
array_api_tests/test_operators_and_elementwise_functions.py::test_sign
192+
array_api_tests/test_operators_and_elementwise_functions.py::test_expm1
193+
array_api_tests/test_operators_and_elementwise_functions.py::test_round
194+
array_api_tests/test_set_functions.py::test_unique_counts
195+
array_api_tests/test_set_functions.py::test_unique_values

0 commit comments

Comments
 (0)