Skip to content

Commit 6f1610d

Browse files
authored
Merge pull request data-apis#80 from asmeurer/fix-test-failures
Fixing CI test failures
2 parents d235910 + 7d54463 commit 6f1610d

File tree

10 files changed

+142
-72
lines changed

10 files changed

+142
-72
lines changed

.github/workflows/array-api-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ on:
2424

2525

2626
env:
27-
PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }}"
27+
PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline"
2828

2929
jobs:
3030
tests:

array_api_compat/common/_aliases.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,12 @@ def sort(
386386
res = xp.flip(res, axis=axis)
387387
return res
388388

389+
# nonzero should error for zero-dimensional arrays
390+
def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
391+
if x.ndim == 0:
392+
raise ValueError("nonzero() does not support zero-dimensional arrays")
393+
return xp.nonzero(x, **kwargs)
394+
389395
# sum() and prod() should always upcast when dtype=None
390396
def sum(
391397
x: ndarray,
@@ -526,5 +532,5 @@ def isdtype(
526532
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
527533
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
528534
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
529-
'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul',
530-
'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
535+
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
536+
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']

array_api_compat/cupy/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
reshape = get_xp(cp)(_aliases.reshape)
5353
argsort = get_xp(cp)(_aliases.argsort)
5454
sort = get_xp(cp)(_aliases.sort)
55+
nonzero = get_xp(cp)(_aliases.nonzero)
5556
sum = get_xp(cp)(_aliases.sum)
5657
prod = get_xp(cp)(_aliases.prod)
5758
ceil = get_xp(cp)(_aliases.ceil)

array_api_compat/numpy/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
reshape = get_xp(np)(_aliases.reshape)
5353
argsort = get_xp(np)(_aliases.argsort)
5454
sort = get_xp(np)(_aliases.sort)
55+
nonzero = get_xp(np)(_aliases.nonzero)
5556
sum = get_xp(np)(_aliases.sum)
5657
prod = get_xp(np)(_aliases.prod)
5758
ceil = get_xp(np)(_aliases.ceil)

array_api_compat/torch/_aliases.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,8 @@ def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Unio
475475
return torch.roll(x, shift, axis, **kwargs)
476476

477477
def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
478+
if x.ndim == 0:
479+
raise ValueError("nonzero() does not support zero-dimensional arrays")
478480
return torch.nonzero(x, as_tuple=True, **kwargs)
479481

480482
def where(condition: array, x1: array, x2: array, /) -> array:

array_api_compat/torch/linalg.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
if TYPE_CHECKING:
55
import torch
66
array = torch.Tensor
7+
from torch import dtype as Dtype
8+
from typing import Optional
79

810
from torch.linalg import *
911

@@ -12,9 +14,9 @@
1214
from torch import linalg as torch_linalg
1315
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
1416

15-
# These are implemented in torch but aren't in the linalg namespace
16-
from torch import outer, trace
17-
from ._aliases import _fix_promotion, matrix_transpose, tensordot
17+
# outer is implemented in torch but aren't in the linalg namespace
18+
from torch import outer
19+
from ._aliases import _fix_promotion, matrix_transpose, tensordot, sum
1820

1921
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
2022
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
@@ -49,6 +51,11 @@ def solve(x1: array, x2: array, /, **kwargs) -> array:
4951
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
5052
return torch.linalg.solve(x1, x2, **kwargs)
5153

54+
# torch.trace doesn't support the offset argument and doesn't support stacking
55+
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
56+
# Use our wrapped sum to make sure it does upcasting correctly
57+
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
58+
5259
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot',
5360
'vecdot', 'solve']
5461

cupy-xfails.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,16 @@ array_api_tests/test_special_cases.py::test_iop[__imod__(isfinite(x1_i) and x1_i
175175
array_api_tests/test_special_cases.py::test_iop[__imod__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i]
176176
array_api_tests/test_special_cases.py::test_iop[__imod__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i]
177177
array_api_tests/test_special_cases.py::test_iop[__imod__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i]
178+
179+
# fft functions are not yet supported
180+
# (https://github.com/data-apis/array-api-compat/issues/67)
181+
array_api_tests/test_fft.py::test_fft
182+
array_api_tests/test_fft.py::test_ifft
183+
array_api_tests/test_fft.py::test_fftn
184+
array_api_tests/test_fft.py::test_ifftn
185+
array_api_tests/test_fft.py::test_rfft
186+
array_api_tests/test_fft.py::test_irfft
187+
array_api_tests/test_fft.py::test_rfftn
188+
array_api_tests/test_fft.py::test_irfftn
189+
array_api_tests/test_fft.py::test_hfft
190+
array_api_tests/test_fft.py::test_ihfft

numpy-1-21-xfails.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,16 @@ array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i
236236
array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i < 0) -> -0]
237237
array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0]
238238
array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0]
239+
240+
# fft functions are not yet supported
241+
# (https://github.com/data-apis/array-api-compat/issues/67)
242+
array_api_tests/test_fft.py::test_fft
243+
array_api_tests/test_fft.py::test_ifft
244+
array_api_tests/test_fft.py::test_fftn
245+
array_api_tests/test_fft.py::test_ifftn
246+
array_api_tests/test_fft.py::test_rfft
247+
array_api_tests/test_fft.py::test_irfft
248+
array_api_tests/test_fft.py::test_rfftn
249+
array_api_tests/test_fft.py::test_irfftn
250+
array_api_tests/test_fft.py::test_hfft
251+
array_api_tests/test_fft.py::test_ihfft

numpy-xfails.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,16 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices
5454
# The test suite is incorrectly checking sums that have loss of significance
5555
# (https://github.com/data-apis/array-api-tests/issues/168)
5656
array_api_tests/test_statistical_functions.py::test_sum
57+
58+
# fft functions are not yet supported
59+
# (https://github.com/data-apis/array-api-compat/issues/67)
60+
array_api_tests/test_fft.py::test_fft
61+
array_api_tests/test_fft.py::test_ifft
62+
array_api_tests/test_fft.py::test_fftn
63+
array_api_tests/test_fft.py::test_ifftn
64+
array_api_tests/test_fft.py::test_rfft
65+
array_api_tests/test_fft.py::test_irfft
66+
array_api_tests/test_fft.py::test_rfftn
67+
array_api_tests/test_fft.py::test_irfftn
68+
array_api_tests/test_fft.py::test_hfft
69+
array_api_tests/test_fft.py::test_ihfft

0 commit comments

Comments
 (0)