Skip to content

Commit 1a4a689

Browse files
jbrockmendelnickleus27
authored andcommitted
REF: simplify putmask_smart (pandas-dev#44435)
1 parent 07261dd commit 1a4a689

File tree

2 files changed

+20
-37
lines changed

2 files changed

+20
-37
lines changed

pandas/core/array_algos/putmask.py

+9-32
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from __future__ import annotations
55

66
from typing import Any
7-
import warnings
87

98
import numpy as np
109

@@ -15,16 +14,12 @@
1514
)
1615

1716
from pandas.core.dtypes.cast import (
17+
can_hold_element,
1818
convert_scalar_for_putitemlike,
1919
find_common_type,
2020
infer_dtype_from,
2121
)
22-
from pandas.core.dtypes.common import (
23-
is_float_dtype,
24-
is_integer_dtype,
25-
is_list_like,
26-
)
27-
from pandas.core.dtypes.missing import isna_compat
22+
from pandas.core.dtypes.common import is_list_like
2823

2924
from pandas.core.arrays import ExtensionArray
3025

@@ -75,7 +70,7 @@ def putmask_smart(values: np.ndarray, mask: npt.NDArray[np.bool_], new) -> np.nd
7570
`values`, updated in-place.
7671
mask : np.ndarray[bool]
7772
Applies to both sides (array like).
78-
new : `new values` either scalar or an array like aligned with `values`
73+
new : listlike `new values` aligned with `values`
7974
8075
Returns
8176
-------
@@ -89,9 +84,6 @@ def putmask_smart(values: np.ndarray, mask: npt.NDArray[np.bool_], new) -> np.nd
8984
# we cannot use np.asarray() here as we cannot have conversions
9085
# that numpy does when numeric are mixed with strings
9186

92-
if not is_list_like(new):
93-
new = np.broadcast_to(new, mask.shape)
94-
9587
# see if we are only masking values that if putted
9688
# will work in the current dtype
9789
try:
@@ -100,27 +92,12 @@ def putmask_smart(values: np.ndarray, mask: npt.NDArray[np.bool_], new) -> np.nd
10092
# TypeError: only integer scalar arrays can be converted to a scalar index
10193
pass
10294
else:
103-
# make sure that we have a nullable type if we have nulls
104-
if not isna_compat(values, nn[0]):
105-
pass
106-
elif not (is_float_dtype(nn.dtype) or is_integer_dtype(nn.dtype)):
107-
# only compare integers/floats
108-
pass
109-
elif not (is_float_dtype(values.dtype) or is_integer_dtype(values.dtype)):
110-
# only compare integers/floats
111-
pass
112-
else:
113-
114-
# we ignore ComplexWarning here
115-
with warnings.catch_warnings(record=True):
116-
warnings.simplefilter("ignore", np.ComplexWarning)
117-
nn_at = nn.astype(values.dtype)
118-
119-
comp = nn == nn_at
120-
if is_list_like(comp) and comp.all():
121-
nv = values.copy()
122-
nv[mask] = nn_at
123-
return nv
95+
# We only get to putmask_smart when we cannot hold 'new' in values.
96+
# The "smart" part of putmask_smart is checking if we can hold new[mask]
97+
# in values, in which case we can still avoid the need to cast.
98+
if can_hold_element(values, nn):
99+
values[mask] = nn
100+
return values
124101

125102
new = np.asarray(new)
126103

pandas/core/internals/blocks.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -952,15 +952,15 @@ def putmask(self, mask, new) -> list[Block]:
952952
List[Block]
953953
"""
954954
orig_mask = mask
955-
mask, noop = validate_putmask(self.values.T, mask)
955+
values = cast(np.ndarray, self.values)
956+
mask, noop = validate_putmask(values.T, mask)
956957
assert not isinstance(new, (ABCIndex, ABCSeries, ABCDataFrame))
957958

958959
# if we are passed a scalar None, convert it here
959960
if not self.is_object and is_valid_na_for_dtype(new, self.dtype):
960961
new = self.fill_value
961962

962963
if self._can_hold_element(new):
963-
964964
# error: Argument 1 to "putmask_without_repeat" has incompatible type
965965
# "Union[ndarray, ExtensionArray]"; expected "ndarray"
966966
putmask_without_repeat(self.values.T, mask, new) # type: ignore[arg-type]
@@ -979,9 +979,15 @@ def putmask(self, mask, new) -> list[Block]:
979979
elif self.ndim == 1 or self.shape[0] == 1:
980980
# no need to split columns
981981

982-
# error: Argument 1 to "putmask_smart" has incompatible type "Union[ndarray,
983-
# ExtensionArray]"; expected "ndarray"
984-
nv = putmask_smart(self.values.T, mask, new).T # type: ignore[arg-type]
982+
if not is_list_like(new):
983+
# putmask_smart can't save us the need to cast
984+
return self.coerce_to_target_dtype(new).putmask(mask, new)
985+
986+
# This differs from
987+
# `self.coerce_to_target_dtype(new).putmask(mask, new)`
988+
# because putmask_smart will check if new[mask] may be held
989+
# by our dtype.
990+
nv = putmask_smart(values.T, mask, new).T
985991
return [self.make_block(nv)]
986992

987993
else:

0 commit comments

Comments
 (0)