-
Notifications
You must be signed in to change notification settings - Fork 34
Add dask to array-api-compat #76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 21 commits
8ee1613
3cc2857
ad6bf56
f178e28
8c9c784
6305d7e
06afd75
6e0ef29
9abe56d
20622ba
a484d5a
0d66160
df69086
1f9799a
424f25d
64d20c4
762a03c
69cc93b
f52b3d5
5edc5ec
565666a
6841758
4be5517
cd381a0
54f4838
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
name: Array API Tests (Dask) | ||
|
||
on: [push, pull_request] | ||
|
||
jobs: | ||
array-api-tests-dask: | ||
uses: ./.github/workflows/array-api-tests.yml | ||
with: | ||
package-name: dask | ||
module-name: dask.array | ||
extra-requires: numpy | ||
pytest-extra-args: --disable-deadline --max-examples=5 |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -77,7 +77,8 @@ def matrix_rank(x: ndarray, | |||||||||
# dimensional arrays. | ||||||||||
if x.ndim < 2: | ||||||||||
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") | ||||||||||
S = xp.linalg.svd(x, compute_uv=False, **kwargs) | ||||||||||
S = xp.linalg.svdvals(x, **kwargs) | ||||||||||
#S = xp.linalg.svd(x, compute_uv=False, **kwargs) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change looks like it is causing the numpy array api tests to fail. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current numpy dev CI failures can be ignored (they should go away once numpy/numpy#25668 is merged), but the other ones are important. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the catch. I think this should be equal to svdvals (if svdvals exists), so I just used a hasattr check. Lemme know if that's too hacky There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should use As for not wrapping svdvals when it is present, we could do that too (I did this for array-api-compat/array_api_compat/numpy/linalg.py Lines 30 to 33 in 916a84b
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, updated. |
||||||||||
if rtol is None: | ||||||||||
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps | ||||||||||
else: | ||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from dask.array import * | ||
|
||
# These imports may overwrite names from the import * above. | ||
from ._aliases import * | ||
|
||
__array_api_version__ = '2022.12' | ||
|
||
__import__(__package__ + '.linalg') |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
from __future__ import annotations | ||
|
||
from ...common import _aliases | ||
from ...common._helpers import _check_device | ||
|
||
from ..._internal import get_xp | ||
|
||
import numpy as np | ||
from numpy import ( | ||
# Constants | ||
e, | ||
inf, | ||
nan, | ||
pi, | ||
newaxis, | ||
# Dtypes | ||
bool_ as bool, | ||
float32, | ||
float64, | ||
int8, | ||
int16, | ||
int32, | ||
int64, | ||
uint8, | ||
uint16, | ||
uint32, | ||
uint64, | ||
complex64, | ||
complex128, | ||
iinfo, | ||
finfo, | ||
can_cast, | ||
result_type, | ||
) | ||
|
||
from typing import TYPE_CHECKING | ||
if TYPE_CHECKING: | ||
from typing import Optional, Union | ||
from ...common._typing import ndarray, Device, Dtype | ||
|
||
import dask.array as da | ||
|
||
isdtype = get_xp(np)(_aliases.isdtype) | ||
astype = _aliases.astype | ||
|
||
# Common aliases | ||
|
||
# This arange func is modified from the common one to | ||
# not pass stop/step as keyword arguments, which will cause | ||
# an error with dask | ||
|
||
# TODO: delete the xp stuff, it shouldn't be necessary | ||
def dask_arange( | ||
start: Union[int, float], | ||
/, | ||
stop: Optional[Union[int, float]] = None, | ||
step: Union[int, float] = 1, | ||
*, | ||
xp, | ||
dtype: Optional[Dtype] = None, | ||
device: Optional[Device] = None, | ||
**kwargs | ||
) -> ndarray: | ||
_check_device(xp, device) | ||
args = [start] | ||
if stop is not None: | ||
args.append(stop) | ||
else: | ||
# stop is None, so start is actually stop | ||
# prepend the default value for start which is 0 | ||
args.insert(0, 0) | ||
args.append(step) | ||
return xp.arange(*args, dtype=dtype, **kwargs) | ||
|
||
arange = get_xp(da)(dask_arange) | ||
eye = get_xp(da)(_aliases.eye) | ||
|
||
from functools import partial | ||
asarray = partial(_aliases._asarray, namespace='dask.array') | ||
asarray.__doc__ = _aliases._asarray.__doc__ | ||
|
||
linspace = get_xp(da)(_aliases.linspace) | ||
eye = get_xp(da)(_aliases.eye) | ||
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult) | ||
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult) | ||
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult) | ||
unique_all = get_xp(da)(_aliases.unique_all) | ||
unique_counts = get_xp(da)(_aliases.unique_counts) | ||
unique_inverse = get_xp(da)(_aliases.unique_inverse) | ||
unique_values = get_xp(da)(_aliases.unique_values) | ||
permute_dims = get_xp(da)(_aliases.permute_dims) | ||
std = get_xp(da)(_aliases.std) | ||
var = get_xp(da)(_aliases.var) | ||
empty = get_xp(da)(_aliases.empty) | ||
empty_like = get_xp(da)(_aliases.empty_like) | ||
full = get_xp(da)(_aliases.full) | ||
full_like = get_xp(da)(_aliases.full_like) | ||
ones = get_xp(da)(_aliases.ones) | ||
ones_like = get_xp(da)(_aliases.ones_like) | ||
zeros = get_xp(da)(_aliases.zeros) | ||
zeros_like = get_xp(da)(_aliases.zeros_like) | ||
reshape = get_xp(da)(_aliases.reshape) | ||
matrix_transpose = get_xp(da)(_aliases.matrix_transpose) | ||
vecdot = get_xp(da)(_aliases.vecdot) | ||
|
||
from dask.array import ( | ||
# Element wise aliases | ||
arccos as acos, | ||
arccosh as acosh, | ||
arcsin as asin, | ||
arcsinh as asinh, | ||
arctan as atan, | ||
arctan2 as atan2, | ||
arctanh as atanh, | ||
left_shift as bitwise_left_shift, | ||
right_shift as bitwise_right_shift, | ||
invert as bitwise_invert, | ||
power as pow, | ||
# Other | ||
concatenate as concat, | ||
) | ||
|
||
del da | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's also manage There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated for everything except linalg. I'm not sure if we can do linalg since there's no all for dask.array.linalg. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See how it's done for cupy.linalg, which also doesn't define There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be updated as well. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from __future__ import annotations | ||
|
||
from dask.array.linalg import * | ||
from ...common import _linalg | ||
from ..._internal import get_xp | ||
from dask.array import matmul, tensordot, trace, outer | ||
from ._aliases import matrix_transpose, vecdot | ||
|
||
import dask.array as da | ||
|
||
from typing import TYPE_CHECKING | ||
if TYPE_CHECKING: | ||
from typing import Optional, Union, Tuple | ||
from ...common._typing import ndarray, Device, Dtype | ||
|
||
#cross = get_xp(da)(_linalg.cross) | ||
#outer = get_xp(da)(_linalg.outer) | ||
EighResult = _linalg.EighResult | ||
QRResult = _linalg.QRResult | ||
SlogdetResult = _linalg.SlogdetResult | ||
SVDResult = _linalg.SVDResult | ||
qr = get_xp(da)(_linalg.qr) | ||
#svd = get_xp(da)(_linalg.svd) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you get rid of all the little commented out bits of code like this. Or if one of them actually is useful to keep, add an explanatory comment above it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, had some code left over from debugging. Everything should be cleaned up now. 🤞 |
||
cholesky = get_xp(da)(_linalg.cholesky) | ||
matrix_rank = get_xp(da)(_linalg.matrix_rank) | ||
#pinv = get_xp(da)(_linalg.pinv) | ||
matrix_norm = get_xp(da)(_linalg.matrix_norm) | ||
|
||
def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]: | ||
# TODO: can't avoid computing U or V for dask | ||
_, s, _ = svd(x) | ||
return s | ||
|
||
vector_norm = get_xp(da)(_linalg.vector_norm) | ||
diagonal = get_xp(da)(_linalg.diagonal) | ||
|
||
#__all__ = linalg_all + _linalg.__all__ | ||
|
||
del get_xp | ||
del da | ||
#del linalg_all | ||
del _linalg |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# FFT isn't conformant | ||
array_api_tests/test_fft.py | ||
|
||
# slow and not implemented in dask | ||
array_api_tests/test_linalg.py::test_matrix_power |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think dask supports cupy on the GPU.
Is this something we also need to take into consideration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes indeed, it does support CuPy. I"m not sure that a
stream
input can be supported though. If a Dask array spans multiple machines, I think there'd be multiple streams with numbers that are unrelated to each other. Which therefore can't be supported at the Dask level in this API.That's probably fine - you'd only move arrays in a single process to another device like this I think, so maybe the whole
to_device
method doesn't quite work for Dask? @jakirkham any thoughts on this?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if dask can't really support device transfers or anything I would not call the default device "cpu" as that's misleading. We could just create a proxy
DaskDevice
object that serves as the device for dask arrays.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lithomas1 what's the status of this comment? Does the
device()
helper above need to be updated for dask?