Skip to content

Commit f7a44db

Browse files
authored
REF: de-duplicate dt64/td64 putmask/setitem shims (#39778)
1 parent 3b36529 commit f7a44db

File tree

5 files changed

+97
-78
lines changed

5 files changed

+97
-78
lines changed

pandas/core/array_algos/putmask.py

+58-3
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
"""
22
EA-compatible analogue to to np.putmask
33
"""
4-
from typing import Any
4+
from typing import Any, Tuple
55
import warnings
66

77
import numpy as np
88

99
from pandas._libs import lib
1010
from pandas._typing import ArrayLike
1111

12-
from pandas.core.dtypes.cast import convert_scalar_for_putitemlike, find_common_type
12+
from pandas.core.dtypes.cast import (
13+
convert_scalar_for_putitemlike,
14+
find_common_type,
15+
infer_dtype_from,
16+
)
1317
from pandas.core.dtypes.common import is_float_dtype, is_integer_dtype, is_list_like
1418
from pandas.core.dtypes.missing import isna_compat
1519

20+
from pandas.core.arrays import ExtensionArray
21+
1622

1723
def putmask_inplace(values: ArrayLike, mask: np.ndarray, value: Any) -> None:
1824
"""
@@ -22,7 +28,7 @@ def putmask_inplace(values: ArrayLike, mask: np.ndarray, value: Any) -> None:
2228
Parameters
2329
----------
2430
mask : np.ndarray[bool]
25-
We assume _extract_bool_array has already been called.
31+
We assume extract_bool_array has already been called.
2632
value : Any
2733
"""
2834

@@ -152,3 +158,52 @@ def putmask_without_repeat(values: np.ndarray, mask: np.ndarray, new: Any) -> No
152158
raise ValueError("cannot assign mismatch length to masked array")
153159
else:
154160
np.putmask(values, mask, new)
161+
162+
163+
def validate_putmask(values: ArrayLike, mask: np.ndarray) -> Tuple[np.ndarray, bool]:
164+
"""
165+
Validate mask and check if this putmask operation is a no-op.
166+
"""
167+
mask = extract_bool_array(mask)
168+
if mask.shape != values.shape:
169+
raise ValueError("putmask: mask and data must be the same size")
170+
171+
noop = not mask.any()
172+
return mask, noop
173+
174+
175+
def extract_bool_array(mask: ArrayLike) -> np.ndarray:
176+
"""
177+
If we have a SparseArray or BooleanArray, convert it to ndarray[bool].
178+
"""
179+
if isinstance(mask, ExtensionArray):
180+
# We could have BooleanArray, Sparse[bool], ...
181+
# Except for BooleanArray, this is equivalent to just
182+
# np.asarray(mask, dtype=bool)
183+
mask = mask.to_numpy(dtype=bool, na_value=False)
184+
185+
mask = np.asarray(mask, dtype=bool)
186+
return mask
187+
188+
189+
def setitem_datetimelike_compat(values: np.ndarray, num_set: int, other):
190+
"""
191+
Parameters
192+
----------
193+
values : np.ndarray
194+
num_set : int
195+
For putmask, this is mask.sum()
196+
other : Any
197+
"""
198+
if values.dtype == object:
199+
dtype, _ = infer_dtype_from(other, pandas_dtype=True)
200+
201+
if isinstance(dtype, np.dtype) and dtype.kind in ["m", "M"]:
202+
# https://github.com/numpy/numpy/issues/12550
203+
# timedelta64 will incorrectly cast to int
204+
if not is_list_like(other):
205+
other = [other] * num_set
206+
else:
207+
other = list(other)
208+
209+
return other

pandas/core/indexes/base.py

+13-16
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@
9393
from pandas.core import missing, ops
9494
from pandas.core.accessor import CachedAccessor
9595
import pandas.core.algorithms as algos
96+
from pandas.core.array_algos.putmask import (
97+
setitem_datetimelike_compat,
98+
validate_putmask,
99+
)
96100
from pandas.core.arrays import Categorical, ExtensionArray
97101
from pandas.core.arrays.datetimes import tz_to_dtype, validate_tz_from_dtype
98102
from pandas.core.arrays.sparse import SparseDtype
@@ -4274,6 +4278,7 @@ def memory_usage(self, deep: bool = False) -> int:
42744278
result += self._engine.sizeof(deep=deep)
42754279
return result
42764280

4281+
@final
42774282
def where(self, cond, other=None):
42784283
"""
42794284
Replace values where the condition is False.
@@ -4306,6 +4311,10 @@ def where(self, cond, other=None):
43064311
>>> idx.where(idx.isin(['car', 'train']), 'other')
43074312
Index(['car', 'other', 'train', 'other'], dtype='object')
43084313
"""
4314+
if isinstance(self, ABCMultiIndex):
4315+
raise NotImplementedError(
4316+
".where is not supported for MultiIndex operations"
4317+
)
43094318
cond = np.asarray(cond, dtype=bool)
43104319
return self.putmask(~cond, other)
43114320

@@ -4522,10 +4531,8 @@ def putmask(self, mask, value):
45224531
numpy.ndarray.putmask : Changes elements of an array
45234532
based on conditional and input values.
45244533
"""
4525-
mask = np.asarray(mask, dtype=bool)
4526-
if mask.shape != self.shape:
4527-
raise ValueError("putmask: mask and data must be the same size")
4528-
if not mask.any():
4534+
mask, noop = validate_putmask(self._values, mask)
4535+
if noop:
45294536
return self.copy()
45304537

45314538
if value is None and (self._is_numeric_dtype or self.dtype == object):
@@ -4540,18 +4547,8 @@ def putmask(self, mask, value):
45404547
return self.astype(dtype).putmask(mask, value)
45414548

45424549
values = self._values.copy()
4543-
dtype, _ = infer_dtype_from(converted, pandas_dtype=True)
4544-
if dtype.kind in ["m", "M"]:
4545-
# https://github.com/numpy/numpy/issues/12550
4546-
# timedelta64 will incorrectly cast to int
4547-
if not is_list_like(converted):
4548-
converted = [converted] * mask.sum()
4549-
values[mask] = converted
4550-
else:
4551-
converted = list(converted)
4552-
np.putmask(values, mask, converted)
4553-
else:
4554-
np.putmask(values, mask, converted)
4550+
converted = setitem_datetimelike_compat(values, mask.sum(), converted)
4551+
np.putmask(values, mask, converted)
45554552

45564553
return type(self)._simple_new(values, name=self.name)
45574554

pandas/core/indexes/interval.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from pandas.core.dtypes.dtypes import IntervalDtype
4444

4545
from pandas.core.algorithms import take_nd, unique
46+
from pandas.core.array_algos.putmask import validate_putmask
4647
from pandas.core.arrays.interval import IntervalArray, _interval_shared_docs
4748
import pandas.core.common as com
4849
from pandas.core.indexers import is_valid_positional_slice
@@ -799,10 +800,8 @@ def length(self):
799800
return Index(self._data.length, copy=False)
800801

801802
def putmask(self, mask, value):
802-
mask = np.asarray(mask, dtype=bool)
803-
if mask.shape != self.shape:
804-
raise ValueError("putmask: mask and data must be the same size")
805-
if not mask.any():
803+
mask, noop = validate_putmask(self._data, mask)
804+
if noop:
806805
return self.copy()
807806

808807
try:

pandas/core/indexes/multi.py

-3
Original file line numberDiff line numberDiff line change
@@ -2151,9 +2151,6 @@ def repeat(self, repeats: int, axis=None) -> MultiIndex:
21512151
verify_integrity=False,
21522152
)
21532153

2154-
def where(self, cond, other=None):
2155-
raise NotImplementedError(".where is not supported for MultiIndex operations")
2156-
21572154
def drop(self, codes, level=None, errors="raise"):
21582155
"""
21592156
Make new MultiIndex with passed list of codes deleted

pandas/core/internals/blocks.py

+23-52
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
astype_dt64_to_dt64tz,
2525
astype_nansafe,
2626
can_hold_element,
27-
convert_scalar_for_putitemlike,
2827
find_common_type,
2928
infer_dtype_from,
3029
maybe_downcast_numeric,
@@ -52,9 +51,12 @@
5251

5352
import pandas.core.algorithms as algos
5453
from pandas.core.array_algos.putmask import (
54+
extract_bool_array,
5555
putmask_inplace,
5656
putmask_smart,
5757
putmask_without_repeat,
58+
setitem_datetimelike_compat,
59+
validate_putmask,
5860
)
5961
from pandas.core.array_algos.quantile import quantile_with_mask
6062
from pandas.core.array_algos.replace import (
@@ -425,7 +427,8 @@ def fillna(
425427
inplace = validate_bool_kwarg(inplace, "inplace")
426428

427429
mask = isna(self.values)
428-
mask = _extract_bool_array(mask)
430+
mask, noop = validate_putmask(self.values, mask)
431+
429432
if limit is not None:
430433
limit = libalgos.validate_limit(None, limit=limit)
431434
mask[mask.cumsum(self.ndim - 1) > limit] = False
@@ -442,8 +445,8 @@ def fillna(
442445
# TODO: should be nb._maybe_downcast?
443446
return self._maybe_downcast([nb], downcast)
444447

445-
# we can't process the value, but nothing to do
446-
if not mask.any():
448+
if noop:
449+
# we can't process the value, but nothing to do
447450
return [self] if inplace else [self.copy()]
448451

449452
# operate column-by-column
@@ -846,7 +849,7 @@ def _replace_list(
846849
# GH#38086 faster if we know we dont need to check for regex
847850
masks = [missing.mask_missing(self.values, s[0]) for s in pairs]
848851

849-
masks = [_extract_bool_array(x) for x in masks]
852+
masks = [extract_bool_array(x) for x in masks]
850853

851854
rb = [self if inplace else self.copy()]
852855
for i, (src, dest) in enumerate(pairs):
@@ -968,18 +971,8 @@ def setitem(self, indexer, value):
968971
# TODO(EA2D): special case not needed with 2D EA
969972
values[indexer] = value.to_numpy(values.dtype).reshape(-1, 1)
970973

971-
elif self.is_object and not is_ea_value and arr_value.dtype.kind in ["m", "M"]:
972-
# https://github.com/numpy/numpy/issues/12550
973-
# numpy will incorrect cast to int if we're not careful
974-
if is_list_like(value):
975-
value = list(value)
976-
else:
977-
value = [value] * len(values[indexer])
978-
979-
values[indexer] = value
980-
981974
else:
982-
975+
value = setitem_datetimelike_compat(values, len(values[indexer]), value)
983976
values[indexer] = value
984977

985978
if transpose:
@@ -1004,7 +997,7 @@ def putmask(self, mask, new) -> List[Block]:
1004997
List[Block]
1005998
"""
1006999
transpose = self.ndim == 2
1007-
mask = _extract_bool_array(mask)
1000+
mask, noop = validate_putmask(self.values.T, mask)
10081001
assert not isinstance(new, (ABCIndex, ABCSeries, ABCDataFrame))
10091002

10101003
new_values = self.values # delay copy if possible.
@@ -1020,7 +1013,7 @@ def putmask(self, mask, new) -> List[Block]:
10201013
putmask_without_repeat(new_values, mask, new)
10211014
return [self]
10221015

1023-
elif not mask.any():
1016+
elif noop:
10241017
return [self]
10251018

10261019
dtype, _ = infer_dtype_from(new)
@@ -1296,12 +1289,13 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
12961289
if transpose:
12971290
values = values.T
12981291

1299-
cond = _extract_bool_array(cond)
1292+
icond, noop = validate_putmask(values, ~cond)
13001293

13011294
if is_valid_na_for_dtype(other, self.dtype) and not self.is_object:
13021295
other = self.fill_value
13031296

1304-
if cond.ravel("K").all():
1297+
if noop:
1298+
# TODO: avoid the downcasting at the end in this case?
13051299
result = values
13061300
else:
13071301
# see if we can operate on the entire block, or need item-by-item
@@ -1313,23 +1307,14 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
13131307
blocks = block.where(orig_other, cond, errors=errors, axis=axis)
13141308
return self._maybe_downcast(blocks, "infer")
13151309

1316-
dtype, _ = infer_dtype_from(other, pandas_dtype=True)
1317-
if dtype.kind in ["m", "M"] and dtype.kind != values.dtype.kind:
1318-
# expressions.where would cast np.timedelta64 to int
1319-
if not is_list_like(other):
1320-
other = [other] * (~cond).sum()
1321-
else:
1322-
other = list(other)
1310+
alt = setitem_datetimelike_compat(values, icond.sum(), other)
1311+
if alt is not other:
13231312
result = values.copy()
1324-
np.putmask(result, ~cond, other)
1325-
1313+
np.putmask(result, icond, alt)
13261314
else:
1327-
# convert datetime to datetime64, timedelta to timedelta64
1328-
other = convert_scalar_for_putitemlike(other, values.dtype)
1329-
13301315
# By the time we get here, we should have all Series/Index
13311316
# args extracted to ndarray
1332-
result = expressions.where(cond, values, other)
1317+
result = expressions.where(~icond, values, other)
13331318

13341319
if self._can_hold_na or self.ndim == 1:
13351320

@@ -1339,6 +1324,7 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
13391324
return [self.make_block(result)]
13401325

13411326
# might need to separate out blocks
1327+
cond = ~icond
13421328
axis = cond.ndim - 1
13431329
cond = cond.swapaxes(axis, 0)
13441330
mask = np.array([cond[i].all() for i in range(cond.shape[0])], dtype=bool)
@@ -1545,7 +1531,7 @@ def putmask(self, mask, new) -> List[Block]:
15451531
"""
15461532
See Block.putmask.__doc__
15471533
"""
1548-
mask = _extract_bool_array(mask)
1534+
mask = extract_bool_array(mask)
15491535

15501536
new_values = self.values
15511537

@@ -1775,7 +1761,7 @@ def shift(self, periods: int, axis: int = 0, fill_value: Any = None) -> List[Blo
17751761

17761762
def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
17771763

1778-
cond = _extract_bool_array(cond)
1764+
cond = extract_bool_array(cond)
17791765
assert not isinstance(other, (ABCIndex, ABCSeries, ABCDataFrame))
17801766

17811767
if isinstance(other, np.ndarray) and other.ndim == 2:
@@ -2019,7 +2005,7 @@ def to_native_types(self, na_rep="NaT", **kwargs):
20192005
return self.make_block(result)
20202006

20212007
def putmask(self, mask, new) -> List[Block]:
2022-
mask = _extract_bool_array(mask)
2008+
mask = extract_bool_array(mask)
20232009

20242010
if not self._can_hold_element(new):
20252011
return self.astype(object).putmask(mask, new)
@@ -2034,7 +2020,7 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
20342020
# TODO(EA2D): reshape unnecessary with 2D EAs
20352021
arr = self.array_values().reshape(self.shape)
20362022

2037-
cond = _extract_bool_array(cond)
2023+
cond = extract_bool_array(cond)
20382024

20392025
try:
20402026
res_values = arr.T.where(cond, other).T
@@ -2513,18 +2499,3 @@ def safe_reshape(arr: ArrayLike, new_shape: Shape) -> ArrayLike:
25132499
# TODO(EA2D): special case will be unnecessary with 2D EAs
25142500
arr = np.asarray(arr).reshape(new_shape)
25152501
return arr
2516-
2517-
2518-
def _extract_bool_array(mask: ArrayLike) -> np.ndarray:
2519-
"""
2520-
If we have a SparseArray or BooleanArray, convert it to ndarray[bool].
2521-
"""
2522-
if isinstance(mask, ExtensionArray):
2523-
# We could have BooleanArray, Sparse[bool], ...
2524-
# Except for BooleanArray, this is equivalent to just
2525-
# np.asarray(mask, dtype=bool)
2526-
mask = mask.to_numpy(dtype=bool, na_value=False)
2527-
2528-
assert isinstance(mask, np.ndarray), type(mask)
2529-
assert mask.dtype == bool, mask.dtype
2530-
return mask

0 commit comments

Comments
 (0)