Skip to content

REF: avoid upcast/downcast in Block.where #45582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 30, 2022
16 changes: 15 additions & 1 deletion pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
DtypeObj,
Scalar,
)
from pandas.compat import np_version_under1p20
from pandas.errors import IntCastingNaNError
from pandas.util._exceptions import find_stack_level
from pandas.util._validators import validate_bool_kwarg
Expand Down Expand Up @@ -87,6 +88,7 @@
)
from pandas.core.dtypes.inference import is_list_like
from pandas.core.dtypes.missing import (
array_equivalent,
is_valid_na_for_dtype,
isna,
na_value_for_dtype,
Expand Down Expand Up @@ -1961,7 +1963,7 @@ def np_can_hold_element(dtype: np.dtype, element: Any) -> Any:
# in smaller int dtypes.
info = np.iinfo(dtype)
if info.min <= element <= info.max:
return element
return dtype.type(element)
raise ValueError

if tipo is not None:
Expand Down Expand Up @@ -2017,6 +2019,18 @@ def np_can_hold_element(dtype: np.dtype, element: Any) -> Any:
if element._hasna:
raise ValueError
return element
elif tipo.itemsize > dtype.itemsize:
if isinstance(element, np.ndarray):
# e.g. TestDataFrameIndexingWhere::test_where_alignment
casted = element.astype(dtype)
if np_version_under1p20:
if array_equivalent(casted, element):
return casted
else:
if np.array_equal(casted, element, equal_nan=True):
return casted
raise ValueError

return element

if lib.is_integer(element) or lib.is_float(element):
Expand Down
45 changes: 13 additions & 32 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
from pandas.core.dtypes.cast import (
can_hold_element,
find_result_type,
maybe_downcast_numeric,
maybe_downcast_to_dtype,
np_can_hold_element,
soft_convert_objects,
)
from pandas.core.dtypes.common import (
Expand Down Expand Up @@ -1190,13 +1190,19 @@ def where(self, other, cond) -> list[Block]:

other = self._standardize_fill_value(other)

if not self._can_hold_element(other):
try:
# try/except here is equivalent to a self._can_hold_element check,
# but this gets us back 'casted' which we will re-use below;
# without using 'casted', expressions.where may do unwanted upcasts.
casted = np_can_hold_element(values.dtype, other)
except (ValueError, TypeError):
# we cannot coerce, return a compat dtype
block = self.coerce_to_target_dtype(other)
blocks = block.where(orig_other, cond)
return self._maybe_downcast(blocks, "infer")

else:
other = casted
alt = setitem_datetimelike_compat(values, icond.sum(), other)
if alt is not other:
if is_list_like(other) and len(other) < len(values):
Expand Down Expand Up @@ -1226,38 +1232,13 @@ def where(self, other, cond) -> list[Block]:

# Note: expressions.where may upcast.
result = expressions.where(~icond, values, other)
# The np_can_hold_element check _should_ ensure that we always
# have result.dtype == self.dtype here.

if self._can_hold_na or self.ndim == 1:

if transpose:
result = result.T

return [self.make_block(result)]

# might need to separate out blocks
cond = ~icond
axis = cond.ndim - 1
cond = cond.swapaxes(axis, 0)
mask = cond.all(axis=1)

result_blocks: list[Block] = []
for m in [mask, ~mask]:
if m.any():
taken = result.take(m.nonzero()[0], axis=axis)
r = maybe_downcast_numeric(taken, self.dtype)
if r.dtype != taken.dtype:
warnings.warn(
"Downcasting integer-dtype results in .where is "
"deprecated and will change in a future version. "
"To retain the old behavior, explicitly cast the results "
"to the desired dtype.",
FutureWarning,
stacklevel=find_stack_level(),
)
nb = self.make_block(r.T, placement=self._mgr_locs[m])
result_blocks.append(nb)
if transpose:
result = result.T

return result_blocks
return [self.make_block(result)]

def _unstack(
self,
Expand Down
29 changes: 12 additions & 17 deletions pandas/tests/frame/indexing/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,7 @@ def _check_align(df, cond, other, check_dtypes=True):

# check other is ndarray
cond = df > 0
warn = None
if df is mixed_int_frame:
warn = FutureWarning
with tm.assert_produces_warning(warn, match="Downcasting integer-dtype"):
_check_align(df, cond, (_safe_add(df).values))
_check_align(df, cond, (_safe_add(df).values))

# integers are upcast, so don't check the dtypes
cond = df > 0
Expand Down Expand Up @@ -469,44 +465,43 @@ def test_where_axis(self, using_array_manager):
# GH 9736
df = DataFrame(np.random.randn(2, 2))
mask = DataFrame([[False, False], [False, False]])
s = Series([0, 1])
ser = Series([0, 1])

expected = DataFrame([[0, 0], [1, 1]], dtype="float64")
result = df.where(mask, s, axis="index")
result = df.where(mask, ser, axis="index")
tm.assert_frame_equal(result, expected)

result = df.copy()
return_value = result.where(mask, s, axis="index", inplace=True)
return_value = result.where(mask, ser, axis="index", inplace=True)
assert return_value is None
tm.assert_frame_equal(result, expected)

expected = DataFrame([[0, 1], [0, 1]], dtype="float64")
result = df.where(mask, s, axis="columns")
result = df.where(mask, ser, axis="columns")
tm.assert_frame_equal(result, expected)

result = df.copy()
return_value = result.where(mask, s, axis="columns", inplace=True)
return_value = result.where(mask, ser, axis="columns", inplace=True)
assert return_value is None
tm.assert_frame_equal(result, expected)

def test_where_axis_with_upcast(self):
# Upcast needed
df = DataFrame([[1, 2], [3, 4]], dtype="int64")
mask = DataFrame([[False, False], [False, False]])
s = Series([0, np.nan])
ser = Series([0, np.nan])

expected = DataFrame([[0, 0], [np.nan, np.nan]], dtype="float64")
result = df.where(mask, s, axis="index")
result = df.where(mask, ser, axis="index")
tm.assert_frame_equal(result, expected)

result = df.copy()
return_value = result.where(mask, s, axis="index", inplace=True)
return_value = result.where(mask, ser, axis="index", inplace=True)
assert return_value is None
tm.assert_frame_equal(result, expected)

warn = FutureWarning if using_array_manager else None
expected = DataFrame([[0, np.nan], [0, np.nan]])
with tm.assert_produces_warning(warn, match="Downcasting integer-dtype"):
result = df.where(mask, s, axis="columns")
result = df.where(mask, ser, axis="columns")
tm.assert_frame_equal(result, expected)

expected = DataFrame(
Expand All @@ -516,7 +511,7 @@ def test_where_axis(self, using_array_manager):
}
)
result = df.copy()
return_value = result.where(mask, s, axis="columns", inplace=True)
return_value = result.where(mask, ser, axis="columns", inplace=True)
assert return_value is None
tm.assert_frame_equal(result, expected)

Expand Down
6 changes: 2 additions & 4 deletions pandas/tests/frame/methods/test_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_clip_against_unordered_columns(self):
tm.assert_frame_equal(result_lower, expected_lower)
tm.assert_frame_equal(result_lower_upper, expected_lower_upper)

def test_clip_with_na_args(self, float_frame, using_array_manager):
def test_clip_with_na_args(self, float_frame):
"""Should process np.nan argument as None"""
# GH#17276
tm.assert_frame_equal(float_frame.clip(np.nan), float_frame)
Expand All @@ -151,9 +151,7 @@ def test_clip_with_na_args(self, float_frame, using_array_manager):
)
tm.assert_frame_equal(result, expected)

warn = FutureWarning if using_array_manager else None
with tm.assert_produces_warning(warn, match="Downcasting integer-dtype"):
result = df.clip(lower=[4, 5, np.nan], axis=1)
result = df.clip(lower=[4, 5, np.nan], axis=1)
expected = DataFrame(
{"col_0": [4, 4, 4], "col_1": [5, 5, 6], "col_2": [7, 8, 9]}
)
Expand Down