Skip to content

REF: simplify putmask_smart #44435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 9 additions & 32 deletions pandas/core/array_algos/putmask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from __future__ import annotations

from typing import Any
import warnings

import numpy as np

Expand All @@ -15,16 +14,12 @@
)

from pandas.core.dtypes.cast import (
can_hold_element,
convert_scalar_for_putitemlike,
find_common_type,
infer_dtype_from,
)
from pandas.core.dtypes.common import (
is_float_dtype,
is_integer_dtype,
is_list_like,
)
from pandas.core.dtypes.missing import isna_compat
from pandas.core.dtypes.common import is_list_like

from pandas.core.arrays import ExtensionArray

Expand Down Expand Up @@ -75,7 +70,7 @@ def putmask_smart(values: np.ndarray, mask: npt.NDArray[np.bool_], new) -> np.nd
`values`, updated in-place.
mask : np.ndarray[bool]
Applies to both sides (array like).
new : `new values` either scalar or an array like aligned with `values`
new : listlike `new values` aligned with `values`

Returns
-------
Expand All @@ -89,9 +84,6 @@ def putmask_smart(values: np.ndarray, mask: npt.NDArray[np.bool_], new) -> np.nd
# we cannot use np.asarray() here as we cannot have conversions
# that numpy does when numeric are mixed with strings

if not is_list_like(new):
new = np.broadcast_to(new, mask.shape)

# see if we are only masking values that if putted
# will work in the current dtype
try:
Expand All @@ -100,27 +92,12 @@ def putmask_smart(values: np.ndarray, mask: npt.NDArray[np.bool_], new) -> np.nd
# TypeError: only integer scalar arrays can be converted to a scalar index
pass
else:
# make sure that we have a nullable type if we have nulls
if not isna_compat(values, nn[0]):
pass
elif not (is_float_dtype(nn.dtype) or is_integer_dtype(nn.dtype)):
# only compare integers/floats
pass
elif not (is_float_dtype(values.dtype) or is_integer_dtype(values.dtype)):
# only compare integers/floats
pass
else:

# we ignore ComplexWarning here
with warnings.catch_warnings(record=True):
warnings.simplefilter("ignore", np.ComplexWarning)
nn_at = nn.astype(values.dtype)

comp = nn == nn_at
if is_list_like(comp) and comp.all():
nv = values.copy()
nv[mask] = nn_at
return nv
# We only get to putmask_smart when we cannot hold 'new' in values.
# The "smart" part of putmask_smart is checking if we can hold new[mask]
# in values, in which case we can still avoid the need to cast.
if can_hold_element(values, nn):
values[mask] = nn
return values

new = np.asarray(new)

Expand Down
16 changes: 11 additions & 5 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,15 +952,15 @@ def putmask(self, mask, new) -> list[Block]:
List[Block]
"""
orig_mask = mask
mask, noop = validate_putmask(self.values.T, mask)
values = cast(np.ndarray, self.values)
mask, noop = validate_putmask(values.T, mask)
assert not isinstance(new, (ABCIndex, ABCSeries, ABCDataFrame))

# if we are passed a scalar None, convert it here
if not self.is_object and is_valid_na_for_dtype(new, self.dtype):
new = self.fill_value

if self._can_hold_element(new):

# error: Argument 1 to "putmask_without_repeat" has incompatible type
# "Union[ndarray, ExtensionArray]"; expected "ndarray"
putmask_without_repeat(self.values.T, mask, new) # type: ignore[arg-type]
Expand All @@ -979,9 +979,15 @@ def putmask(self, mask, new) -> list[Block]:
elif self.ndim == 1 or self.shape[0] == 1:
# no need to split columns

# error: Argument 1 to "putmask_smart" has incompatible type "Union[ndarray,
# ExtensionArray]"; expected "ndarray"
nv = putmask_smart(self.values.T, mask, new).T # type: ignore[arg-type]
if not is_list_like(new):
# putmask_smart can't save us the need to cast
return self.coerce_to_target_dtype(new).putmask(mask, new)

# This differs from
# `self.coerce_to_target_dtype(new).putmask(mask, new)`
# because putmask_smart will check if new[mask] may be held
# by our dtype.
nv = putmask_smart(values.T, mask, new).T
return [self.make_block(nv)]

else:
Expand Down