Skip to content

Commit a46f8a8

Browse files
committed
da.asarray should not materialize the graph
1 parent 5ef0e18 commit a46f8a8

File tree

2 files changed

+61
-17
lines changed

2 files changed

+61
-17
lines changed

array_api_compat/dask/array/_aliases.py

+21-17
Original file line numberDiff line numberDiff line change
@@ -129,24 +129,28 @@ def asarray(
129129
See the corresponding documentation in the array library and/or the array API
130130
specification for more details.
131131
"""
132+
if isinstance(obj, da.Array):
133+
if dtype is not None:
134+
# Note: at the moment of writing, dask ignores the copy parameter
135+
# and always behaves with copy=False. We pass the parameter anyway
136+
# for the sake of forward compatibility.
137+
res = obj.astype(dtype, copy=True if copy is True else False)
138+
if copy is False and res is not obj:
139+
raise ValueError("Unable to avoid copy")
140+
else:
141+
res = obj
142+
return obj.copy() if copy else obj
143+
132144
if copy is False:
133-
# copy=False is not yet implemented in dask
134-
raise NotImplementedError("copy=False is not yet implemented")
135-
elif copy is True:
136-
if isinstance(obj, da.Array) and dtype is None:
137-
return obj.copy()
138-
# Go through numpy, since dask copy is no-op by default
139-
obj = np.array(obj, dtype=dtype, copy=True)
140-
return da.array(obj, dtype=dtype)
141-
else:
142-
if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype:
143-
# copy=True to be uniform across dask < 2024.12 and >= 2024.12
144-
# see https://github.com/dask/dask/pull/11524/
145-
obj = np.array(obj, dtype=dtype, copy=True)
146-
return da.from_array(obj)
147-
return obj
148-
149-
return da.asarray(obj, dtype=dtype, **kwargs)
145+
raise NotImplementedError(
146+
"copy=False is not possible when converting a non-dask object to dask"
147+
)
148+
149+
# copy=None to be uniform across dask < 2024.12 and >= 2024.12
150+
# see https://github.com/dask/dask/pull/11524/
151+
obj = np.asarray(obj, dtype=dtype, copy=True)
152+
return da.from_array(obj)
153+
150154

151155
from dask.array import (
152156
# Element wise aliases

tests/test_dask.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import dask
2+
import numpy as np
3+
import pytest
4+
import array_api_compat.dask.array as xp
5+
6+
@pytest.fixture
7+
def no_compute():
8+
"""
9+
Cause the test to raise if at any point anything calls compute() or persist(),
10+
e.g. as it can be triggered implicitly by __bool__, __array__, etc.
11+
"""
12+
def get(dsk, *args, **kwargs):
13+
raise AssertionError("Called compute() or persist()")
14+
15+
with dask.config.set(scheduler=get):
16+
yield
17+
18+
19+
def test_no_compute(no_compute):
20+
"""Test the no_compute_fixture"""
21+
a = xp.asarray(True)
22+
with pytest.raises(AssertionError, match="Called compute"):
23+
bool(a)
24+
25+
26+
def test_asarray_no_compute(no_compute):
27+
a = xp.arange(10)
28+
xp.asarray(a)
29+
xp.asarray(a, dtype=np.int16)
30+
xp.asarray(a, dtype=a.dtype)
31+
xp.asarray(a, copy=True)
32+
xp.asarray(a, copy=True, dtype=np.int16)
33+
xp.asarray(a, copy=True, dtype=a.dtype)
34+
35+
36+
def test_clip_no_compute(no_compute):
37+
a = xp.arange(10)
38+
xp.clip(a)
39+
xp.clip(a, 1)
40+
xp.clip(a, 1, 8)

0 commit comments

Comments
 (0)