diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index cd127f44..1d31bc7a 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.6' +__version__ = '1.7' from .common import * # noqa: F401, F403 diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 79354487..32fb0e70 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -50,7 +50,7 @@ def is_numpy_array(x): is_torch_array is_dask_array is_jax_array - is_pydata_sparse + is_pydata_sparse_array """ # Avoid importing NumPy if it isn't already if 'numpy' not in sys.modules: @@ -80,7 +80,7 @@ def is_cupy_array(x): is_torch_array is_dask_array is_jax_array - is_pydata_sparse + is_pydata_sparse_array """ # Avoid importing NumPy if it isn't already if 'cupy' not in sys.modules: @@ -107,7 +107,7 @@ def is_torch_array(x): is_cupy_array is_dask_array is_jax_array - is_pydata_sparse + is_pydata_sparse_array """ # Avoid importing torch if it isn't already if 'torch' not in sys.modules: @@ -134,7 +134,7 @@ def is_dask_array(x): is_cupy_array is_torch_array is_jax_array - is_pydata_sparse + is_pydata_sparse_array """ # Avoid importing dask if it isn't already if 'dask.array' not in sys.modules: @@ -161,7 +161,7 @@ def is_jax_array(x): is_cupy_array is_torch_array is_dask_array - is_pydata_sparse + is_pydata_sparse_array """ # Avoid importing jax if it isn't already if 'jax' not in sys.modules: @@ -172,7 +172,7 @@ def is_jax_array(x): return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) -def is_pydata_sparse(x) -> bool: +def is_pydata_sparse_array(x) -> bool: """ Return True if `x` is an array from the `sparse` package. @@ -219,7 +219,7 @@ def is_array_api_obj(x): or is_torch_array(x) \ or is_dask_array(x) \ or is_jax_array(x) \ - or is_pydata_sparse(x) \ + or is_pydata_sparse_array(x) \ or hasattr(x, '__array_namespace__') def _check_api_version(api_version): @@ -288,7 +288,7 @@ def your_function(x, y): is_torch_array is_dask_array is_jax_array - is_pydata_sparse + is_pydata_sparse_array """ if use_compat not in [None, True, False]: @@ -348,7 +348,7 @@ def your_function(x, y): # not have a wrapper submodule for it. import jax.experimental.array_api as jnp namespaces.add(jnp) - elif is_pydata_sparse(x): + elif is_pydata_sparse_array(x): if use_compat is True: _check_api_version(api_version) raise ValueError("`sparse` does not have an array-api-compat wrapper") @@ -451,7 +451,7 @@ def device(x: Array, /) -> Device: return x.device() else: return x.device - elif is_pydata_sparse(x): + elif is_pydata_sparse_array(x): # `sparse` will gain `.device`, so check for this first. x_device = getattr(x, 'device', None) if x_device is not None: @@ -583,7 +583,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] # This import adds to_device to x import jax.experimental.array_api # noqa: F401 return x.to_device(device, stream=stream) - elif is_pydata_sparse(x) and device == _device(x): + elif is_pydata_sparse_array(x) and device == _device(x): # Perform trivial check to return the same array if # device is same instead of err-ing. return x @@ -613,7 +613,7 @@ def size(x): "is_jax_array", "is_numpy_array", "is_torch_array", - "is_pydata_sparse", + "is_pydata_sparse_array", "size", "to_device", ] diff --git a/docs/changelog.md b/docs/changelog.md index 545a9aa8..f48b713d 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,6 +1,6 @@ # Changelog -## 1.6 (2024-03-29) +## 1.7 (2024-05-24) ## Major Changes @@ -10,7 +10,23 @@ `array_api_compat.sparse` submodule, and `array_namespace()` returns the `sparse` module. -- Added the function `is_pydata_sparse(x)`. +- Added the function `is_pydata_sparse_array(x)`. + +## Minor Changes + +- Fix JAX `float0` arrays. See https://github.com/google/jax/issues/20620. + ([@NeilGirdhar](https://github.com/NeilGirdhar)) + +- Fix `torch.linalg.vector_norm()` when `axis=()`. + +- Fix `torch.linalg.solve()` to apply the array API standard rules for when + `x2` should be treated as a vector vs. a matrix. + +- Fix PyTorch test failures on CI by skipping uint16, uint32, uint64 tests. + +## 1.6 (2024-03-29) + +## Major Changes - Drop support for Python 3.8. diff --git a/docs/dev/tests.md b/docs/dev/tests.md index 3ae7b7a0..6d9d1d7b 100644 --- a/docs/dev/tests.md +++ b/docs/dev/tests.md @@ -11,7 +11,8 @@ dependencies from `requirements-dev.txt` (array-api-compat has [no hard runtime dependencies](no-dependencies)). array-api-tests is run against all supported libraries are tested on CI -([except for JAX](jax-support)). This is achieved by a [reusable GitHub Actions +([except for JAX](jax-support) and [Sparse](sparse-support)). This is achieved +by a [reusable GitHub Actions Workflow](https://github.com/data-apis/array-api-compat/blob/main/.github/workflows/array-api-tests.yml). Most libraries have tests that must be xfailed or skipped for various reasons. These are defined in specific `-xfails.txt` files and are diff --git a/docs/helper-functions.rst b/docs/helper-functions.rst index dcaa2e44..5516bf60 100644 --- a/docs/helper-functions.rst +++ b/docs/helper-functions.rst @@ -49,3 +49,4 @@ yet. .. autofunction:: is_torch_array .. autofunction:: is_dask_array .. autofunction:: is_jax_array +.. autofunction:: is_pydata_sparse_array diff --git a/docs/index.md b/docs/index.md index e4e59d5b..02320330 100644 --- a/docs/index.md +++ b/docs/index.md @@ -2,8 +2,8 @@ This is a small wrapper around common array libraries that is compatible with the [Array API standard](https://data-apis.org/array-api/latest/). Currently, -NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array -libraries, or if you encounter any issues, please [open an +NumPy, CuPy, PyTorch, Dask, JAX, and Sparse are supported. If you want support +for other array libraries, or if you encounter any issues, please [open an issue](https://github.com/data-apis/array-api-compat/issues). Note that some of the functionality in this library is backwards incompatible diff --git a/docs/supported-array-libraries.md b/docs/supported-array-libraries.md index 88b9edce..a016a636 100644 --- a/docs/supported-array-libraries.md +++ b/docs/supported-array-libraries.md @@ -133,5 +133,7 @@ Other methods may only be partially implemented or return incorrect results at t The minimum supported Dask version is 2023.12.0. -## [`sparse`](https://sparse.pydata.org/en/stable/) +(sparse-support)= +## [Sparse](https://sparse.pydata.org/en/stable/) + Similar to JAX, `sparse` Array API support is contained directly in `sparse`. diff --git a/setup.py b/setup.py index 506e3005..1b260081 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ "jax": "jax", "pytorch": "pytorch", "dask": "dask", + "sprase": "sparse >=0.15.1", }, classifiers=[ "Programming Language :: Python :: 3", diff --git a/tests/test_common.py b/tests/test_common.py index 798dc114..294a112a 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,5 +1,5 @@ from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401 - is_dask_array, is_jax_array, is_pydata_sparse) + is_dask_array, is_jax_array, is_pydata_sparse_array) from array_api_compat import is_array_api_obj, device, to_device @@ -16,7 +16,7 @@ 'torch': 'is_torch_array', 'dask.array': 'is_dask_array', 'jax.numpy': 'is_jax_array', - 'sparse': 'is_pydata_sparse', + 'sparse': 'is_pydata_sparse_array', } @pytest.mark.parametrize('library', is_functions.keys())