Skip to content

Commit 7f3ac91

Browse files
jbrockmendelphofl
authored andcommitted
REF: avoid upcast/downcast in Block.where (pandas-dev#45582)
1 parent 471e4b3 commit 7f3ac91

File tree

4 files changed

+38
-54
lines changed

4 files changed

+38
-54
lines changed

pandas/core/dtypes/cast.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
)
8888
from pandas.core.dtypes.inference import is_list_like
8989
from pandas.core.dtypes.missing import (
90+
array_equivalent,
9091
is_valid_na_for_dtype,
9192
isna,
9293
na_value_for_dtype,
@@ -1970,7 +1971,7 @@ def np_can_hold_element(dtype: np.dtype, element: Any) -> Any:
19701971
# in smaller int dtypes.
19711972
info = np.iinfo(dtype)
19721973
if info.min <= element <= info.max:
1973-
return element
1974+
return dtype.type(element)
19741975
raise ValueError
19751976

19761977
if tipo is not None:
@@ -2026,6 +2027,15 @@ def np_can_hold_element(dtype: np.dtype, element: Any) -> Any:
20262027
if element._hasna:
20272028
raise ValueError
20282029
return element
2030+
elif tipo.itemsize > dtype.itemsize:
2031+
if isinstance(element, np.ndarray):
2032+
# e.g. TestDataFrameIndexingWhere::test_where_alignment
2033+
casted = element.astype(dtype)
2034+
# TODO(np>=1.20): we can just use np.array_equal with equal_nan
2035+
if array_equivalent(casted, element):
2036+
return casted
2037+
raise ValueError
2038+
20292039
return element
20302040

20312041
if lib.is_integer(element) or lib.is_float(element):

pandas/core/internals/blocks.py

+13-32
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838
from pandas.core.dtypes.cast import (
3939
can_hold_element,
4040
find_result_type,
41-
maybe_downcast_numeric,
4241
maybe_downcast_to_dtype,
42+
np_can_hold_element,
4343
soft_convert_objects,
4444
)
4545
from pandas.core.dtypes.common import (
@@ -1186,13 +1186,19 @@ def where(self, other, cond) -> list[Block]:
11861186

11871187
other = self._standardize_fill_value(other)
11881188

1189-
if not self._can_hold_element(other):
1189+
try:
1190+
# try/except here is equivalent to a self._can_hold_element check,
1191+
# but this gets us back 'casted' which we will re-use below;
1192+
# without using 'casted', expressions.where may do unwanted upcasts.
1193+
casted = np_can_hold_element(values.dtype, other)
1194+
except (ValueError, TypeError):
11901195
# we cannot coerce, return a compat dtype
11911196
block = self.coerce_to_target_dtype(other)
11921197
blocks = block.where(orig_other, cond)
11931198
return self._maybe_downcast(blocks, "infer")
11941199

11951200
else:
1201+
other = casted
11961202
alt = setitem_datetimelike_compat(values, icond.sum(), other)
11971203
if alt is not other:
11981204
if is_list_like(other) and len(other) < len(values):
@@ -1222,38 +1228,13 @@ def where(self, other, cond) -> list[Block]:
12221228

12231229
# Note: expressions.where may upcast.
12241230
result = expressions.where(~icond, values, other)
1231+
# The np_can_hold_element check _should_ ensure that we always
1232+
# have result.dtype == self.dtype here.
12251233

1226-
if self._can_hold_na or self.ndim == 1:
1227-
1228-
if transpose:
1229-
result = result.T
1230-
1231-
return [self.make_block(result)]
1232-
1233-
# might need to separate out blocks
1234-
cond = ~icond
1235-
axis = cond.ndim - 1
1236-
cond = cond.swapaxes(axis, 0)
1237-
mask = cond.all(axis=1)
1238-
1239-
result_blocks: list[Block] = []
1240-
for m in [mask, ~mask]:
1241-
if m.any():
1242-
taken = result.take(m.nonzero()[0], axis=axis)
1243-
r = maybe_downcast_numeric(taken, self.dtype)
1244-
if r.dtype != taken.dtype:
1245-
warnings.warn(
1246-
"Downcasting integer-dtype results in .where is "
1247-
"deprecated and will change in a future version. "
1248-
"To retain the old behavior, explicitly cast the results "
1249-
"to the desired dtype.",
1250-
FutureWarning,
1251-
stacklevel=find_stack_level(),
1252-
)
1253-
nb = self.make_block(r.T, placement=self._mgr_locs[m])
1254-
result_blocks.append(nb)
1234+
if transpose:
1235+
result = result.T
12551236

1256-
return result_blocks
1237+
return [self.make_block(result)]
12571238

12581239
def _unstack(
12591240
self,

pandas/tests/frame/indexing/test_where.py

+12-17
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,7 @@ def _check_align(df, cond, other, check_dtypes=True):
141141

142142
# check other is ndarray
143143
cond = df > 0
144-
warn = None
145-
if df is mixed_int_frame:
146-
warn = FutureWarning
147-
with tm.assert_produces_warning(warn, match="Downcasting integer-dtype"):
148-
_check_align(df, cond, (_safe_add(df).values))
144+
_check_align(df, cond, (_safe_add(df).values))
149145

150146
# integers are upcast, so don't check the dtypes
151147
cond = df > 0
@@ -469,44 +465,43 @@ def test_where_axis(self, using_array_manager):
469465
# GH 9736
470466
df = DataFrame(np.random.randn(2, 2))
471467
mask = DataFrame([[False, False], [False, False]])
472-
s = Series([0, 1])
468+
ser = Series([0, 1])
473469

474470
expected = DataFrame([[0, 0], [1, 1]], dtype="float64")
475-
result = df.where(mask, s, axis="index")
471+
result = df.where(mask, ser, axis="index")
476472
tm.assert_frame_equal(result, expected)
477473

478474
result = df.copy()
479-
return_value = result.where(mask, s, axis="index", inplace=True)
475+
return_value = result.where(mask, ser, axis="index", inplace=True)
480476
assert return_value is None
481477
tm.assert_frame_equal(result, expected)
482478

483479
expected = DataFrame([[0, 1], [0, 1]], dtype="float64")
484-
result = df.where(mask, s, axis="columns")
480+
result = df.where(mask, ser, axis="columns")
485481
tm.assert_frame_equal(result, expected)
486482

487483
result = df.copy()
488-
return_value = result.where(mask, s, axis="columns", inplace=True)
484+
return_value = result.where(mask, ser, axis="columns", inplace=True)
489485
assert return_value is None
490486
tm.assert_frame_equal(result, expected)
491487

488+
def test_where_axis_with_upcast(self):
492489
# Upcast needed
493490
df = DataFrame([[1, 2], [3, 4]], dtype="int64")
494491
mask = DataFrame([[False, False], [False, False]])
495-
s = Series([0, np.nan])
492+
ser = Series([0, np.nan])
496493

497494
expected = DataFrame([[0, 0], [np.nan, np.nan]], dtype="float64")
498-
result = df.where(mask, s, axis="index")
495+
result = df.where(mask, ser, axis="index")
499496
tm.assert_frame_equal(result, expected)
500497

501498
result = df.copy()
502-
return_value = result.where(mask, s, axis="index", inplace=True)
499+
return_value = result.where(mask, ser, axis="index", inplace=True)
503500
assert return_value is None
504501
tm.assert_frame_equal(result, expected)
505502

506-
warn = FutureWarning if using_array_manager else None
507503
expected = DataFrame([[0, np.nan], [0, np.nan]])
508-
with tm.assert_produces_warning(warn, match="Downcasting integer-dtype"):
509-
result = df.where(mask, s, axis="columns")
504+
result = df.where(mask, ser, axis="columns")
510505
tm.assert_frame_equal(result, expected)
511506

512507
expected = DataFrame(
@@ -516,7 +511,7 @@ def test_where_axis(self, using_array_manager):
516511
}
517512
)
518513
result = df.copy()
519-
return_value = result.where(mask, s, axis="columns", inplace=True)
514+
return_value = result.where(mask, ser, axis="columns", inplace=True)
520515
assert return_value is None
521516
tm.assert_frame_equal(result, expected)
522517

pandas/tests/frame/methods/test_clip.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def test_clip_against_unordered_columns(self):
136136
tm.assert_frame_equal(result_lower, expected_lower)
137137
tm.assert_frame_equal(result_lower_upper, expected_lower_upper)
138138

139-
def test_clip_with_na_args(self, float_frame, using_array_manager):
139+
def test_clip_with_na_args(self, float_frame):
140140
"""Should process np.nan argument as None"""
141141
# GH#17276
142142
tm.assert_frame_equal(float_frame.clip(np.nan), float_frame)
@@ -151,9 +151,7 @@ def test_clip_with_na_args(self, float_frame, using_array_manager):
151151
)
152152
tm.assert_frame_equal(result, expected)
153153

154-
warn = FutureWarning if using_array_manager else None
155-
with tm.assert_produces_warning(warn, match="Downcasting integer-dtype"):
156-
result = df.clip(lower=[4, 5, np.nan], axis=1)
154+
result = df.clip(lower=[4, 5, np.nan], axis=1)
157155
expected = DataFrame(
158156
{"col_0": [4, 4, 4], "col_1": [5, 5, 6], "col_2": [7, 8, 9]}
159157
)

0 commit comments

Comments
 (0)