Skip to content

Add dask to array-api-compat #76

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/array-api-tests-dask.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: Array API Tests (Dask)

on: [push, pull_request]

jobs:
array-api-tests-dask:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: dask
module-name: dask.array
extra-requires: numpy
pytest-extra-args: --disable-deadline --max-examples=5
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest numpy torch
python -m pip install pytest numpy torch dask[array]

- name: Run Tests
run: |
Expand Down
14 changes: 12 additions & 2 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ def _asarray(
import numpy as xp
elif namespace == 'cupy':
import cupy as xp
elif namespace == 'dask.array':
import dask.array as xp
else:
raise ValueError("Unrecognized namespace argument to asarray()")

Expand All @@ -322,11 +324,19 @@ def _asarray(
if copy in COPY_FALSE:
# copy=False is not yet implemented in xp.asarray
raise NotImplementedError("copy=False is not yet implemented")
if isinstance(obj, xp.ndarray):
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"):
if dtype is not None and obj.dtype != dtype:
copy = True
if copy in COPY_TRUE:
return xp.array(obj, copy=True, dtype=dtype)
copy_kwargs = {}
if namespace != "dask.array":
copy_kwargs["copy"] = True
else:
# No copy kw in dask.asarray so we go thorugh np.asarray first
# (like dask also does) but copy after
import numpy as np
obj = np.asarray(obj).copy()
return xp.array(obj, dtype=dtype, **copy_kwargs)
return obj

return xp.asarray(obj, dtype=dtype, **kwargs)
Expand Down
24 changes: 24 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,23 @@ def _is_torch_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)

def _is_dask_array(x):
# Avoid importing dask if it isn't already
if 'dask.array' not in sys.modules:
return False

import dask.array

return isinstance(x, dask.array.Array)

def is_array_api_obj(x):
"""
Check if x is an array API compatible array object.
"""
return _is_numpy_array(x) \
or _is_cupy_array(x) \
or _is_torch_array(x) \
or _is_dask_array(x) \
or hasattr(x, '__array_namespace__')

def _check_api_version(api_version):
Expand Down Expand Up @@ -95,6 +105,13 @@ def your_function(x, y):
else:
import torch
namespaces.add(torch)
elif _is_dask_array(x):
_check_api_version(api_version)
if _use_compat:
from ..dask import array as dask_namespace
namespaces.add(dask_namespace)
else:
raise TypeError("_use_compat cannot be False if input array is a dask array!")
elif hasattr(x, '__array_namespace__'):
namespaces.add(x.__array_namespace__(api_version=api_version))
else:
Expand Down Expand Up @@ -219,6 +236,13 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A
return _cupy_to_device(x, device, stream=stream)
elif _is_torch_array(x):
return _torch_to_device(x, device, stream=stream)
elif _is_dask_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
# TODO: What if our array is on the GPU already?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dask supports cupy on the GPU.

Is this something we also need to take into consideration?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed, it does support CuPy. I"m not sure that a stream input can be supported though. If a Dask array spans multiple machines, I think there'd be multiple streams with numbers that are unrelated to each other. Which therefore can't be supported at the Dask level in this API.

That's probably fine - you'd only move arrays in a single process to another device like this I think, so maybe the whole to_device method doesn't quite work for Dask? @jakirkham any thoughts on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if dask can't really support device transfers or anything I would not call the default device "cpu" as that's misleading. We could just create a proxy DaskDevice object that serves as the device for dask arrays.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lithomas1 what's the status of this comment? Does the device() helper above need to be updated for dask?

if device == 'cpu':
return x
raise ValueError(f"Unsupported device {device!r}")
return x.to_device(device, stream=stream)

def size(x):
Expand Down
3 changes: 2 additions & 1 deletion array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def matrix_rank(x: ndarray,
# dimensional arrays.
if x.ndim < 2:
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
S = xp.linalg.svd(x, compute_uv=False, **kwargs)
S = xp.linalg.svdvals(x, **kwargs)
#S = xp.linalg.svd(x, compute_uv=False, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change looks like it is causing the numpy array api tests to fail.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current numpy dev CI failures can be ignored (they should go away once numpy/numpy#25668 is merged), but the other ones are important.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch. I think this should be equal to svdvals (if svdvals exists), so I just used a hasattr check. Lemme know if that's too hacky

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should use get_xp(xp).linalg.svdvals so that it uses the svdvals alias. This pattern is already used in a few places in this file.

As for not wrapping svdvals when it is present, we could do that too (I did this for vector_norm

if hasattr(np.linalg, 'vector_norm'):
vector_norm = np.linalg.vector_norm
else:
vector_norm = get_xp(np)(_linalg.vector_norm)
but I forgot about svdvals).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated.

if rtol is None:
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
else:
Expand Down
8 changes: 8 additions & 0 deletions array_api_compat/dask/array/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dask.array import *

# These imports may overwrite names from the import * above.
from ._aliases import *

__array_api_version__ = '2022.12'

__import__(__package__ + '.linalg')
123 changes: 123 additions & 0 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from __future__ import annotations

from ...common import _aliases
from ...common._helpers import _check_device

from ..._internal import get_xp

import numpy as np
from numpy import (
# Constants
e,
inf,
nan,
pi,
newaxis,
# Dtypes
bool_ as bool,
float32,
float64,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
complex64,
complex128,
iinfo,
finfo,
can_cast,
result_type,
)

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Union
from ...common._typing import ndarray, Device, Dtype

import dask.array as da

isdtype = get_xp(np)(_aliases.isdtype)
astype = _aliases.astype

# Common aliases

# This arange func is modified from the common one to
# not pass stop/step as keyword arguments, which will cause
# an error with dask

# TODO: delete the xp stuff, it shouldn't be necessary
def dask_arange(
start: Union[int, float],
/,
stop: Optional[Union[int, float]] = None,
step: Union[int, float] = 1,
*,
xp,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs
) -> ndarray:
_check_device(xp, device)
args = [start]
if stop is not None:
args.append(stop)
else:
# stop is None, so start is actually stop
# prepend the default value for start which is 0
args.insert(0, 0)
args.append(step)
return xp.arange(*args, dtype=dtype, **kwargs)

arange = get_xp(da)(dask_arange)
eye = get_xp(da)(_aliases.eye)

from functools import partial
asarray = partial(_aliases._asarray, namespace='dask.array')
asarray.__doc__ = _aliases._asarray.__doc__

linspace = get_xp(da)(_aliases.linspace)
eye = get_xp(da)(_aliases.eye)
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
unique_all = get_xp(da)(_aliases.unique_all)
unique_counts = get_xp(da)(_aliases.unique_counts)
unique_inverse = get_xp(da)(_aliases.unique_inverse)
unique_values = get_xp(da)(_aliases.unique_values)
permute_dims = get_xp(da)(_aliases.permute_dims)
std = get_xp(da)(_aliases.std)
var = get_xp(da)(_aliases.var)
empty = get_xp(da)(_aliases.empty)
empty_like = get_xp(da)(_aliases.empty_like)
full = get_xp(da)(_aliases.full)
full_like = get_xp(da)(_aliases.full_like)
ones = get_xp(da)(_aliases.ones)
ones_like = get_xp(da)(_aliases.ones_like)
zeros = get_xp(da)(_aliases.zeros)
zeros_like = get_xp(da)(_aliases.zeros_like)
reshape = get_xp(da)(_aliases.reshape)
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
vecdot = get_xp(da)(_aliases.vecdot)

from dask.array import (
# Element wise aliases
arccos as acos,
arccosh as acosh,
arcsin as asin,
arcsinh as asinh,
arctan as atan,
arctan2 as atan2,
arctanh as atanh,
left_shift as bitwise_left_shift,
right_shift as bitwise_right_shift,
invert as bitwise_invert,
power as pow,
# Other
concatenate as concat,
)

del da
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also manage __all__ in these files the same as in the other submodules.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated for everything except linalg.

I'm not sure if we can do linalg since there's no all for dask.array.linalg.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See how it's done for cupy.linalg, which also doesn't define __all__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be updated as well.

42 changes: 42 additions & 0 deletions array_api_compat/dask/array/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

from dask.array.linalg import *
from ...common import _linalg
from ..._internal import get_xp
from dask.array import matmul, tensordot, trace, outer
from ._aliases import matrix_transpose, vecdot

import dask.array as da

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Union, Tuple
from ...common._typing import ndarray, Device, Dtype

#cross = get_xp(da)(_linalg.cross)
#outer = get_xp(da)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
qr = get_xp(da)(_linalg.qr)
#svd = get_xp(da)(_linalg.svd)
Copy link
Member

@asmeurer asmeurer Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you get rid of all the little commented out bits of code like this. Or if one of them actually is useful to keep, add an explanatory comment above it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, had some code left over from debugging.

Everything should be cleaned up now. 🤞

cholesky = get_xp(da)(_linalg.cholesky)
matrix_rank = get_xp(da)(_linalg.matrix_rank)
#pinv = get_xp(da)(_linalg.pinv)
matrix_norm = get_xp(da)(_linalg.matrix_norm)

def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]:
# TODO: can't avoid computing U or V for dask
_, s, _ = svd(x)
return s

vector_norm = get_xp(da)(_linalg.vector_norm)
diagonal = get_xp(da)(_linalg.diagonal)

#__all__ = linalg_all + _linalg.__all__

del get_xp
del da
#del linalg_all
del _linalg
5 changes: 5 additions & 0 deletions dask-skips.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# FFT isn't conformant
array_api_tests/test_fft.py

# slow and not implemented in dask
array_api_tests/test_linalg.py::test_matrix_power
Loading