Skip to content

REF: de-duplicate dt64/td64 putmask/setitem shims #39778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions pandas/core/array_algos/putmask.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
"""
EA-compatible analogue to to np.putmask
"""
from typing import Any
from typing import Any, Tuple
import warnings

import numpy as np

from pandas._libs import lib
from pandas._typing import ArrayLike

from pandas.core.dtypes.cast import convert_scalar_for_putitemlike, find_common_type
from pandas.core.dtypes.cast import (
convert_scalar_for_putitemlike,
find_common_type,
infer_dtype_from,
)
from pandas.core.dtypes.common import is_float_dtype, is_integer_dtype, is_list_like
from pandas.core.dtypes.missing import isna_compat

from pandas.core.arrays import ExtensionArray


def putmask_inplace(values: ArrayLike, mask: np.ndarray, value: Any) -> None:
"""
Expand All @@ -22,7 +28,7 @@ def putmask_inplace(values: ArrayLike, mask: np.ndarray, value: Any) -> None:
Parameters
----------
mask : np.ndarray[bool]
We assume _extract_bool_array has already been called.
We assume extract_bool_array has already been called.
value : Any
"""

Expand Down Expand Up @@ -152,3 +158,52 @@ def putmask_without_repeat(values: np.ndarray, mask: np.ndarray, new: Any) -> No
raise ValueError("cannot assign mismatch length to masked array")
else:
np.putmask(values, mask, new)


def validate_putmask(values: ArrayLike, mask: np.ndarray) -> Tuple[np.ndarray, bool]:
"""
Validate mask and check if this putmask operation is a no-op.
"""
mask = extract_bool_array(mask)
if mask.shape != values.shape:
raise ValueError("putmask: mask and data must be the same size")

noop = not mask.any()
return mask, noop


def extract_bool_array(mask: ArrayLike) -> np.ndarray:
"""
If we have a SparseArray or BooleanArray, convert it to ndarray[bool].
"""
if isinstance(mask, ExtensionArray):
# We could have BooleanArray, Sparse[bool], ...
# Except for BooleanArray, this is equivalent to just
# np.asarray(mask, dtype=bool)
mask = mask.to_numpy(dtype=bool, na_value=False)

mask = np.asarray(mask, dtype=bool)
return mask


def setitem_datetimelike_compat(values: np.ndarray, num_set: int, other):
"""
Parameters
----------
values : np.ndarray
num_set : int
For putmask, this is mask.sum()
other : Any
"""
if values.dtype == object:
dtype, _ = infer_dtype_from(other, pandas_dtype=True)

if isinstance(dtype, np.dtype) and dtype.kind in ["m", "M"]:
# https://github.com/numpy/numpy/issues/12550
# timedelta64 will incorrectly cast to int
if not is_list_like(other):
other = [other] * num_set
else:
other = list(other)

return other
29 changes: 13 additions & 16 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@
from pandas.core import missing, ops
from pandas.core.accessor import CachedAccessor
import pandas.core.algorithms as algos
from pandas.core.array_algos.putmask import (
setitem_datetimelike_compat,
validate_putmask,
)
from pandas.core.arrays import Categorical, ExtensionArray
from pandas.core.arrays.datetimes import tz_to_dtype, validate_tz_from_dtype
from pandas.core.arrays.sparse import SparseDtype
Expand Down Expand Up @@ -4274,6 +4278,7 @@ def memory_usage(self, deep: bool = False) -> int:
result += self._engine.sizeof(deep=deep)
return result

@final
def where(self, cond, other=None):
"""
Replace values where the condition is False.
Expand Down Expand Up @@ -4306,6 +4311,10 @@ def where(self, cond, other=None):
>>> idx.where(idx.isin(['car', 'train']), 'other')
Index(['car', 'other', 'train', 'other'], dtype='object')
"""
if isinstance(self, ABCMultiIndex):
raise NotImplementedError(
".where is not supported for MultiIndex operations"
)
cond = np.asarray(cond, dtype=bool)
return self.putmask(~cond, other)

Expand Down Expand Up @@ -4522,10 +4531,8 @@ def putmask(self, mask, value):
numpy.ndarray.putmask : Changes elements of an array
based on conditional and input values.
"""
mask = np.asarray(mask, dtype=bool)
if mask.shape != self.shape:
raise ValueError("putmask: mask and data must be the same size")
if not mask.any():
mask, noop = validate_putmask(self._values, mask)
if noop:
return self.copy()

if value is None and (self._is_numeric_dtype or self.dtype == object):
Expand All @@ -4540,18 +4547,8 @@ def putmask(self, mask, value):
return self.astype(dtype).putmask(mask, value)

values = self._values.copy()
dtype, _ = infer_dtype_from(converted, pandas_dtype=True)
if dtype.kind in ["m", "M"]:
# https://github.com/numpy/numpy/issues/12550
# timedelta64 will incorrectly cast to int
if not is_list_like(converted):
converted = [converted] * mask.sum()
values[mask] = converted
else:
converted = list(converted)
np.putmask(values, mask, converted)
else:
np.putmask(values, mask, converted)
converted = setitem_datetimelike_compat(values, mask.sum(), converted)
np.putmask(values, mask, converted)

return type(self)._simple_new(values, name=self.name)

Expand Down
7 changes: 3 additions & 4 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from pandas.core.dtypes.dtypes import IntervalDtype

from pandas.core.algorithms import take_nd, unique
from pandas.core.array_algos.putmask import validate_putmask
from pandas.core.arrays.interval import IntervalArray, _interval_shared_docs
import pandas.core.common as com
from pandas.core.indexers import is_valid_positional_slice
Expand Down Expand Up @@ -799,10 +800,8 @@ def length(self):
return Index(self._data.length, copy=False)

def putmask(self, mask, value):
mask = np.asarray(mask, dtype=bool)
if mask.shape != self.shape:
raise ValueError("putmask: mask and data must be the same size")
if not mask.any():
mask, noop = validate_putmask(self._data, mask)
if noop:
return self.copy()

try:
Expand Down
3 changes: 0 additions & 3 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2151,9 +2151,6 @@ def repeat(self, repeats: int, axis=None) -> MultiIndex:
verify_integrity=False,
)

def where(self, cond, other=None):
raise NotImplementedError(".where is not supported for MultiIndex operations")

def drop(self, codes, level=None, errors="raise"):
"""
Make new MultiIndex with passed list of codes deleted
Expand Down
75 changes: 23 additions & 52 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
astype_dt64_to_dt64tz,
astype_nansafe,
can_hold_element,
convert_scalar_for_putitemlike,
find_common_type,
infer_dtype_from,
maybe_downcast_numeric,
Expand Down Expand Up @@ -52,9 +51,12 @@

import pandas.core.algorithms as algos
from pandas.core.array_algos.putmask import (
extract_bool_array,
putmask_inplace,
putmask_smart,
putmask_without_repeat,
setitem_datetimelike_compat,
validate_putmask,
)
from pandas.core.array_algos.quantile import quantile_with_mask
from pandas.core.array_algos.replace import (
Expand Down Expand Up @@ -425,7 +427,8 @@ def fillna(
inplace = validate_bool_kwarg(inplace, "inplace")

mask = isna(self.values)
mask = _extract_bool_array(mask)
mask, noop = validate_putmask(self.values, mask)

if limit is not None:
limit = libalgos.validate_limit(None, limit=limit)
mask[mask.cumsum(self.ndim - 1) > limit] = False
Expand All @@ -442,8 +445,8 @@ def fillna(
# TODO: should be nb._maybe_downcast?
return self._maybe_downcast([nb], downcast)

# we can't process the value, but nothing to do
if not mask.any():
if noop:
# we can't process the value, but nothing to do
return [self] if inplace else [self.copy()]

# operate column-by-column
Expand Down Expand Up @@ -846,7 +849,7 @@ def _replace_list(
# GH#38086 faster if we know we dont need to check for regex
masks = [missing.mask_missing(self.values, s[0]) for s in pairs]

masks = [_extract_bool_array(x) for x in masks]
masks = [extract_bool_array(x) for x in masks]

rb = [self if inplace else self.copy()]
for i, (src, dest) in enumerate(pairs):
Expand Down Expand Up @@ -968,18 +971,8 @@ def setitem(self, indexer, value):
# TODO(EA2D): special case not needed with 2D EA
values[indexer] = value.to_numpy(values.dtype).reshape(-1, 1)

elif self.is_object and not is_ea_value and arr_value.dtype.kind in ["m", "M"]:
# https://github.com/numpy/numpy/issues/12550
# numpy will incorrect cast to int if we're not careful
if is_list_like(value):
value = list(value)
else:
value = [value] * len(values[indexer])

values[indexer] = value

else:

value = setitem_datetimelike_compat(values, len(values[indexer]), value)
values[indexer] = value

if transpose:
Expand All @@ -1004,7 +997,7 @@ def putmask(self, mask, new) -> List[Block]:
List[Block]
"""
transpose = self.ndim == 2
mask = _extract_bool_array(mask)
mask, noop = validate_putmask(self.values.T, mask)
assert not isinstance(new, (ABCIndex, ABCSeries, ABCDataFrame))

new_values = self.values # delay copy if possible.
Expand All @@ -1020,7 +1013,7 @@ def putmask(self, mask, new) -> List[Block]:
putmask_without_repeat(new_values, mask, new)
return [self]

elif not mask.any():
elif noop:
return [self]

dtype, _ = infer_dtype_from(new)
Expand Down Expand Up @@ -1296,12 +1289,13 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
if transpose:
values = values.T

cond = _extract_bool_array(cond)
icond, noop = validate_putmask(values, ~cond)

if is_valid_na_for_dtype(other, self.dtype) and not self.is_object:
other = self.fill_value

if cond.ravel("K").all():
if noop:
# TODO: avoid the downcasting at the end in this case?
result = values
else:
# see if we can operate on the entire block, or need item-by-item
Expand All @@ -1313,23 +1307,14 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
blocks = block.where(orig_other, cond, errors=errors, axis=axis)
return self._maybe_downcast(blocks, "infer")

dtype, _ = infer_dtype_from(other, pandas_dtype=True)
if dtype.kind in ["m", "M"] and dtype.kind != values.dtype.kind:
# expressions.where would cast np.timedelta64 to int
if not is_list_like(other):
other = [other] * (~cond).sum()
else:
other = list(other)
alt = setitem_datetimelike_compat(values, icond.sum(), other)
if alt is not other:
result = values.copy()
np.putmask(result, ~cond, other)

np.putmask(result, icond, alt)
else:
# convert datetime to datetime64, timedelta to timedelta64
other = convert_scalar_for_putitemlike(other, values.dtype)

# By the time we get here, we should have all Series/Index
# args extracted to ndarray
result = expressions.where(cond, values, other)
result = expressions.where(~icond, values, other)

if self._can_hold_na or self.ndim == 1:

Expand All @@ -1339,6 +1324,7 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
return [self.make_block(result)]

# might need to separate out blocks
cond = ~icond
axis = cond.ndim - 1
cond = cond.swapaxes(axis, 0)
mask = np.array([cond[i].all() for i in range(cond.shape[0])], dtype=bool)
Expand Down Expand Up @@ -1545,7 +1531,7 @@ def putmask(self, mask, new) -> List[Block]:
"""
See Block.putmask.__doc__
"""
mask = _extract_bool_array(mask)
mask = extract_bool_array(mask)

new_values = self.values

Expand Down Expand Up @@ -1775,7 +1761,7 @@ def shift(self, periods: int, axis: int = 0, fill_value: Any = None) -> List[Blo

def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:

cond = _extract_bool_array(cond)
cond = extract_bool_array(cond)
assert not isinstance(other, (ABCIndex, ABCSeries, ABCDataFrame))

if isinstance(other, np.ndarray) and other.ndim == 2:
Expand Down Expand Up @@ -2019,7 +2005,7 @@ def to_native_types(self, na_rep="NaT", **kwargs):
return self.make_block(result)

def putmask(self, mask, new) -> List[Block]:
mask = _extract_bool_array(mask)
mask = extract_bool_array(mask)

if not self._can_hold_element(new):
return self.astype(object).putmask(mask, new)
Expand All @@ -2034,7 +2020,7 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
# TODO(EA2D): reshape unnecessary with 2D EAs
arr = self.array_values().reshape(self.shape)

cond = _extract_bool_array(cond)
cond = extract_bool_array(cond)

try:
res_values = arr.T.where(cond, other).T
Expand Down Expand Up @@ -2513,18 +2499,3 @@ def safe_reshape(arr: ArrayLike, new_shape: Shape) -> ArrayLike:
# TODO(EA2D): special case will be unnecessary with 2D EAs
arr = np.asarray(arr).reshape(new_shape)
return arr


def _extract_bool_array(mask: ArrayLike) -> np.ndarray:
"""
If we have a SparseArray or BooleanArray, convert it to ndarray[bool].
"""
if isinstance(mask, ExtensionArray):
# We could have BooleanArray, Sparse[bool], ...
# Except for BooleanArray, this is equivalent to just
# np.asarray(mask, dtype=bool)
mask = mask.to_numpy(dtype=bool, na_value=False)

assert isinstance(mask, np.ndarray), type(mask)
assert mask.dtype == bool, mask.dtype
return mask