Skip to content

Commit b9854a7

Browse files
committed
Merge branch 'main' into signbit-nan
2 parents 86a402b + f3145b0 commit b9854a7

28 files changed

+333
-108
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
name: Array API Tests (NumPy 1.26)
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
array-api-tests-numpy-latest:
7+
uses: ./.github/workflows/array-api-tests.yml
8+
with:
9+
package-name: numpy
10+
package-version: '== 1.26.*'
11+
xfails-file-extra: '-1-26'

.github/workflows/docs-deploy.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
steps:
1414
- uses: actions/checkout@v4
1515
- name: Download Artifact
16-
uses: dawidd6/action-download-artifact@v3
16+
uses: dawidd6/action-download-artifact@v6
1717
with:
1818
workflow: docs-build.yml
1919
name: docs-build

.github/workflows/publish-package.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ jobs:
9494
if: >-
9595
(github.event_name == 'push' && startsWith(github.ref, 'refs/tags'))
9696
|| (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true')
97-
uses: pypa/gh-action-pypi-publish@v1.8.14
97+
uses: pypa/gh-action-pypi-publish@v1.9.0
9898
with:
9999
repository-url: https://test.pypi.org/legacy/
100100
print-hash: true
@@ -107,6 +107,6 @@ jobs:
107107

108108
- name: Publish distribution 📦 to PyPI
109109
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
110-
uses: pypa/gh-action-pypi-publish@v1.8.14
110+
uses: pypa/gh-action-pypi-publish@v1.9.0
111111
with:
112112
print-hash: true

.github/workflows/tests.yml

+6-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ jobs:
66
strategy:
77
matrix:
88
python-version: ['3.9', '3.10', '3.11', '3.12']
9-
numpy-version: ['1.21', '1.26', 'dev']
9+
numpy-version: ['1.21', '1.26', '2.0', 'dev']
1010
exclude:
1111
- python-version: '3.11'
1212
numpy-version: '1.21'
@@ -28,6 +28,11 @@ jobs:
2828
else
2929
PIP_EXTRA='numpy==1.26.*'
3030
fi
31+
32+
if [ "${{ matrix.python-version }}" == "3.9" ]; then
33+
sed -i '/^ndonnx/d' requirements-dev.txt
34+
fi
35+
3136
python -m pip install -r requirements-dev.txt $PIP_EXTRA
3237
3338
- name: Run Tests

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
This is a small wrapper around common array libraries that is compatible with
44
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
5-
NumPy, CuPy, PyTorch, Dask, JAX and `sparse` are supported. If you want support
6-
for other array libraries, or if you encounter any issues, please [open an
7-
issue](https://github.com/data-apis/array-api-compat/issues).
5+
NumPy, CuPy, PyTorch, Dask, JAX, ndonnx and `sparse` are supported. If you want
6+
support for other array libraries, or if you encounter any issues, please [open
7+
an issue](https://github.com/data-apis/array-api-compat/issues).
88

99
See the documentation for more details https://data-apis.org/array-api-compat/

array_api_compat/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@
1717
this implementation for the default when working with NumPy arrays.
1818
1919
"""
20-
__version__ = '1.6'
20+
__version__ = '1.8'
2121

2222
from .common import * # noqa: F401, F403

array_api_compat/common/_aliases.py

+62-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import NamedTuple
1313
import inspect
1414

15-
from ._helpers import _check_device
15+
from ._helpers import array_namespace, _check_device
1616

1717
# These functions are modified from the NumPy versions.
1818

@@ -264,6 +264,66 @@ def var(
264264
) -> ndarray:
265265
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
266266

267+
268+
# The min and max argument names in clip are different and not optional in numpy, and type
269+
# promotion behavior is different.
270+
def clip(
271+
x: ndarray,
272+
/,
273+
min: Optional[Union[int, float, ndarray]] = None,
274+
max: Optional[Union[int, float, ndarray]] = None,
275+
*,
276+
xp,
277+
# TODO: np.clip has other ufunc kwargs
278+
out: Optional[ndarray] = None,
279+
) -> ndarray:
280+
def _isscalar(a):
281+
return isinstance(a, (int, float, type(None)))
282+
min_shape = () if _isscalar(min) else min.shape
283+
max_shape = () if _isscalar(max) else max.shape
284+
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
285+
286+
wrapped_xp = array_namespace(x)
287+
288+
# np.clip does type promotion but the array API clip requires that the
289+
# output have the same dtype as x. We do this instead of just downcasting
290+
# the result of xp.clip() to handle some corner cases better (e.g.,
291+
# avoiding uint64 -> float64 promotion).
292+
293+
# Note: cases where min or max overflow (integer) or round (float) in the
294+
# wrong direction when downcasting to x.dtype are unspecified. This code
295+
# just does whatever NumPy does when it downcasts in the assignment, but
296+
# other behavior could be preferred, especially for integers. For example,
297+
# this code produces:
298+
299+
# >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None)
300+
# -128
301+
302+
# but an answer of 0 might be preferred. See
303+
# https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.
304+
305+
306+
# At least handle the case of Python integers correctly (see
307+
# https://github.com/numpy/numpy/pull/26892).
308+
if type(min) is int and min <= xp.iinfo(x.dtype).min:
309+
min = None
310+
if type(max) is int and max >= xp.iinfo(x.dtype).max:
311+
max = None
312+
313+
if out is None:
314+
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True)
315+
if min is not None:
316+
a = xp.broadcast_to(xp.asarray(min), result_shape)
317+
ia = (out < a) | xp.isnan(a)
318+
# torch requires an explicit cast here
319+
out[ia] = wrapped_xp.astype(a[ia], out.dtype)
320+
if max is not None:
321+
b = xp.broadcast_to(xp.asarray(max), result_shape)
322+
ib = (out > b) | xp.isnan(b)
323+
out[ib] = wrapped_xp.astype(b[ib], out.dtype)
324+
# Return a scalar for 0-D
325+
return out[()]
326+
267327
# Unlike transpose(), the axes argument to permute_dims() is required.
268328
def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
269329
return xp.transpose(x, axes)
@@ -465,6 +525,6 @@ def isdtype(
465525
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
466526
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
467527
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
468-
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
528+
'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort',
469529
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
470530
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']

array_api_compat/common/_helpers.py

+59-23
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ def is_numpy_array(x):
4848
is_array_api_obj
4949
is_cupy_array
5050
is_torch_array
51+
is_ndonnx_array
5152
is_dask_array
5253
is_jax_array
53-
is_pydata_sparse
54+
is_pydata_sparse_array
5455
"""
5556
# Avoid importing NumPy if it isn't already
5657
if 'numpy' not in sys.modules:
@@ -78,11 +79,12 @@ def is_cupy_array(x):
7879
is_array_api_obj
7980
is_numpy_array
8081
is_torch_array
82+
is_ndonnx_array
8183
is_dask_array
8284
is_jax_array
83-
is_pydata_sparse
85+
is_pydata_sparse_array
8486
"""
85-
# Avoid importing NumPy if it isn't already
87+
# Avoid importing CuPy if it isn't already
8688
if 'cupy' not in sys.modules:
8789
return False
8890

@@ -107,7 +109,7 @@ def is_torch_array(x):
107109
is_cupy_array
108110
is_dask_array
109111
is_jax_array
110-
is_pydata_sparse
112+
is_pydata_sparse_array
111113
"""
112114
# Avoid importing torch if it isn't already
113115
if 'torch' not in sys.modules:
@@ -118,6 +120,33 @@ def is_torch_array(x):
118120
# TODO: Should we reject ndarray subclasses?
119121
return isinstance(x, torch.Tensor)
120122

123+
def is_ndonnx_array(x):
124+
"""
125+
Return True if `x` is a ndonnx Array.
126+
127+
This function does not import ndonnx if it has not already been imported
128+
and is therefore cheap to use.
129+
130+
See Also
131+
--------
132+
133+
array_namespace
134+
is_array_api_obj
135+
is_numpy_array
136+
is_cupy_array
137+
is_ndonnx_array
138+
is_dask_array
139+
is_jax_array
140+
is_pydata_sparse_array
141+
"""
142+
# Avoid importing torch if it isn't already
143+
if 'ndonnx' not in sys.modules:
144+
return False
145+
146+
import ndonnx as ndx
147+
148+
return isinstance(x, ndx.Array)
149+
121150
def is_dask_array(x):
122151
"""
123152
Return True if `x` is a dask.array Array.
@@ -133,8 +162,9 @@ def is_dask_array(x):
133162
is_numpy_array
134163
is_cupy_array
135164
is_torch_array
165+
is_ndonnx_array
136166
is_jax_array
137-
is_pydata_sparse
167+
is_pydata_sparse_array
138168
"""
139169
# Avoid importing dask if it isn't already
140170
if 'dask.array' not in sys.modules:
@@ -160,8 +190,9 @@ def is_jax_array(x):
160190
is_numpy_array
161191
is_cupy_array
162192
is_torch_array
193+
is_ndonnx_array
163194
is_dask_array
164-
is_pydata_sparse
195+
is_pydata_sparse_array
165196
"""
166197
# Avoid importing jax if it isn't already
167198
if 'jax' not in sys.modules:
@@ -172,7 +203,7 @@ def is_jax_array(x):
172203
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
173204

174205

175-
def is_pydata_sparse(x) -> bool:
206+
def is_pydata_sparse_array(x) -> bool:
176207
"""
177208
Return True if `x` is an array from the `sparse` package.
178209
@@ -188,6 +219,7 @@ def is_pydata_sparse(x) -> bool:
188219
is_numpy_array
189220
is_cupy_array
190221
is_torch_array
222+
is_ndonnx_array
191223
is_dask_array
192224
is_jax_array
193225
"""
@@ -211,6 +243,7 @@ def is_array_api_obj(x):
211243
is_numpy_array
212244
is_cupy_array
213245
is_torch_array
246+
is_ndonnx_array
214247
is_dask_array
215248
is_jax_array
216249
"""
@@ -219,7 +252,7 @@ def is_array_api_obj(x):
219252
or is_torch_array(x) \
220253
or is_dask_array(x) \
221254
or is_jax_array(x) \
222-
or is_pydata_sparse(x) \
255+
or is_pydata_sparse_array(x) \
223256
or hasattr(x, '__array_namespace__')
224257

225258
def _check_api_version(api_version):
@@ -288,7 +321,7 @@ def your_function(x, y):
288321
is_torch_array
289322
is_dask_array
290323
is_jax_array
291-
is_pydata_sparse
324+
is_pydata_sparse_array
292325
293326
"""
294327
if use_compat not in [None, True, False]:
@@ -307,12 +340,9 @@ def your_function(x, y):
307340
elif use_compat is False:
308341
namespaces.add(np)
309342
else:
310-
# numpy 2.0 has __array_namespace__ and is fully array API
343+
# numpy 2.0+ have __array_namespace__, however, they are not yet fully array API
311344
# compatible.
312-
if hasattr(x, '__array_namespace__'):
313-
namespaces.add(x.__array_namespace__(api_version=api_version))
314-
else:
315-
namespaces.add(numpy_namespace)
345+
namespaces.add(numpy_namespace)
316346
elif is_cupy_array(x):
317347
if _use_compat:
318348
_check_api_version(api_version)
@@ -344,11 +374,15 @@ def your_function(x, y):
344374
elif use_compat is False:
345375
import jax.numpy as jnp
346376
else:
347-
# jax.experimental.array_api is already an array namespace. We do
348-
# not have a wrapper submodule for it.
349-
import jax.experimental.array_api as jnp
377+
# JAX v0.4.32 and newer implements the array API directly in jax.numpy.
378+
# For older JAX versions, it is available via jax.experimental.array_api.
379+
import jax.numpy
380+
if hasattr(jax.numpy, "__array_api_version__"):
381+
jnp = jax.numpy
382+
else:
383+
import jax.experimental.array_api as jnp
350384
namespaces.add(jnp)
351-
elif is_pydata_sparse(x):
385+
elif is_pydata_sparse_array(x):
352386
if use_compat is True:
353387
_check_api_version(api_version)
354388
raise ValueError("`sparse` does not have an array-api-compat wrapper")
@@ -451,7 +485,7 @@ def device(x: Array, /) -> Device:
451485
return x.device()
452486
else:
453487
return x.device
454-
elif is_pydata_sparse(x):
488+
elif is_pydata_sparse_array(x):
455489
# `sparse` will gain `.device`, so check for this first.
456490
x_device = getattr(x, 'device', None)
457491
if x_device is not None:
@@ -580,10 +614,11 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
580614
return x
581615
raise ValueError(f"Unsupported device {device!r}")
582616
elif is_jax_array(x):
583-
# This import adds to_device to x
584-
import jax.experimental.array_api # noqa: F401
617+
if not hasattr(x, "__array_namespace__"):
618+
# In JAX v0.4.31 and older, this import adds to_device method to x.
619+
import jax.experimental.array_api # noqa: F401
585620
return x.to_device(device, stream=stream)
586-
elif is_pydata_sparse(x) and device == _device(x):
621+
elif is_pydata_sparse_array(x) and device == _device(x):
587622
# Perform trivial check to return the same array if
588623
# device is same instead of err-ing.
589624
return x
@@ -613,7 +648,8 @@ def size(x):
613648
"is_jax_array",
614649
"is_numpy_array",
615650
"is_torch_array",
616-
"is_pydata_sparse",
651+
"is_ndonnx_array",
652+
"is_pydata_sparse_array",
617653
"size",
618654
"to_device",
619655
]

array_api_compat/cupy/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
astype = _aliases.astype
4848
std = get_xp(cp)(_aliases.std)
4949
var = get_xp(cp)(_aliases.var)
50+
clip = get_xp(cp)(_aliases.clip)
5051
permute_dims = get_xp(cp)(_aliases.permute_dims)
5152
reshape = get_xp(cp)(_aliases.reshape)
5253
argsort = get_xp(cp)(_aliases.argsort)

array_api_compat/dask/array/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def _dask_arange(
8888
permute_dims = get_xp(da)(_aliases.permute_dims)
8989
std = get_xp(da)(_aliases.std)
9090
var = get_xp(da)(_aliases.var)
91+
clip = get_xp(da)(_aliases.clip)
9192
empty = get_xp(da)(_aliases.empty)
9293
empty_like = get_xp(da)(_aliases.empty_like)
9394
full = get_xp(da)(_aliases.full)

array_api_compat/numpy/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
astype = _aliases.astype
4848
std = get_xp(np)(_aliases.std)
4949
var = get_xp(np)(_aliases.var)
50+
clip = get_xp(np)(_aliases.clip)
5051
permute_dims = get_xp(np)(_aliases.permute_dims)
5152
reshape = get_xp(np)(_aliases.reshape)
5253
argsort = get_xp(np)(_aliases.argsort)

0 commit comments

Comments
 (0)