Skip to content

Commit 885a1c4

Browse files
jbrockmendelnickleus27
authored andcommitted
ENH: implement EA._putmask (pandas-dev#44387)
1 parent c4316b5 commit 885a1c4

File tree

5 files changed

+37
-13
lines changed

5 files changed

+37
-13
lines changed

pandas/core/arrays/_mixins.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def _wrap_reduction_result(self, axis: int | None, result):
310310
# ------------------------------------------------------------------------
311311
# __array_function__ methods
312312

313-
def putmask(self, mask: np.ndarray, value) -> None:
313+
def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
314314
"""
315315
Analogue to np.putmask(self, mask, value)
316316

pandas/core/arrays/base.py

+27
Original file line numberDiff line numberDiff line change
@@ -1409,6 +1409,33 @@ def insert(self: ExtensionArrayT, loc: int, item) -> ExtensionArrayT:
14091409

14101410
return type(self)._concat_same_type([self[:loc], item_arr, self[loc:]])
14111411

1412+
def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
1413+
"""
1414+
Analogue to np.putmask(self, mask, value)
1415+
1416+
Parameters
1417+
----------
1418+
mask : np.ndarray[bool]
1419+
value : scalar or listlike
1420+
If listlike, must be arraylike with same length as self.
1421+
1422+
Returns
1423+
-------
1424+
None
1425+
1426+
Notes
1427+
-----
1428+
Unlike np.putmask, we do not repeat listlike values with mismatched length.
1429+
'value' should either be a scalar or an arraylike with the same length
1430+
as self.
1431+
"""
1432+
if is_list_like(value):
1433+
val = value[mask]
1434+
else:
1435+
val = value
1436+
1437+
self[mask] = val
1438+
14121439
def _where(
14131440
self: ExtensionArrayT, mask: npt.NDArray[np.bool_], value
14141441
) -> ExtensionArrayT:

pandas/core/arrays/interval.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
PositionalIndexer,
3737
ScalarIndexer,
3838
SequenceIndexer,
39+
npt,
3940
)
4041
from pandas.compat.numpy import function as nv
4142
from pandas.util._decorators import Appender
@@ -1482,15 +1483,15 @@ def to_tuples(self, na_tuple=True) -> np.ndarray:
14821483

14831484
# ---------------------------------------------------------------------
14841485

1485-
def putmask(self, mask: np.ndarray, value) -> None:
1486+
def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
14861487
value_left, value_right = self._validate_setitem_value(value)
14871488

14881489
if isinstance(self._left, np.ndarray):
14891490
np.putmask(self._left, mask, value_left)
14901491
np.putmask(self._right, mask, value_right)
14911492
else:
1492-
self._left.putmask(mask, value_left)
1493-
self._right.putmask(mask, value_right)
1493+
self._left._putmask(mask, value_left)
1494+
self._right._putmask(mask, value_right)
14941495

14951496
def insert(self: IntervalArrayT, loc: int, item: Interval) -> IntervalArrayT:
14961497
"""

pandas/core/indexes/base.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -4444,8 +4444,7 @@ def _join_non_unique(
44444444
if isinstance(join_array, np.ndarray):
44454445
np.putmask(join_array, mask, right)
44464446
else:
4447-
# error: "ExtensionArray" has no attribute "putmask"
4448-
join_array.putmask(mask, right) # type: ignore[attr-defined]
4447+
join_array._putmask(mask, right)
44494448

44504449
join_index = self._wrap_joined_index(join_array, other)
44514450

@@ -5051,8 +5050,7 @@ def putmask(self, mask, value) -> Index:
50515050
else:
50525051
# Note: we use the original value here, not converted, as
50535052
# _validate_fill_value is not idempotent
5054-
# error: "ExtensionArray" has no attribute "putmask"
5055-
values.putmask(mask, value) # type: ignore[attr-defined]
5053+
values._putmask(mask, value)
50565054

50575055
return self._shallow_copy(values)
50585056

pandas/core/internals/blocks.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1415,15 +1415,13 @@ def putmask(self, mask, new) -> list[Block]:
14151415

14161416
new_values = self.values
14171417

1418-
if isinstance(new, (np.ndarray, ExtensionArray)) and len(new) == len(mask):
1419-
new = new[mask]
1420-
14211418
if mask.ndim == new_values.ndim + 1:
14221419
# TODO(EA2D): unnecessary with 2D EAs
14231420
mask = mask.reshape(new_values.shape)
14241421

14251422
try:
1426-
new_values[mask] = new
1423+
# Caller is responsible for ensuring matching lengths
1424+
new_values._putmask(mask, new)
14271425
except TypeError:
14281426
if not is_interval_dtype(self.dtype):
14291427
# Discussion about what we want to support in the general
@@ -1704,7 +1702,7 @@ def putmask(self, mask, new) -> list[Block]:
17041702
return self.coerce_to_target_dtype(new).putmask(mask, new)
17051703

17061704
arr = self.values
1707-
arr.T.putmask(mask, new)
1705+
arr.T._putmask(mask, new)
17081706
return [self]
17091707

17101708
def where(self, other, cond) -> list[Block]:

0 commit comments

Comments
 (0)