Skip to content

Commit 83e159e

Browse files
bzahdcherianpre-commit-ci[bot]
authored
Allow indexing unindexed dimensions using dask arrays (#5873)
* Attempt to fix indexing for Dask This is a naive attempt to make `isel` work with Dask Known limitation: it triggers the computation. * Works now. * avoid importorskip * More tests and fixes * Raise nicer error when indexing with boolean dask array * Annotate tests * edit query tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixes #4276 Pass 0d dask arrays through for indexing. * Add xfail notes. * backcompat: vendor np.broadcast_shapes * Small improvement * fix: Handle scalars properly. * fix bad test * Check computes with setitem * Better error * Cleanup * Raise nice error with VectorizedIndexer and dask. * Add whats-new --------- Co-authored-by: dcherian <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]>
1 parent 5043223 commit 83e159e

10 files changed

+211
-60
lines changed

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ New Features
2525

2626
- Fix :py:meth:`xr.cov` and :py:meth:`xr.corr` now support complex valued arrays (:issue:`7340`, :pull:`7392`).
2727
By `Michael Niklas <https://github.com/headtr1ck>`_.
28+
- Allow indexing along unindexed dimensions with dask arrays
29+
(:issue:`2511`, :issue:`4276`, :issue:`4663`, :pull:`5873`).
30+
By `Abel Aoun <https://github.com/bzah>`_ and `Deepak Cherian <https://github.com/dcherian>`_.
2831
- Support dask arrays in ``first`` and ``last`` reductions.
2932
By `Deepak Cherian <https://github.com/dcherian>`_.
3033

xarray/core/dataset.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
)
7474
from xarray.core.missing import get_clean_interp_index
7575
from xarray.core.options import OPTIONS, _get_keep_attrs
76-
from xarray.core.pycompat import array_type, is_duck_dask_array
76+
from xarray.core.pycompat import array_type, is_duck_array, is_duck_dask_array
7777
from xarray.core.types import QuantileMethods, T_Dataset
7878
from xarray.core.utils import (
7979
Default,
@@ -2292,7 +2292,8 @@ def _validate_indexers(
22922292
elif isinstance(v, Sequence) and len(v) == 0:
22932293
yield k, np.empty((0,), dtype="int64")
22942294
else:
2295-
v = np.asarray(v)
2295+
if not is_duck_array(v):
2296+
v = np.asarray(v)
22962297

22972298
if v.dtype.kind in "US":
22982299
index = self._indexes[k].to_pandas_index()

xarray/core/indexing.py

+22-10
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
from xarray.core import duck_array_ops
1818
from xarray.core.nputils import NumpyVIndexAdapter
1919
from xarray.core.options import OPTIONS
20-
from xarray.core.pycompat import array_type, integer_types, is_duck_dask_array
20+
from xarray.core.pycompat import (
21+
array_type,
22+
integer_types,
23+
is_duck_array,
24+
is_duck_dask_array,
25+
)
2126
from xarray.core.types import T_Xarray
2227
from xarray.core.utils import (
2328
NDArrayMixin,
@@ -368,17 +373,17 @@ def __init__(self, key):
368373
k = int(k)
369374
elif isinstance(k, slice):
370375
k = as_integer_slice(k)
371-
elif isinstance(k, np.ndarray):
376+
elif is_duck_array(k):
372377
if not np.issubdtype(k.dtype, np.integer):
373378
raise TypeError(
374379
f"invalid indexer array, does not have integer dtype: {k!r}"
375380
)
376-
if k.ndim != 1:
381+
if k.ndim > 1:
377382
raise TypeError(
378-
f"invalid indexer array for {type(self).__name__}; must have "
379-
f"exactly 1 dimension: {k!r}"
383+
f"invalid indexer array for {type(self).__name__}; must be scalar "
384+
f"or have 1 dimension: {k!r}"
380385
)
381-
k = np.asarray(k, dtype=np.int64)
386+
k = k.astype(np.int64)
382387
else:
383388
raise TypeError(
384389
f"unexpected indexer type for {type(self).__name__}: {k!r}"
@@ -409,7 +414,13 @@ def __init__(self, key):
409414
for k in key:
410415
if isinstance(k, slice):
411416
k = as_integer_slice(k)
412-
elif isinstance(k, np.ndarray):
417+
elif is_duck_dask_array(k):
418+
raise ValueError(
419+
"Vectorized indexing with Dask arrays is not supported. "
420+
"Please pass a numpy array by calling ``.compute``. "
421+
"See https://github.com/dask/dask/issues/8958."
422+
)
423+
elif is_duck_array(k):
413424
if not np.issubdtype(k.dtype, np.integer):
414425
raise TypeError(
415426
f"invalid indexer array, does not have integer dtype: {k!r}"
@@ -422,7 +433,7 @@ def __init__(self, key):
422433
"invalid indexer key: ndarray arguments "
423434
f"have different numbers of dimensions: {ndims}"
424435
)
425-
k = np.asarray(k, dtype=np.int64)
436+
k = k.astype(np.int64)
426437
else:
427438
raise TypeError(
428439
f"unexpected indexer type for {type(self).__name__}: {k!r}"
@@ -1351,8 +1362,9 @@ def __getitem__(self, key):
13511362
rewritten_indexer = False
13521363
new_indexer = []
13531364
for idim, k in enumerate(key.tuple):
1354-
if isinstance(k, Iterable) and duck_array_ops.array_equiv(
1355-
k, np.arange(self.array.shape[idim])
1365+
if isinstance(k, Iterable) and (
1366+
not is_duck_dask_array(k)
1367+
and duck_array_ops.array_equiv(k, np.arange(self.array.shape[idim]))
13561368
):
13571369
new_indexer.append(slice(None))
13581370
rewritten_indexer = True

xarray/core/nputils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
88

99
from xarray.core.options import OPTIONS
10+
from xarray.core.pycompat import is_duck_array
1011

1112
try:
1213
import bottleneck as bn
@@ -121,7 +122,10 @@ def _advanced_indexer_subspaces(key):
121122
return (), ()
122123

123124
non_slices = [k for k in key if not isinstance(k, slice)]
124-
ndim = len(np.broadcast(*non_slices).shape)
125+
broadcasted_shape = np.broadcast_shapes(
126+
*[item.shape if is_duck_array(item) else (0,) for item in non_slices]
127+
)
128+
ndim = len(broadcasted_shape)
125129
mixed_positions = advanced_index_positions[0] + np.arange(ndim)
126130
vindex_positions = np.arange(ndim)
127131
return mixed_positions, vindex_positions

xarray/core/pycompat.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
from packaging.version import Version
99

10-
from xarray.core.utils import is_duck_array, module_available
10+
from xarray.core.utils import is_duck_array, is_scalar, module_available
1111

1212
integer_types = (int, np.integer)
1313

@@ -79,3 +79,7 @@ def is_dask_collection(x):
7979

8080
def is_duck_dask_array(x):
8181
return is_duck_array(x) and is_dask_collection(x)
82+
83+
84+
def is_0d_dask_array(x):
85+
return is_duck_dask_array(x) and is_scalar(x)

xarray/core/variable.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
as_indexable,
2727
)
2828
from xarray.core.options import OPTIONS, _get_keep_attrs
29-
from xarray.core.pycompat import array_type, integer_types, is_duck_dask_array
29+
from xarray.core.pycompat import (
30+
array_type,
31+
integer_types,
32+
is_0d_dask_array,
33+
is_duck_dask_array,
34+
)
3035
from xarray.core.utils import (
3136
Frozen,
3237
NdimSizeLenMixin,
@@ -687,11 +692,12 @@ def _broadcast_indexes(self, key):
687692
key = self._item_key_to_tuple(key) # key is a tuple
688693
# key is a tuple of full size
689694
key = indexing.expanded_indexer(key, self.ndim)
690-
# Convert a scalar Variable to an integer
695+
# Convert a scalar Variable to a 0d-array
691696
key = tuple(
692-
k.data.item() if isinstance(k, Variable) and k.ndim == 0 else k for k in key
697+
k.data if isinstance(k, Variable) and k.ndim == 0 else k for k in key
693698
)
694-
# Convert a 0d-array to an integer
699+
# Convert a 0d numpy arrays to an integer
700+
# dask 0d arrays are passed through
695701
key = tuple(
696702
k.item() if isinstance(k, np.ndarray) and k.ndim == 0 else k for k in key
697703
)
@@ -732,7 +738,8 @@ def _validate_indexers(self, key):
732738
for dim, k in zip(self.dims, key):
733739
if not isinstance(k, BASIC_INDEXING_TYPES):
734740
if not isinstance(k, Variable):
735-
k = np.asarray(k)
741+
if not is_duck_array(k):
742+
k = np.asarray(k)
736743
if k.ndim > 1:
737744
raise IndexError(
738745
"Unlabeled multi-dimensional array cannot be "
@@ -749,6 +756,13 @@ def _validate_indexers(self, key):
749756
"{}-dimensional boolean indexing is "
750757
"not supported. ".format(k.ndim)
751758
)
759+
if is_duck_dask_array(k.data):
760+
raise KeyError(
761+
"Indexing with a boolean dask array is not allowed. "
762+
"This will result in a dask array of unknown shape. "
763+
"Such arrays are unsupported by Xarray."
764+
"Please compute the indexer first using .compute()"
765+
)
752766
if getattr(k, "dims", (dim,)) != (dim,):
753767
raise IndexError(
754768
"Boolean indexer should be unlabeled or on the "
@@ -759,18 +773,20 @@ def _validate_indexers(self, key):
759773
)
760774

761775
def _broadcast_indexes_outer(self, key):
776+
# drop dim if k is integer or if k is a 0d dask array
762777
dims = tuple(
763778
k.dims[0] if isinstance(k, Variable) else dim
764779
for k, dim in zip(key, self.dims)
765-
if not isinstance(k, integer_types)
780+
if (not isinstance(k, integer_types) and not is_0d_dask_array(k))
766781
)
767782

768783
new_key = []
769784
for k in key:
770785
if isinstance(k, Variable):
771786
k = k.data
772787
if not isinstance(k, BASIC_INDEXING_TYPES):
773-
k = np.asarray(k)
788+
if not is_duck_array(k):
789+
k = np.asarray(k)
774790
if k.size == 0:
775791
# Slice by empty list; numpy could not infer the dtype
776792
k = k.astype(int)

xarray/tests/test_dask.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,15 @@ def test_indexing(self):
123123
(da.array([99, 99, 3, 99]), [0, -1, 1]),
124124
(da.array([99, 99, 99, 4]), np.arange(3)),
125125
(da.array([1, 99, 99, 99]), [False, True, True, True]),
126-
(da.array([1, 99, 99, 99]), np.arange(4) > 0),
127-
(da.array([99, 99, 99, 99]), Variable(("x"), da.array([1, 2, 3, 4])) > 0),
126+
(da.array([1, 99, 99, 99]), np.array([False, True, True, True])),
127+
(da.array([99, 99, 99, 99]), Variable(("x"), np.array([True] * 4))),
128128
],
129129
)
130130
def test_setitem_dask_array(self, expected_data, index):
131131
arr = Variable(("x"), da.array([1, 2, 3, 4]))
132132
expected = Variable(("x"), expected_data)
133-
arr[index] = 99
133+
with raise_if_dask_computes():
134+
arr[index] = 99
134135
assert_identical(arr, expected)
135136

136137
def test_squeeze(self):

xarray/tests/test_dataarray.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -4216,45 +4216,49 @@ def test_query(
42164216
d = np.random.choice(["foo", "bar", "baz"], size=30, replace=True).astype(
42174217
object
42184218
)
4219-
if backend == "numpy":
4220-
aa = DataArray(data=a, dims=["x"], name="a")
4221-
bb = DataArray(data=b, dims=["x"], name="b")
4222-
cc = DataArray(data=c, dims=["y"], name="c")
4223-
dd = DataArray(data=d, dims=["z"], name="d")
4219+
aa = DataArray(data=a, dims=["x"], name="a", coords={"a2": ("x", a)})
4220+
bb = DataArray(data=b, dims=["x"], name="b", coords={"b2": ("x", b)})
4221+
cc = DataArray(data=c, dims=["y"], name="c", coords={"c2": ("y", c)})
4222+
dd = DataArray(data=d, dims=["z"], name="d", coords={"d2": ("z", d)})
42244223

4225-
elif backend == "dask":
4224+
if backend == "dask":
42264225
import dask.array as da
42274226

4228-
aa = DataArray(data=da.from_array(a, chunks=3), dims=["x"], name="a")
4229-
bb = DataArray(data=da.from_array(b, chunks=3), dims=["x"], name="b")
4230-
cc = DataArray(data=da.from_array(c, chunks=7), dims=["y"], name="c")
4231-
dd = DataArray(data=da.from_array(d, chunks=12), dims=["z"], name="d")
4227+
aa = aa.copy(data=da.from_array(a, chunks=3))
4228+
bb = bb.copy(data=da.from_array(b, chunks=3))
4229+
cc = cc.copy(data=da.from_array(c, chunks=7))
4230+
dd = dd.copy(data=da.from_array(d, chunks=12))
42324231

42334232
# query single dim, single variable
4234-
actual = aa.query(x="a > 5", engine=engine, parser=parser)
4233+
with raise_if_dask_computes():
4234+
actual = aa.query(x="a2 > 5", engine=engine, parser=parser)
42354235
expect = aa.isel(x=(a > 5))
42364236
assert_identical(expect, actual)
42374237

42384238
# query single dim, single variable, via dict
4239-
actual = aa.query(dict(x="a > 5"), engine=engine, parser=parser)
4239+
with raise_if_dask_computes():
4240+
actual = aa.query(dict(x="a2 > 5"), engine=engine, parser=parser)
42404241
expect = aa.isel(dict(x=(a > 5)))
42414242
assert_identical(expect, actual)
42424243

42434244
# query single dim, single variable
4244-
actual = bb.query(x="b > 50", engine=engine, parser=parser)
4245+
with raise_if_dask_computes():
4246+
actual = bb.query(x="b2 > 50", engine=engine, parser=parser)
42454247
expect = bb.isel(x=(b > 50))
42464248
assert_identical(expect, actual)
42474249

42484250
# query single dim, single variable
4249-
actual = cc.query(y="c < .5", engine=engine, parser=parser)
4251+
with raise_if_dask_computes():
4252+
actual = cc.query(y="c2 < .5", engine=engine, parser=parser)
42504253
expect = cc.isel(y=(c < 0.5))
42514254
assert_identical(expect, actual)
42524255

42534256
# query single dim, single string variable
42544257
if parser == "pandas":
42554258
# N.B., this query currently only works with the pandas parser
42564259
# xref https://github.com/pandas-dev/pandas/issues/40436
4257-
actual = dd.query(z='d == "bar"', engine=engine, parser=parser)
4260+
with raise_if_dask_computes():
4261+
actual = dd.query(z='d2 == "bar"', engine=engine, parser=parser)
42584262
expect = dd.isel(z=(d == "bar"))
42594263
assert_identical(expect, actual)
42604264

0 commit comments

Comments
 (0)