Skip to content

Commit 4863400

Browse files
committed
bool mask support
1 parent 7218499 commit 4863400

File tree

5 files changed

+154
-25
lines changed

5 files changed

+154
-25
lines changed

src/array_api_extra/_lib/_at.py

+80-12
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
from types import ModuleType
1010
from typing import ClassVar, cast
1111

12-
from ._utils._compat import array_namespace, is_jax_array, is_writeable_array
12+
from ._utils._compat import (
13+
array_namespace,
14+
is_dask_array,
15+
is_jax_array,
16+
is_writeable_array,
17+
)
1318
from ._utils._typing import Array, Index
1419

1520

@@ -141,6 +146,25 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
141146
not explicitly covered by ``array-api-compat``, are not supported by update
142147
methods.
143148
149+
Boolean masks are supported on Dask and jitted JAX arrays exclusively
150+
when `idx` has the same shape as `x` and `y` is 0-dimensional.
151+
Note that this is support is not available in JAX's native
152+
``x.at[mask].set(y)``.
153+
154+
This pattern::
155+
156+
>>> mask = m(x)
157+
>>> x[mask] = f(x[mask])
158+
159+
Can't be replaced by `at`, as it won't work on Dask and JAX inside jax.jit::
160+
161+
>>> mask = m(x)
162+
>>> x = xpx.at(x, mask).set(f(x[mask]) # Crash on Dask and jax.jit
163+
164+
You should instead use::
165+
166+
>>> x = xp.where(m(x), f(x), x)
167+
144168
Examples
145169
--------
146170
Given either of these equivalent expressions::
@@ -189,6 +213,7 @@ def _op(
189213
self,
190214
at_op: _AtOp,
191215
in_place_op: Callable[[Array, Array | object], Array] | None,
216+
out_of_place_op: Callable[[Array, Array], Array] | None,
192217
y: Array | object,
193218
/,
194219
copy: bool | None,
@@ -210,6 +235,16 @@ def _op(
210235
211236
x[idx] = y
212237
238+
out_of_place_op : Callable[[Array, Array], Array] | None
239+
Out-of-place operation to apply when idx is a boolean mask and the backend
240+
doesn't support in-place updates::
241+
242+
x = xp.where(idx, out_of_place_op(x, y), x)
243+
244+
If None::
245+
246+
x = xp.where(idx, y, x)
247+
213248
y : array or object
214249
Right-hand side of the operation.
215250
copy : bool or None
@@ -223,6 +258,7 @@ def _op(
223258
Updated `x`.
224259
"""
225260
x, idx = self._x, self._idx
261+
xp = array_namespace(x, y) if xp is None else xp
226262

227263
if idx is _undef:
228264
msg = (
@@ -247,15 +283,41 @@ def _op(
247283
else:
248284
writeable = is_writeable_array(x)
249285

286+
# JAX inside jax.jit and Dask don't support in-place updates with boolean
287+
# mask. However we can handle the common special case of 0-dimensional y
288+
# with where(idx, y, x) instead.
289+
if (
290+
(is_dask_array(idx) or is_jax_array(idx))
291+
and idx.dtype == xp.bool
292+
and idx.shape == x.shape
293+
):
294+
y_xp = xp.asarray(y, dtype=x.dtype)
295+
if y_xp.ndim == 0:
296+
if out_of_place_op:
297+
# FIXME: suppress inf warnings on dask with lazywhere
298+
out = xp.where(idx, out_of_place_op(x, y_xp), x)
299+
# Undo int->float promotion on JAX after _AtOp.DIVIDE
300+
out = xp.astype(out, x.dtype, copy=False)
301+
else:
302+
out = xp.where(idx, y_xp, x)
303+
304+
if copy:
305+
return out
306+
x[()] = out
307+
return x
308+
# else: this will work on eager JAX and crash on jax.jit and Dask
309+
250310
if copy:
251311
if is_jax_array(x):
252312
# Use JAX's at[]
253313
func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value))
254-
return func(y)
314+
out = func(y)
315+
# Undo int->float promotion on JAX after _AtOp.DIVIDE
316+
return xp.astype(out, x.dtype, copy=False)
317+
255318
# Emulate at[] behaviour for non-JAX arrays
256319
# with a copy followed by an update
257-
if xp is None:
258-
xp = array_namespace(x)
320+
259321
x = xp.asarray(x, copy=True)
260322
if writeable is False:
261323
# A copy of a read-only numpy array is writeable
@@ -283,7 +345,7 @@ def set(
283345
xp: ModuleType | None = None,
284346
) -> Array: # numpydoc ignore=PR01,RT01
285347
"""Apply ``x[idx] = y`` and return the update array."""
286-
return self._op(_AtOp.SET, None, y, copy=copy, xp=xp)
348+
return self._op(_AtOp.SET, None, None, y, copy=copy, xp=xp)
287349

288350
def add(
289351
self,
@@ -297,7 +359,7 @@ def add(
297359
# Note for this and all other methods based on _iop:
298360
# operator.iadd and operator.add subtly differ in behaviour, as
299361
# only iadd will trigger exceptions when y has an incompatible dtype.
300-
return self._op(_AtOp.ADD, operator.iadd, y, copy=copy, xp=xp)
362+
return self._op(_AtOp.ADD, operator.iadd, operator.add, y, copy=copy, xp=xp)
301363

302364
def subtract(
303365
self,
@@ -307,7 +369,9 @@ def subtract(
307369
xp: ModuleType | None = None,
308370
) -> Array: # numpydoc ignore=PR01,RT01
309371
"""Apply ``x[idx] -= y`` and return the updated array."""
310-
return self._op(_AtOp.SUBTRACT, operator.isub, y, copy=copy, xp=xp)
372+
return self._op(
373+
_AtOp.SUBTRACT, operator.isub, operator.sub, y, copy=copy, xp=xp
374+
)
311375

312376
def multiply(
313377
self,
@@ -317,7 +381,9 @@ def multiply(
317381
xp: ModuleType | None = None,
318382
) -> Array: # numpydoc ignore=PR01,RT01
319383
"""Apply ``x[idx] *= y`` and return the updated array."""
320-
return self._op(_AtOp.MULTIPLY, operator.imul, y, copy=copy, xp=xp)
384+
return self._op(
385+
_AtOp.MULTIPLY, operator.imul, operator.mul, y, copy=copy, xp=xp
386+
)
321387

322388
def divide(
323389
self,
@@ -327,7 +393,9 @@ def divide(
327393
xp: ModuleType | None = None,
328394
) -> Array: # numpydoc ignore=PR01,RT01
329395
"""Apply ``x[idx] /= y`` and return the updated array."""
330-
return self._op(_AtOp.DIVIDE, operator.itruediv, y, copy=copy, xp=xp)
396+
return self._op(
397+
_AtOp.DIVIDE, operator.itruediv, operator.truediv, y, copy=copy, xp=xp
398+
)
331399

332400
def power(
333401
self,
@@ -337,7 +405,7 @@ def power(
337405
xp: ModuleType | None = None,
338406
) -> Array: # numpydoc ignore=PR01,RT01
339407
"""Apply ``x[idx] **= y`` and return the updated array."""
340-
return self._op(_AtOp.POWER, operator.ipow, y, copy=copy, xp=xp)
408+
return self._op(_AtOp.POWER, operator.ipow, operator.pow, y, copy=copy, xp=xp)
341409

342410
def min(
343411
self,
@@ -349,7 +417,7 @@ def min(
349417
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
350418
xp = array_namespace(self._x) if xp is None else xp
351419
y = xp.asarray(y)
352-
return self._op(_AtOp.MIN, xp.minimum, y, copy=copy, xp=xp)
420+
return self._op(_AtOp.MIN, xp.minimum, xp.minimum, y, copy=copy, xp=xp)
353421

354422
def max(
355423
self,
@@ -361,4 +429,4 @@ def max(
361429
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
362430
xp = array_namespace(self._x) if xp is None else xp
363431
y = xp.asarray(y)
364-
return self._op(_AtOp.MAX, xp.maximum, y, copy=copy, xp=xp)
432+
return self._op(_AtOp.MAX, xp.maximum, xp.maximum, y, copy=copy, xp=xp)

src/array_api_extra/_lib/_utils/_compat.py

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
device,
99
is_array_api_strict_namespace,
1010
is_cupy_namespace,
11+
is_dask_array,
1112
is_dask_namespace,
1213
is_jax_array,
1314
is_jax_namespace,
@@ -23,6 +24,7 @@
2324
device,
2425
is_array_api_strict_namespace,
2526
is_cupy_namespace,
27+
is_dask_array,
2628
is_dask_namespace,
2729
is_jax_array,
2830
is_jax_namespace,
@@ -38,6 +40,7 @@
3840
"device",
3941
"is_array_api_strict_namespace",
4042
"is_cupy_namespace",
43+
"is_dask_array",
4144
"is_dask_namespace",
4245
"is_jax_array",
4346
"is_jax_namespace",

src/array_api_extra/_lib/_utils/_compat.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def is_jax_namespace(xp: ModuleType, /) -> bool: ...
2525
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
2626
def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ...
2727
def is_torch_namespace(xp: ModuleType, /) -> bool: ...
28+
def is_dask_array(x: object, /) -> bool: ...
2829
def is_jax_array(x: object, /) -> bool: ...
2930
def is_writeable_array(x: object, /) -> bool: ...
3031
def size(x: Array, /) -> int | None: ...

tests/test_at.py

+68-13
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import pickle
23
from collections.abc import Callable, Generator
34
from contextlib import contextmanager
@@ -100,14 +101,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
100101
[
101102
(False, False),
102103
(False, True),
103-
pytest.param(
104-
True,
105-
False,
106-
marks=(
107-
pytest.mark.skip_xp_backend(Backend.JAX, reason="TODO special case"),
108-
pytest.mark.skip_xp_backend(Backend.DASK, reason="TODO special case"),
109-
),
110-
),
104+
(True, False), # Uses xp.where(idx, y, x) on JAX and Dask
111105
pytest.param(
112106
True,
113107
True,
@@ -176,11 +170,16 @@ def test_alternate_index_syntax():
176170
at(a, 0)[0].set(4)
177171

178172

179-
@pytest.mark.parametrize("copy", [True, False])
173+
@pytest.mark.skip_xp_backend(
174+
Backend.SPARSE, reason="read-only backend without .at support"
175+
)
176+
@pytest.mark.parametrize("copy", [True, None])
180177
@pytest.mark.parametrize(
181178
"op", [_AtOp.ADD, _AtOp.SUBTRACT, _AtOp.MULTIPLY, _AtOp.DIVIDE, _AtOp.POWER]
182179
)
183-
def test_iops_incompatible_dtype(op: _AtOp, copy: bool):
180+
def test_iops_incompatible_dtype(
181+
xp: ModuleType, library: Backend, op: _AtOp, copy: bool | None
182+
):
184183
"""Test that at() replicates the backend's behaviour for
185184
in-place operations with incompatible dtypes.
186185
@@ -192,6 +191,62 @@ def test_iops_incompatible_dtype(op: _AtOp, copy: bool):
192191
UFuncTypeError: Cannot cast ufunc 'divide' output from dtype('float64')
193192
to dtype('int64') with casting rule 'same_kind'
194193
"""
195-
x = np.asarray([2, 4])
196-
with pytest.raises(TypeError, match="Cannot cast ufunc"):
197-
at_op(x, slice(None), op, 1.1, copy=copy)
194+
x = xp.asarray([2, 4])
195+
196+
if library is Backend.DASK:
197+
z = at_op(x, slice(None), op, 1.1, copy=copy)
198+
assert z.dtype == x.dtype
199+
200+
elif library is Backend.JAX:
201+
with pytest.warns(FutureWarning, match="cannot safely cast"):
202+
z = at_op(x, slice(None), op, 1.1, copy=copy)
203+
assert z.dtype == x.dtype
204+
205+
else:
206+
with pytest.raises(Exception, match=r"cast|promote|dtype"):
207+
at_op(x, slice(None), op, 1.1, copy=copy)
208+
209+
210+
@pytest.mark.skip_xp_backend(
211+
Backend.SPARSE, reason="read-only backend without .at support"
212+
)
213+
@pytest.mark.parametrize(
214+
"op", [_AtOp.ADD, _AtOp.SUBTRACT, _AtOp.MULTIPLY, _AtOp.DIVIDE, _AtOp.POWER]
215+
)
216+
def test_bool_mask_incompatible_dtype(xp: ModuleType, library: Backend, op: _AtOp):
217+
"""
218+
When xp.where(idx, y, x) would promote the dtype of the output
219+
to y.dtype, at(x, idx).set(y) must retain x.dtype instead
220+
"""
221+
x = xp.asarray([1, 2])
222+
idx = xp.asarray([True, False])
223+
if library in (Backend.DASK, Backend.JAX):
224+
z = at_op(x, idx, op, 1.1)
225+
assert z.dtype == x.dtype
226+
227+
else:
228+
with pytest.raises(Exception, match=r"cast|promote|dtype"):
229+
at_op(x, idx, op, 1.1)
230+
231+
232+
@pytest.mark.skip_xp_backend(
233+
Backend.SPARSE, reason="read-only backend without .at support"
234+
)
235+
def test_bool_mask_nd(xp: ModuleType):
236+
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
237+
idx = xp.asarray([[True, False, False], [False, True, True]])
238+
z = at_op(x, idx, _AtOp.SET, 0)
239+
xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]]))
240+
241+
242+
@pytest.mark.skip_xp_backend(
243+
Backend.SPARSE, reason="read-only backend without .at support"
244+
)
245+
@pytest.mark.skip_xp_backend(Backend.DASK, reason="FIXME need scipy's lazywhere")
246+
@pytest.mark.parametrize("bool_mask", [False, True])
247+
def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
248+
x = xp.asarray([math.inf, 1.0, 2.0])
249+
idx = ~xp.isinf(x) if bool_mask else slice(1, None)
250+
# inf - inf -> nan with a warning
251+
z = at_op(x, idx, _AtOp.SUBTRACT, math.inf)
252+
xp_assert_equal(z, xp.asarray([math.inf, -math.inf, -math.inf]))

vendor_tests/test_vendor.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ def test_vendor_compat():
77
array_namespace,
88
device,
99
is_cupy_namespace,
10+
is_dask_array,
1011
is_dask_namespace,
1112
is_jax_array,
1213
is_jax_namespace,
@@ -20,6 +21,7 @@ def test_vendor_compat():
2021
assert array_namespace(x) is xp
2122
device(x)
2223
assert not is_cupy_namespace(xp)
24+
assert not is_dask_array(x)
2325
assert not is_dask_namespace(xp)
2426
assert not is_jax_array(x)
2527
assert not is_jax_namespace(xp)

0 commit comments

Comments
 (0)