Skip to content

Commit 0afd1a7

Browse files
committed
Merge branch 'main' into np-solve-fix
2 parents 730a214 + 29d948b commit 0afd1a7

34 files changed

+536
-1111
lines changed

.github/workflows/ruff.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ jobs:
1616
pip install ruff
1717
# Update output format to enable automatic inline annotations.
1818
- name: Run Ruff
19-
run: ruff check --output-format=github --select F822,PLC0414,RUF022 --preview .
19+
run: ruff check --output-format=github .

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,6 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
131+
# macOS specific iles
132+
.DS_Store

README.md

+33-8
Original file line numberDiff line numberDiff line change
@@ -71,23 +71,24 @@ namespace, except that functions that are part of the array API are wrapped so
7171
that they have the correct array API behavior. In each case, the array object
7272
used will be the same array object from the wrapped library.
7373

74-
## Difference between `array_api_compat` and `numpy.array_api`
74+
## Difference between `array_api_compat` and `array_api_strict`
7575

76-
`numpy.array_api` is a strict minimal implementation of the Array API (see
76+
`array_api_strict` is a strict minimal implementation of the array API standard, formerly
77+
known as `numpy.array_api` (see
7778
[NEP 47](https://numpy.org/neps/nep-0047-array-api-standard.html)). For
78-
example, `numpy.array_api` does not include any functions that are not part of
79+
example, `array_api_strict` does not include any functions that are not part of
7980
the array API specification, and will explicitly disallow behaviors that are
8081
not required by the spec (e.g., [cross-kind type
8182
promotions](https://data-apis.org/array-api/latest/API_specification/type_promotion.html)).
82-
(`cupy.array_api` is similar to `numpy.array_api`)
83+
(`cupy.array_api` is similar to `array_api_strict`)
8384

8485
`array_api_compat`, on the other hand, is just an extension of the
8586
corresponding array library namespaces with changes needed to be compliant
8687
with the array API. It includes all additional library functions not mentioned
8788
in the spec, and allows any library behaviors not explicitly disallowed by it,
8889
such as cross-kind casting.
8990

90-
In particular, unlike `numpy.array_api`, this package does not use a separate
91+
In particular, unlike `array_api_strict`, this package does not use a separate
9192
`Array` object, but rather just uses the corresponding array library array
9293
objects (`numpy.ndarray`, `cupy.ndarray`, `torch.Tensor`, etc.) directly. This
9394
is because those are the objects that are going to be passed as inputs to
@@ -96,7 +97,7 @@ functions by end users. This does mean that a few behaviors cannot be wrapped
9697
most things.
9798

9899
Array consuming library authors coding against the array API may wish to test
99-
against `numpy.array_api` to ensure they are not using functionality outside
100+
against `array_api_strict` to ensure they are not using functionality outside
100101
of the standard, but prefer this implementation for the default behavior for
101102
end-users.
102103

@@ -125,11 +126,11 @@ part of the specification but which are useful for using the array API:
125126
[`x.device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.device.html)
126127
in the array API specification. Included because `numpy.ndarray` does not
127128
include the `device` attribute and this library does not wrap or extend the
128-
array object. Note that for NumPy, `device(x)` is always `"cpu"`.
129+
array object. Note that for NumPy and dask, `device(x)` is always `"cpu"`.
129130

130131
- `to_device(x, device, /, *, stream=None)`: Equivalent to
131132
[`x.to_device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.to_device.html).
132-
Included because neither NumPy's, CuPy's, nor PyTorch's array objects
133+
Included because neither NumPy's, CuPy's, Dask's, nor PyTorch's array objects
133134
include this method. For NumPy, this function effectively does nothing since
134135
the only supported device is the CPU, but for CuPy, this method supports
135136
CuPy CUDA
@@ -240,6 +241,30 @@ Unlike the other libraries supported here, JAX array API support is contained
240241
entirely in the JAX library. The JAX array API support is tracked at
241242
https://github.com/google/jax/issues/18353.
242243

244+
## Dask
245+
246+
If you're using dask with numpy, many of the same limitations that apply to numpy
247+
will also apply to dask. Besides those differences, other limitations include missing
248+
sort functionality (no `sort` or `argsort`), and limited support for the optional `linalg`
249+
and `fft` extensions.
250+
251+
In particular, the `fft` namespace is not compliant with the array API spec. Any functions
252+
that you find under the `fft` namespace are the original, unwrapped functions under [`dask.array.fft`](https://docs.dask.org/en/latest/array-api.html#fast-fourier-transforms), which may or may not be Array API compliant. Use at your own risk!
253+
254+
For `linalg`, several methods are missing, for example:
255+
- `cross`
256+
- `det`
257+
- `eigh`
258+
- `eigvalsh`
259+
- `matrix_power`
260+
- `pinv`
261+
- `slogdet`
262+
- `matrix_norm`
263+
- `matrix_rank`
264+
Other methods may only be partially implemented or return incorrect results at times.
265+
266+
The minimum supported Dask version is 2023.12.0.
267+
243268
## Vendoring
244269

245270
This library supports vendoring as an installation method. To vendor the

array_api_compat/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
66
https://numpy.org/neps/nep-0047-array-api-standard.html.
77
8-
Unlike numpy.array_api, this is not a strict minimal implementation of the
8+
Unlike array_api_strict, this is not a strict minimal implementation of the
99
Array API, but rather just an extension of the main NumPy namespace with
1010
changes needed to be compliant with the Array API. See
1111
https://numpy.org/doc/stable/reference/array_api.html for a full list of
12-
changes. In particular, unlike numpy.array_api, this package does not use a
12+
changes. In particular, unlike array_api_strict, this package does not use a
1313
separate Array object, but rather just uses numpy.ndarray directly.
1414
15-
Library authors using the Array API may wish to test against numpy.array_api
15+
Library authors using the Array API may wish to test against array_api_strict
1616
to ensure they are not using functionality outside of the standard, but prefer
1717
this implementation for the default when working with NumPy arrays.
1818

array_api_compat/_internal.py

-29
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from functools import wraps
66
from inspect import signature
77

8-
98
def get_xp(xp):
109
"""
1110
Decorator to automatically replace xp with the corresponding array module.
@@ -45,31 +44,3 @@ def wrapped_f(*args, **kwargs):
4544
return wrapped_f
4645

4746
return inner
48-
49-
50-
def _get_all_public_members(module, exclude=None, extend_all=False):
51-
"""Get all public members of a module.
52-
53-
Parameters
54-
----------
55-
module : module
56-
The module to get members from.
57-
exclude : callable, optional
58-
A callable that takes a name and returns True if the name should be
59-
excluded from the list of members.
60-
extend_all : bool, optional
61-
If True, extend the module's __all__ attribute with the members of the
62-
module derived from dir(module). To be used for libraries that do not have a complete __all__ list.
63-
"""
64-
members = getattr(module, "__all__", [])
65-
66-
if members and not extend_all:
67-
return members
68-
69-
if exclude is None:
70-
exclude = lambda name: name.startswith("_") # noqa: E731
71-
72-
members = members + [_ for _ in dir(module) if not exclude(_)]
73-
74-
# remove duplicates
75-
return list(set(members))

array_api_compat/common/__init__.py

+1-27
Original file line numberDiff line numberDiff line change
@@ -1,27 +1 @@
1-
from ._helpers import (
2-
array_namespace,
3-
device,
4-
get_namespace,
5-
is_array_api_obj,
6-
is_cupy_array,
7-
is_dask_array,
8-
is_jax_array,
9-
is_numpy_array,
10-
is_torch_array,
11-
size,
12-
to_device,
13-
)
14-
15-
__all__ = [
16-
"array_namespace",
17-
"device",
18-
"get_namespace",
19-
"is_array_api_obj",
20-
"is_cupy_array",
21-
"is_dask_array",
22-
"is_jax_array",
23-
"is_numpy_array",
24-
"is_torch_array",
25-
"size",
26-
"to_device",
27-
]
1+
from ._helpers import * # noqa: F403

array_api_compat/common/_aliases.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ def zeros_like(
146146

147147
# The functions here return namedtuples (np.unique() returns a normal
148148
# tuple).
149+
150+
# Note that these named tuples aren't actually part of the standard namespace,
151+
# but I don't see any issue with exporting the names here regardless.
149152
class UniqueAllResult(NamedTuple):
150153
values: ndarray
151154
indices: ndarray
@@ -543,5 +546,13 @@ def isdtype(
543546
# This will allow things that aren't required by the spec, like
544547
# isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be
545548
# more strict here to match the type annotation? Note that the
546-
# numpy.array_api implementation will be very strict.
549+
# array_api_strict implementation will be very strict.
547550
return dtype == kind
551+
552+
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
553+
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
554+
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
555+
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
556+
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
557+
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
558+
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']

array_api_compat/common/_helpers.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,16 @@ def _check_device(xp, device):
159159
if device not in ["cpu", None]:
160160
raise ValueError(f"Unsupported device for NumPy: {device!r}")
161161

162-
# device() is not on numpy.ndarray and to_device() is not on numpy.ndarray
162+
# Placeholder object to represent the dask device
163+
# when the array backend is not the CPU.
164+
# (since it is not easy to tell which device a dask array is on)
165+
class _dask_device:
166+
def __repr__(self):
167+
return "DASK_DEVICE"
168+
169+
_DASK_DEVICE = _dask_device()
170+
171+
# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
163172
# or cupy.ndarray. They are not included in array objects of this library
164173
# because this library just reuses the respective ndarray classes without
165174
# wrapping or subclassing them. These helper functions can be used instead of
@@ -181,7 +190,17 @@ def device(x: Array, /) -> Device:
181190
"""
182191
if is_numpy_array(x):
183192
return "cpu"
184-
if is_jax_array(x):
193+
elif is_dask_array(x):
194+
# Peek at the metadata of the jax array to determine type
195+
try:
196+
import numpy as np
197+
if isinstance(x._meta, np.ndarray):
198+
# Must be on CPU since backed by numpy
199+
return "cpu"
200+
except ImportError:
201+
pass
202+
return _DASK_DEVICE
203+
elif is_jax_array(x):
185204
# JAX has .device() as a method, but it is being deprecated so that it
186205
# can become a property, in accordance with the standard. In order for
187206
# this function to not break when JAX makes the flip, we check for
@@ -288,3 +307,19 @@ def size(x):
288307
if None in x.shape:
289308
return None
290309
return math.prod(x.shape)
310+
311+
__all__ = [
312+
"array_namespace",
313+
"device",
314+
"get_namespace",
315+
"is_array_api_obj",
316+
"is_cupy_array",
317+
"is_dask_array",
318+
"is_jax_array",
319+
"is_numpy_array",
320+
"is_torch_array",
321+
"size",
322+
"to_device",
323+
]
324+
325+
_all_ignore = ['sys', 'math', 'inspect']

array_api_compat/common/_linalg.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
from typing import Literal, Optional, Tuple, Union
66
from ._typing import ndarray
77

8+
import math
9+
810
import numpy as np
911
if np.__version__[0] == "2":
1012
from numpy.lib.array_utils import normalize_axis_tuple
1113
else:
1214
from numpy.core.numeric import normalize_axis_tuple
1315

14-
from ._aliases import matrix_transpose, isdtype
16+
from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
1517
from .._internal import get_xp
1618

1719
# These are in the main NumPy namespace but not in numpy.linalg
@@ -110,21 +112,22 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]
110112
# on a single dimension.
111113
if axis is None:
112114
# Note: xp.linalg.norm() doesn't handle 0-D arrays
113-
x = x.ravel()
115+
_x = x.ravel()
114116
_axis = 0
115117
elif isinstance(axis, tuple):
116118
# Note: The axis argument supports any number of axes, whereas
117119
# xp.linalg.norm() only supports a single axis for vector norm.
118120
normalized_axis = normalize_axis_tuple(axis, x.ndim)
119121
rest = tuple(i for i in range(x.ndim) if i not in normalized_axis)
120122
newshape = axis + rest
121-
x = xp.transpose(x, newshape).reshape(
122-
(xp.prod([x.shape[i] for i in axis], dtype=int), *[x.shape[i] for i in rest]))
123+
_x = xp.transpose(x, newshape).reshape(
124+
(math.prod([x.shape[i] for i in axis]), *[x.shape[i] for i in rest]))
123125
_axis = 0
124126
else:
127+
_x = x
125128
_axis = axis
126129

127-
res = xp.linalg.norm(x, axis=_axis, ord=ord)
130+
res = xp.linalg.norm(_x, axis=_axis, ord=ord)
128131

129132
if keepdims:
130133
# We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
@@ -150,3 +153,9 @@ def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarra
150153
elif x.dtype == xp.complex64:
151154
dtype = xp.complex128
152155
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
156+
157+
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
158+
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
159+
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
160+
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
161+
'trace']

array_api_compat/common/_typing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ def __len__(self, /) -> int: ...
2020
SupportsBufferProtocol = Any
2121

2222
Array = Any
23-
Device = Any
23+
Device = Any

0 commit comments

Comments
 (0)