Skip to content

Commit b96e84b

Browse files
committed
Merge branch 'main' into more-2023
2 parents d751db6 + f3145b0 commit b96e84b

File tree

7 files changed

+40
-45
lines changed

7 files changed

+40
-45
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ jobs:
66
strategy:
77
matrix:
88
python-version: ['3.9', '3.10', '3.11', '3.12']
9-
numpy-version: ['1.21', '1.26', 'dev']
9+
numpy-version: ['1.21', '1.26', '2.0', 'dev']
1010
exclude:
1111
- python-version: '3.11'
1212
numpy-version: '1.21'

array_api_compat/common/_helpers.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def is_ndonnx_array(x):
145145

146146
import ndonnx as ndx
147147

148-
return isinstance(x, ndx.Array)
148+
return isinstance(x, ndx.Array)
149149

150150
def is_dask_array(x):
151151
"""
@@ -340,12 +340,9 @@ def your_function(x, y):
340340
elif use_compat is False:
341341
namespaces.add(np)
342342
else:
343-
# numpy 2.0 has __array_namespace__ and is fully array API
343+
# numpy 2.0+ have __array_namespace__, however, they are not yet fully array API
344344
# compatible.
345-
if hasattr(np.empty(0), '__array_namespace__'):
346-
namespaces.add(np.empty(0).__array_namespace__(api_version=api_version))
347-
else:
348-
namespaces.add(numpy_namespace)
345+
namespaces.add(numpy_namespace)
349346
elif is_cupy_array(x):
350347
if _use_compat:
351348
_check_api_version(api_version)
@@ -377,9 +374,13 @@ def your_function(x, y):
377374
elif use_compat is False:
378375
import jax.numpy as jnp
379376
else:
380-
# jax.experimental.array_api is already an array namespace. We do
381-
# not have a wrapper submodule for it.
382-
import jax.experimental.array_api as jnp
377+
# JAX v0.4.32 and newer implements the array API directly in jax.numpy.
378+
# For older JAX versions, it is available via jax.experimental.array_api.
379+
import jax.numpy
380+
if hasattr(jax.numpy, "__array_api_version__"):
381+
jnp = jax.numpy
382+
else:
383+
import jax.experimental.array_api as jnp
383384
namespaces.add(jnp)
384385
elif is_pydata_sparse_array(x):
385386
if use_compat is True:
@@ -613,8 +614,9 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
613614
return x
614615
raise ValueError(f"Unsupported device {device!r}")
615616
elif is_jax_array(x):
616-
# This import adds to_device to x
617-
import jax.experimental.array_api # noqa: F401
617+
if not hasattr(x, "__array_namespace__"):
618+
# In JAX v0.4.31 and older, this import adds to_device method to x.
619+
import jax.experimental.array_api # noqa: F401
618620
return x.to_device(device, stream=stream)
619621
elif is_pydata_sparse_array(x) and device == _device(x):
620622
# Perform trivial check to return the same array if

array_api_compat/torch/_aliases.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
146146
# Basic renames
147147
bitwise_invert = torch.bitwise_not
148148
newaxis = None
149+
# torch.conj sets the conjugation bit, which breaks conversion to other
150+
# libraries. See https://github.com/data-apis/array-api-compat/issues/173
151+
conj = torch.conj_physical
149152

150153
# Two-arg elementwise functions
151154
# These require a wrapper to do the correct type promotion on 0-D tensors
@@ -707,7 +710,7 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
707710
return torch.index_select(x, axis, indices, **kwargs)
708711

709712
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert',
710-
'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift',
713+
'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift',
711714
'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign',
712715
'divide', 'equal', 'floor_divide', 'greater', 'greater_equal',
713716
'hypot', 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal',

docs/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ import array_api_compat.dask as da
6363
```{note}
6464
There are no `array_api_compat` submodules for JAX, sparse, or ndonnx. These
6565
support for these libraries is contained in the libraries themselves (JAX
66-
support is in the `jax.experimental.array_api` module). The
66+
support is in the `jax.numpy` module in JAX v0.4.32 or newer, and in the
67+
`jax.experimental.array_api` module for older JAX versions). The
6768
array-api-compat support for these libraries consists of supporting them in
6869
the [helper functions](helper-functions).
6970
```

numpy-dev-xfails.txt

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,11 @@
11
# finfo(float32).eps returns float32 but should return float
22
array_api_tests/test_data_type_functions.py::test_finfo[float32]
33

4-
# NumPy deviates in some special cases for floordiv
5-
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
6-
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
7-
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
8-
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
9-
array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
10-
array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
11-
array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
12-
array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
13-
array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
14-
array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
15-
array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
16-
array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
17-
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
18-
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
19-
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
20-
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
21-
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
22-
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
23-
244
# https://github.com/numpy/numpy/issues/21213
255
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
266
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
277
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
288
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
29-
array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices
309

3110
# The test suite is incorrectly checking sums that have loss of significance
3211
# (https://github.com/data-apis/array-api-tests/issues/168)

tests/_helpers.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33

44
import pytest
55

6-
wrapped_libraries = ["cupy", "torch", "dask.array"]
7-
all_libraries = wrapped_libraries + ["numpy", "jax.numpy", "sparse"]
8-
import numpy as np
9-
if np.__version__[0] == '1':
10-
wrapped_libraries.append("numpy")
6+
wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
7+
all_libraries = wrapped_libraries + ["jax.numpy"]
118

129
# `sparse` added array API support as of Python 3.10.
1310
if sys.version_info >= (3, 10):
@@ -18,7 +15,11 @@ def import_(library, wrapper=False):
1815
pytest.importorskip(library)
1916
if wrapper:
2017
if 'jax' in library:
21-
library = 'jax.experimental.array_api'
18+
# JAX v0.4.32 implements the array API directly in jax.numpy
19+
# Older jax versions use jax.experimental.array_api
20+
jax_numpy = import_module("jax.numpy")
21+
if not hasattr(jax_numpy, "__array_api_version__"):
22+
library = 'jax.experimental.array_api'
2223
elif library.startswith('sparse'):
2324
library = 'sparse'
2425
else:

tests/test_array_namespace.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,14 @@ def test_array_namespace(library, api_version, use_compat):
2626

2727
if use_compat is False or use_compat is None and library not in wrapped_libraries:
2828
if library == "jax.numpy" and use_compat is None:
29-
import jax.experimental.array_api
30-
assert namespace == jax.experimental.array_api
29+
import jax.numpy
30+
if hasattr(jax.numpy, "__array_api_version__"):
31+
# JAX v0.4.32 or later uses jax.numpy directly
32+
assert namespace == jax.numpy
33+
else:
34+
# JAX v0.4.31 or earlier uses jax.experimental.array_api
35+
import jax.experimental.array_api
36+
assert namespace == jax.experimental.array_api
3137
else:
3238
assert namespace == xp
3339
else:
@@ -58,8 +64,11 @@ def test_array_namespace(library, api_version, use_compat):
5864
assert 'jax.experimental.array_api' not in sys.modules
5965
namespace = array_api_compat.array_namespace(array, api_version={api_version!r})
6066
61-
import jax.experimental.array_api
62-
assert namespace == jax.experimental.array_api
67+
if hasattr(jax.numpy, '__array_api_version__'):
68+
assert namespace == jax.numpy
69+
else:
70+
import jax.experimental.array_api
71+
assert namespace == jax.experimental.array_api
6372
"""
6473
subprocess.run([sys.executable, "-c", code], check=True)
6574

0 commit comments

Comments
 (0)