diff --git a/pandas/core/array_algos/putmask.py b/pandas/core/array_algos/putmask.py index 77e38e6c6e3fc..1f37e0e5d249a 100644 --- a/pandas/core/array_algos/putmask.py +++ b/pandas/core/array_algos/putmask.py @@ -4,7 +4,6 @@ from __future__ import annotations from typing import Any -import warnings import numpy as np @@ -15,16 +14,12 @@ ) from pandas.core.dtypes.cast import ( + can_hold_element, convert_scalar_for_putitemlike, find_common_type, infer_dtype_from, ) -from pandas.core.dtypes.common import ( - is_float_dtype, - is_integer_dtype, - is_list_like, -) -from pandas.core.dtypes.missing import isna_compat +from pandas.core.dtypes.common import is_list_like from pandas.core.arrays import ExtensionArray @@ -75,7 +70,7 @@ def putmask_smart(values: np.ndarray, mask: npt.NDArray[np.bool_], new) -> np.nd `values`, updated in-place. mask : np.ndarray[bool] Applies to both sides (array like). - new : `new values` either scalar or an array like aligned with `values` + new : listlike `new values` aligned with `values` Returns ------- @@ -89,9 +84,6 @@ def putmask_smart(values: np.ndarray, mask: npt.NDArray[np.bool_], new) -> np.nd # we cannot use np.asarray() here as we cannot have conversions # that numpy does when numeric are mixed with strings - if not is_list_like(new): - new = np.broadcast_to(new, mask.shape) - # see if we are only masking values that if putted # will work in the current dtype try: @@ -100,27 +92,12 @@ def putmask_smart(values: np.ndarray, mask: npt.NDArray[np.bool_], new) -> np.nd # TypeError: only integer scalar arrays can be converted to a scalar index pass else: - # make sure that we have a nullable type if we have nulls - if not isna_compat(values, nn[0]): - pass - elif not (is_float_dtype(nn.dtype) or is_integer_dtype(nn.dtype)): - # only compare integers/floats - pass - elif not (is_float_dtype(values.dtype) or is_integer_dtype(values.dtype)): - # only compare integers/floats - pass - else: - - # we ignore ComplexWarning here - with warnings.catch_warnings(record=True): - warnings.simplefilter("ignore", np.ComplexWarning) - nn_at = nn.astype(values.dtype) - - comp = nn == nn_at - if is_list_like(comp) and comp.all(): - nv = values.copy() - nv[mask] = nn_at - return nv + # We only get to putmask_smart when we cannot hold 'new' in values. + # The "smart" part of putmask_smart is checking if we can hold new[mask] + # in values, in which case we can still avoid the need to cast. + if can_hold_element(values, nn): + values[mask] = nn + return values new = np.asarray(new) diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 55e5b0d0439fa..e20bbb0d90fba 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -952,7 +952,8 @@ def putmask(self, mask, new) -> list[Block]: List[Block] """ orig_mask = mask - mask, noop = validate_putmask(self.values.T, mask) + values = cast(np.ndarray, self.values) + mask, noop = validate_putmask(values.T, mask) assert not isinstance(new, (ABCIndex, ABCSeries, ABCDataFrame)) # if we are passed a scalar None, convert it here @@ -960,7 +961,6 @@ def putmask(self, mask, new) -> list[Block]: new = self.fill_value if self._can_hold_element(new): - # error: Argument 1 to "putmask_without_repeat" has incompatible type # "Union[ndarray, ExtensionArray]"; expected "ndarray" putmask_without_repeat(self.values.T, mask, new) # type: ignore[arg-type] @@ -979,9 +979,15 @@ def putmask(self, mask, new) -> list[Block]: elif self.ndim == 1 or self.shape[0] == 1: # no need to split columns - # error: Argument 1 to "putmask_smart" has incompatible type "Union[ndarray, - # ExtensionArray]"; expected "ndarray" - nv = putmask_smart(self.values.T, mask, new).T # type: ignore[arg-type] + if not is_list_like(new): + # putmask_smart can't save us the need to cast + return self.coerce_to_target_dtype(new).putmask(mask, new) + + # This differs from + # `self.coerce_to_target_dtype(new).putmask(mask, new)` + # because putmask_smart will check if new[mask] may be held + # by our dtype. + nv = putmask_smart(values.T, mask, new).T return [self.make_block(nv)] else: