Skip to content

Commit f52b3d5

Browse files
authored
Merge branch 'main' into add-dask
2 parents 69cc93b + 916a84b commit f52b3d5

22 files changed

+426
-199
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
name: Array API Tests (NumPy dev)
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
array-api-tests-numpy-dev:
7+
uses: ./.github/workflows/array-api-tests.yml
8+
with:
9+
package-name: numpy
10+
extra-requires: '--pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple'
11+
xfails-file-extra: '-dev'

.github/workflows/array-api-tests.yml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ on:
3030

3131

3232
env:
33-
PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }}"
33+
PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline"
3434

3535
jobs:
3636
tests:
@@ -51,19 +51,20 @@ jobs:
5151
submodules: 'true'
5252
path: array-api-tests
5353
- name: Set up Python ${{ matrix.python-version }}
54-
uses: actions/setup-python@v4
54+
uses: actions/setup-python@v5
5555
with:
5656
python-version: ${{ matrix.python-version }}
5757
- name: Install dependencies
58-
# NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way
59-
# to put this in the numpy 1.21 config file.
60-
if: "! (matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
58+
# NumPy 1.21 doesn't support Python 3.11. NumPy 2.0 doesn't support
59+
# Python 3.8. There doesn't seem to be a way to put this in the numpy
60+
# 1.21 config file.
61+
if: "! ((matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21')) || (matrix.python-version == '3.8' && inputs.package-name == 'numpy' && contains(inputs.xfails-file-extra, 'dev')))"
6162
run: |
6263
python -m pip install --upgrade pip
6364
python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }}
6465
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt
6566
- name: Run the array API testsuite (${{ inputs.package-name }})
66-
if: "! (matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
67+
if: "! ((matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21')) || (matrix.python-version == '3.8' && inputs.package-name == 'numpy' && contains(inputs.xfails-file-extra, 'dev')))"
6768
env:
6869
ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }}
6970
# This enables the NEP 50 type promotion behavior (without it a lot of

.github/workflows/publish-package.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
fetch-depth: 0
3636

3737
- name: Set up Python
38-
uses: actions/setup-python@v4
38+
uses: actions/setup-python@v5
3939
with:
4040
python-version: '3.x'
4141

@@ -59,7 +59,7 @@ jobs:
5959
run: python -m zipfile --list dist/array_api_compat-*.whl
6060

6161
- name: Upload distribution artifact
62-
uses: actions/upload-artifact@v3
62+
uses: actions/upload-artifact@v4
6363
with:
6464
name: dist-artifact
6565
path: dist
@@ -80,7 +80,7 @@ jobs:
8080

8181
steps:
8282
- name: Download distribution artifact
83-
uses: actions/download-artifact@v3
83+
uses: actions/download-artifact@v4
8484
with:
8585
name: dist-artifact
8686
path: dist

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ jobs:
99
fail-fast: true
1010
steps:
1111
- uses: actions/checkout@v4
12-
- uses: actions/setup-python@v4
12+
- uses: actions/setup-python@v5
1313
with:
1414
python-version: ${{ matrix.python-version }}
1515
- name: Install Dependencies

CHANGELOG.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,22 @@
1+
# 1.4.1 (2024-01-18)
2+
3+
## Minor Changes
4+
5+
- Add support for the upcoming NumPy 2.0 release.
6+
7+
- Added a torch wrapper for `trace` (`torch.trace` doesn't support the
8+
`offset` argument or stacking)
9+
10+
- Wrap numpy, cupy, and torch `nonzero` to raise an error for zero-dimensional
11+
input arrays.
12+
13+
- Add torch wrapper for `newaxis`.
14+
15+
- Improve error message for `array_namespace`
16+
17+
- Fix linalg.cholesky returning the conjugate of the expected upper
18+
decomposition for numpy and cupy.
19+
120
# 1.4 (2023-09-13)
221

322
## Major Changes

README.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,54 @@ corresponding document does not yet exist for PyTorch, but you can examine the
300300
various comments in the
301301
[implementation](https://github.com/data-apis/array-api-compat/blob/main/array_api_compat/torch/_aliases.py)
302302
to see what functions and behaviors have been wrapped.
303+
304+
305+
## Releasing
306+
307+
To release, first note that CuPy must be tested manually (it isn't tested on
308+
CI). Use the script
309+
310+
```
311+
./test_cupy.sh
312+
```
313+
314+
on a machine with a CUDA GPU.
315+
316+
Once you are ready to release, create a PR with a release branch, so that you
317+
can verify that CI is passing. You must edit
318+
319+
```
320+
array_api_compat/__init__.py
321+
```
322+
323+
and update the version (the version is not computed from the tag because that
324+
would break vendorability). You should also edit
325+
326+
```
327+
CHANGELOG.md
328+
```
329+
330+
with the changes for the release.
331+
332+
Then create a tag
333+
334+
```
335+
git tag -a <version>
336+
```
337+
338+
and push it to GitHub
339+
340+
```
341+
git push origin <version>
342+
```
343+
344+
Check that the `publish distributions` action works. Note that this action
345+
will run even if the other CI fails, so you must make sure that CI is passing
346+
*before* tagging.
347+
348+
This does mean you can ignore CI failures, but ideally you should fix any
349+
failures or update the `*-xfails.txt` files before tagging, so that CI and the
350+
cupy tests pass. Otherwise it will be hard to tell what things are breaking in
351+
the future. It's also a good idea to remove any xpasses from those files (but
352+
be aware that some xfails are from flaky failures, so unless you know the
353+
underlying issue has been fixed, a xpass test is probably still xfail).

array_api_compat/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@
1717
this implementation for the default when working with NumPy arrays.
1818
1919
"""
20-
__version__ = '1.4'
20+
__version__ = '1.4.1'
2121

2222
from .common import *

array_api_compat/common/_aliases.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,12 @@ def sort(
396396
res = xp.flip(res, axis=axis)
397397
return res
398398

399+
# nonzero should error for zero-dimensional arrays
400+
def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
401+
if x.ndim == 0:
402+
raise ValueError("nonzero() does not support zero-dimensional arrays")
403+
return xp.nonzero(x, **kwargs)
404+
399405
# sum() and prod() should always upcast when dtype=None
400406
def sum(
401407
x: ndarray,
@@ -536,5 +542,5 @@ def isdtype(
536542
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
537543
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
538544
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
539-
'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul',
540-
'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
545+
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
546+
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']

array_api_compat/common/_helpers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,7 @@ def your_function(x, y):
8181
"""
8282
namespaces = set()
8383
for x in xs:
84-
if hasattr(x, '__array_namespace__'):
85-
namespaces.add(x.__array_namespace__(api_version=api_version))
86-
elif _is_numpy_array(x):
84+
if _is_numpy_array(x):
8785
_check_api_version(api_version)
8886
if _use_compat:
8987
from .. import numpy as numpy_namespace
@@ -114,9 +112,11 @@ def your_function(x, y):
114112
namespaces.add(dask_namespace)
115113
else:
116114
raise TypeError("_use_compat cannot be False if input array is a dask array!")
115+
elif hasattr(x, '__array_namespace__'):
116+
namespaces.add(x.__array_namespace__(api_version=api_version))
117117
else:
118118
# TODO: Support Python scalars?
119-
raise TypeError("The input is not a supported array type")
119+
raise TypeError(f"{type(x).__name__} is not a supported array type")
120120

121121
if not namespaces:
122122
raise TypeError("Unrecognized array input")

array_api_compat/common/_linalg.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
else:
1212
from numpy.core.numeric import normalize_axis_tuple
1313

14-
from ._aliases import matmul, matrix_transpose, tensordot, vecdot
14+
from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
1515
from .._internal import get_xp
1616

1717
# These are in the main NumPy namespace but not in numpy.linalg
@@ -59,7 +59,10 @@ def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult
5959
def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray:
6060
L = xp.linalg.cholesky(x, **kwargs)
6161
if upper:
62-
return get_xp(xp)(matrix_transpose)(L)
62+
U = get_xp(xp)(matrix_transpose)(L)
63+
if get_xp(xp)(isdtype)(U.dtype, 'complex floating'):
64+
U = xp.conj(U)
65+
return U
6366
return L
6467

6568
# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.

array_api_compat/cupy/_aliases.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
reshape = get_xp(cp)(_aliases.reshape)
5353
argsort = get_xp(cp)(_aliases.argsort)
5454
sort = get_xp(cp)(_aliases.sort)
55+
nonzero = get_xp(cp)(_aliases.nonzero)
5556
sum = get_xp(cp)(_aliases.sum)
5657
prod = get_xp(cp)(_aliases.prod)
5758
ceil = get_xp(cp)(_aliases.ceil)
@@ -60,8 +61,17 @@
6061
matmul = get_xp(cp)(_aliases.matmul)
6162
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
6263
tensordot = get_xp(cp)(_aliases.tensordot)
63-
vecdot = get_xp(cp)(_aliases.vecdot)
64-
isdtype = get_xp(cp)(_aliases.isdtype)
64+
65+
# These functions are completely new here. If the library already has them
66+
# (i.e., numpy 2.0), use the library version instead of our wrapper.
67+
if hasattr(cp, 'vecdot'):
68+
vecdot = cp.vecdot
69+
else:
70+
vecdot = get_xp(cp)(_aliases.vecdot)
71+
if hasattr(cp, 'isdtype'):
72+
isdtype = cp.isdtype
73+
else:
74+
isdtype = get_xp(cp)(_aliases.isdtype)
6575

6676
__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
6777
'acosh', 'asin', 'asinh', 'atan', 'atan2',

array_api_compat/cupy/linalg.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,16 @@
2929
pinv = get_xp(cp)(_linalg.pinv)
3030
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
3131
svdvals = get_xp(cp)(_linalg.svdvals)
32-
vector_norm = get_xp(cp)(_linalg.vector_norm)
3332
diagonal = get_xp(cp)(_linalg.diagonal)
3433
trace = get_xp(cp)(_linalg.trace)
3534

35+
# These functions are completely new here. If the library already has them
36+
# (i.e., numpy 2.0), use the library version instead of our wrapper.
37+
if hasattr(cp.linalg, 'vector_norm'):
38+
vector_norm = cp.linalg.vector_norm
39+
else:
40+
vector_norm = get_xp(cp)(_linalg.vector_norm)
41+
3642
__all__ = linalg_all + _linalg.__all__
3743

3844
del get_xp

array_api_compat/numpy/_aliases.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
reshape = get_xp(np)(_aliases.reshape)
5353
argsort = get_xp(np)(_aliases.argsort)
5454
sort = get_xp(np)(_aliases.sort)
55+
nonzero = get_xp(np)(_aliases.nonzero)
5556
sum = get_xp(np)(_aliases.sum)
5657
prod = get_xp(np)(_aliases.prod)
5758
ceil = get_xp(np)(_aliases.ceil)
@@ -60,8 +61,17 @@
6061
matmul = get_xp(np)(_aliases.matmul)
6162
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
6263
tensordot = get_xp(np)(_aliases.tensordot)
63-
vecdot = get_xp(np)(_aliases.vecdot)
64-
isdtype = get_xp(np)(_aliases.isdtype)
64+
65+
# These functions are completely new here. If the library already has them
66+
# (i.e., numpy 2.0), use the library version instead of our wrapper.
67+
if hasattr(np, 'vecdot'):
68+
vecdot = np.vecdot
69+
else:
70+
vecdot = get_xp(np)(_aliases.vecdot)
71+
if hasattr(np, 'isdtype'):
72+
isdtype = np.isdtype
73+
else:
74+
isdtype = get_xp(np)(_aliases.isdtype)
6575

6676
__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos',
6777
'acosh', 'asin', 'asinh', 'atan', 'atan2',

array_api_compat/numpy/linalg.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,16 @@
2222
pinv = get_xp(np)(_linalg.pinv)
2323
matrix_norm = get_xp(np)(_linalg.matrix_norm)
2424
svdvals = get_xp(np)(_linalg.svdvals)
25-
vector_norm = get_xp(np)(_linalg.vector_norm)
2625
diagonal = get_xp(np)(_linalg.diagonal)
2726
trace = get_xp(np)(_linalg.trace)
2827

28+
# These functions are completely new here. If the library already has them
29+
# (i.e., numpy 2.0), use the library version instead of our wrapper.
30+
if hasattr(np.linalg, 'vector_norm'):
31+
vector_norm = np.linalg.vector_norm
32+
else:
33+
vector_norm = get_xp(np)(_linalg.vector_norm)
34+
2935
__all__ = linalg_all + _linalg.__all__
3036

3137
del get_xp

array_api_compat/torch/_aliases.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,8 @@ def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Unio
475475
return torch.roll(x, shift, axis, **kwargs)
476476

477477
def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
478+
if x.ndim == 0:
479+
raise ValueError("nonzero() does not support zero-dimensional arrays")
478480
return torch.nonzero(x, as_tuple=True, **kwargs)
479481

480482
def where(condition: array, x1: array, x2: array, /) -> array:

array_api_compat/torch/linalg.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
if TYPE_CHECKING:
55
import torch
66
array = torch.Tensor
7+
from torch import dtype as Dtype
8+
from typing import Optional
79

810
from torch.linalg import *
911

@@ -12,9 +14,9 @@
1214
from torch import linalg as torch_linalg
1315
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
1416

15-
# These are implemented in torch but aren't in the linalg namespace
16-
from torch import outer, trace
17-
from ._aliases import _fix_promotion, matrix_transpose, tensordot
17+
# outer is implemented in torch but aren't in the linalg namespace
18+
from torch import outer
19+
from ._aliases import _fix_promotion, matrix_transpose, tensordot, sum
1820

1921
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
2022
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
@@ -49,6 +51,11 @@ def solve(x1: array, x2: array, /, **kwargs) -> array:
4951
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
5052
return torch.linalg.solve(x1, x2, **kwargs)
5153

54+
# torch.trace doesn't support the offset argument and doesn't support stacking
55+
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
56+
# Use our wrapped sum to make sure it does upcasting correctly
57+
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
58+
5259
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot',
5360
'vecdot', 'solve']
5461

0 commit comments

Comments
 (0)