9
9
from types import ModuleType
10
10
from typing import ClassVar , cast
11
11
12
- from ._utils ._compat import array_namespace , is_jax_array , is_writeable_array
12
+ from ._utils ._compat import (
13
+ array_namespace ,
14
+ is_dask_array ,
15
+ is_jax_array ,
16
+ is_writeable_array ,
17
+ )
13
18
from ._utils ._typing import Array , Index
14
19
15
20
@@ -141,6 +146,25 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
141
146
not explicitly covered by ``array-api-compat``, are not supported by update
142
147
methods.
143
148
149
+ Boolean masks are supported on Dask and jitted JAX arrays exclusively
150
+ when `idx` has the same shape as `x` and `y` is 0-dimensional.
151
+ Note that this is support is not available in JAX's native
152
+ ``x.at[mask].set(y)``.
153
+
154
+ This pattern::
155
+
156
+ >>> mask = m(x)
157
+ >>> x[mask] = f(x[mask])
158
+
159
+ Can't be replaced by `at`, as it won't work on Dask and JAX inside jax.jit::
160
+
161
+ >>> mask = m(x)
162
+ >>> x = xpx.at(x, mask).set(f(x[mask]) # Crash on Dask and jax.jit
163
+
164
+ You should instead use::
165
+
166
+ >>> x = xp.where(m(x), f(x), x)
167
+
144
168
Examples
145
169
--------
146
170
Given either of these equivalent expressions::
@@ -189,6 +213,7 @@ def _op(
189
213
self ,
190
214
at_op : _AtOp ,
191
215
in_place_op : Callable [[Array , Array | object ], Array ] | None ,
216
+ out_of_place_op : Callable [[Array , Array ], Array ] | None ,
192
217
y : Array | object ,
193
218
/ ,
194
219
copy : bool | None ,
@@ -210,6 +235,16 @@ def _op(
210
235
211
236
x[idx] = y
212
237
238
+ out_of_place_op : Callable[[Array, Array], Array] | None
239
+ Out-of-place operation to apply when idx is a boolean mask and the backend
240
+ doesn't support in-place updates::
241
+
242
+ x = xp.where(idx, out_of_place_op(x, y), x)
243
+
244
+ If None::
245
+
246
+ x = xp.where(idx, y, x)
247
+
213
248
y : array or object
214
249
Right-hand side of the operation.
215
250
copy : bool or None
@@ -223,6 +258,7 @@ def _op(
223
258
Updated `x`.
224
259
"""
225
260
x , idx = self ._x , self ._idx
261
+ xp = array_namespace (x , y ) if xp is None else xp
226
262
227
263
if idx is _undef :
228
264
msg = (
@@ -247,15 +283,41 @@ def _op(
247
283
else :
248
284
writeable = is_writeable_array (x )
249
285
286
+ # JAX inside jax.jit and Dask don't support in-place updates with boolean
287
+ # mask. However we can handle the common special case of 0-dimensional y
288
+ # with where(idx, y, x) instead.
289
+ if (
290
+ (is_dask_array (idx ) or is_jax_array (idx ))
291
+ and idx .dtype == xp .bool
292
+ and idx .shape == x .shape
293
+ ):
294
+ y_xp = xp .asarray (y , dtype = x .dtype )
295
+ if y_xp .ndim == 0 :
296
+ if out_of_place_op :
297
+ # FIXME: suppress inf warnings on dask with lazywhere
298
+ out = xp .where (idx , out_of_place_op (x , y_xp ), x )
299
+ # Undo int->float promotion on JAX after _AtOp.DIVIDE
300
+ out = xp .astype (out , x .dtype , copy = False )
301
+ else :
302
+ out = xp .where (idx , y_xp , x )
303
+
304
+ if copy :
305
+ return out
306
+ x [()] = out
307
+ return x
308
+ # else: this will work on eager JAX and crash on jax.jit and Dask
309
+
250
310
if copy :
251
311
if is_jax_array (x ):
252
312
# Use JAX's at[]
253
313
func = cast (Callable [[Array ], Array ], getattr (x .at [idx ], at_op .value ))
254
- return func (y )
314
+ out = func (y )
315
+ # Undo int->float promotion on JAX after _AtOp.DIVIDE
316
+ return xp .astype (out , x .dtype , copy = False )
317
+
255
318
# Emulate at[] behaviour for non-JAX arrays
256
319
# with a copy followed by an update
257
- if xp is None :
258
- xp = array_namespace (x )
320
+
259
321
x = xp .asarray (x , copy = True )
260
322
if writeable is False :
261
323
# A copy of a read-only numpy array is writeable
@@ -283,7 +345,7 @@ def set(
283
345
xp : ModuleType | None = None ,
284
346
) -> Array : # numpydoc ignore=PR01,RT01
285
347
"""Apply ``x[idx] = y`` and return the update array."""
286
- return self ._op (_AtOp .SET , None , y , copy = copy , xp = xp )
348
+ return self ._op (_AtOp .SET , None , None , y , copy = copy , xp = xp )
287
349
288
350
def add (
289
351
self ,
@@ -297,7 +359,7 @@ def add(
297
359
# Note for this and all other methods based on _iop:
298
360
# operator.iadd and operator.add subtly differ in behaviour, as
299
361
# only iadd will trigger exceptions when y has an incompatible dtype.
300
- return self ._op (_AtOp .ADD , operator .iadd , y , copy = copy , xp = xp )
362
+ return self ._op (_AtOp .ADD , operator .iadd , operator . add , y , copy = copy , xp = xp )
301
363
302
364
def subtract (
303
365
self ,
@@ -307,7 +369,9 @@ def subtract(
307
369
xp : ModuleType | None = None ,
308
370
) -> Array : # numpydoc ignore=PR01,RT01
309
371
"""Apply ``x[idx] -= y`` and return the updated array."""
310
- return self ._op (_AtOp .SUBTRACT , operator .isub , y , copy = copy , xp = xp )
372
+ return self ._op (
373
+ _AtOp .SUBTRACT , operator .isub , operator .sub , y , copy = copy , xp = xp
374
+ )
311
375
312
376
def multiply (
313
377
self ,
@@ -317,7 +381,9 @@ def multiply(
317
381
xp : ModuleType | None = None ,
318
382
) -> Array : # numpydoc ignore=PR01,RT01
319
383
"""Apply ``x[idx] *= y`` and return the updated array."""
320
- return self ._op (_AtOp .MULTIPLY , operator .imul , y , copy = copy , xp = xp )
384
+ return self ._op (
385
+ _AtOp .MULTIPLY , operator .imul , operator .mul , y , copy = copy , xp = xp
386
+ )
321
387
322
388
def divide (
323
389
self ,
@@ -327,7 +393,9 @@ def divide(
327
393
xp : ModuleType | None = None ,
328
394
) -> Array : # numpydoc ignore=PR01,RT01
329
395
"""Apply ``x[idx] /= y`` and return the updated array."""
330
- return self ._op (_AtOp .DIVIDE , operator .itruediv , y , copy = copy , xp = xp )
396
+ return self ._op (
397
+ _AtOp .DIVIDE , operator .itruediv , operator .truediv , y , copy = copy , xp = xp
398
+ )
331
399
332
400
def power (
333
401
self ,
@@ -337,7 +405,7 @@ def power(
337
405
xp : ModuleType | None = None ,
338
406
) -> Array : # numpydoc ignore=PR01,RT01
339
407
"""Apply ``x[idx] **= y`` and return the updated array."""
340
- return self ._op (_AtOp .POWER , operator .ipow , y , copy = copy , xp = xp )
408
+ return self ._op (_AtOp .POWER , operator .ipow , operator . pow , y , copy = copy , xp = xp )
341
409
342
410
def min (
343
411
self ,
@@ -349,7 +417,7 @@ def min(
349
417
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
350
418
xp = array_namespace (self ._x ) if xp is None else xp
351
419
y = xp .asarray (y )
352
- return self ._op (_AtOp .MIN , xp .minimum , y , copy = copy , xp = xp )
420
+ return self ._op (_AtOp .MIN , xp .minimum , xp . minimum , y , copy = copy , xp = xp )
353
421
354
422
def max (
355
423
self ,
@@ -361,4 +429,4 @@ def max(
361
429
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
362
430
xp = array_namespace (self ._x ) if xp is None else xp
363
431
y = xp .asarray (y )
364
- return self ._op (_AtOp .MAX , xp .maximum , y , copy = copy , xp = xp )
432
+ return self ._op (_AtOp .MAX , xp .maximum , xp . maximum , y , copy = copy , xp = xp )
0 commit comments