Skip to content

Commit ede6234

Browse files
authored
REF: simplify Block.where (and subtle alignment bug) (#44691)
1 parent bd6eb7e commit ede6234

File tree

1 file changed

+30
-28
lines changed

1 file changed

+30
-28
lines changed

pandas/core/internals/blocks.py

+30-28
Original file line numberDiff line numberDiff line change
@@ -1139,17 +1139,11 @@ def shift(self, periods: int, axis: int = 0, fill_value: Any = None) -> list[Blo
11391139
# convert integer to float if necessary. need to do a lot more than
11401140
# that, handle boolean etc also
11411141

1142-
# error: Value of type variable "NumpyArrayT" of "maybe_upcast" cannot be
1143-
# "Union[ndarray[Any, Any], ExtensionArray]"
1144-
new_values, fill_value = maybe_upcast(
1145-
self.values, fill_value # type: ignore[type-var]
1146-
)
1142+
values = cast(np.ndarray, self.values)
11471143

1148-
# error: Argument 1 to "shift" has incompatible type "Union[ndarray[Any, Any],
1149-
# ExtensionArray]"; expected "ndarray[Any, Any]"
1150-
new_values = shift(
1151-
new_values, periods, axis, fill_value # type: ignore[arg-type]
1152-
)
1144+
new_values, fill_value = maybe_upcast(values, fill_value)
1145+
1146+
new_values = shift(new_values, periods, axis, fill_value)
11531147

11541148
return [self.make_block(new_values)]
11551149

@@ -1171,7 +1165,8 @@ def where(self, other, cond) -> list[Block]:
11711165

11721166
transpose = self.ndim == 2
11731167

1174-
values = self.values
1168+
# EABlocks override where
1169+
values = cast(np.ndarray, self.values)
11751170
orig_other = other
11761171
if transpose:
11771172
values = values.T
@@ -1185,22 +1180,15 @@ def where(self, other, cond) -> list[Block]:
11851180
# TODO: avoid the downcasting at the end in this case?
11861181
# GH-39595: Always return a copy
11871182
result = values.copy()
1183+
1184+
elif not self._can_hold_element(other):
1185+
# we cannot coerce, return a compat dtype
1186+
block = self.coerce_to_target_dtype(other)
1187+
blocks = block.where(orig_other, cond)
1188+
return self._maybe_downcast(blocks, "infer")
1189+
11881190
else:
1189-
# see if we can operate on the entire block, or need item-by-item
1190-
# or if we are a single block (ndim == 1)
1191-
if not self._can_hold_element(other):
1192-
# we cannot coerce, return a compat dtype
1193-
block = self.coerce_to_target_dtype(other)
1194-
blocks = block.where(orig_other, cond)
1195-
return self._maybe_downcast(blocks, "infer")
1196-
1197-
# error: Argument 1 to "setitem_datetimelike_compat" has incompatible type
1198-
# "Union[ndarray, ExtensionArray]"; expected "ndarray"
1199-
# error: Argument 2 to "setitem_datetimelike_compat" has incompatible type
1200-
# "number[Any]"; expected "int"
1201-
alt = setitem_datetimelike_compat(
1202-
values, icond.sum(), other # type: ignore[arg-type]
1203-
)
1191+
alt = setitem_datetimelike_compat(values, icond.sum(), other)
12041192
if alt is not other:
12051193
if is_list_like(other) and len(other) < len(values):
12061194
# call np.where with other to get the appropriate ValueError
@@ -1215,6 +1203,19 @@ def where(self, other, cond) -> list[Block]:
12151203
else:
12161204
# By the time we get here, we should have all Series/Index
12171205
# args extracted to ndarray
1206+
if (
1207+
is_list_like(other)
1208+
and not isinstance(other, np.ndarray)
1209+
and len(other) == self.shape[-1]
1210+
):
1211+
# If we don't do this broadcasting here, then expressions.where
1212+
# will broadcast a 1D other to be row-like instead of
1213+
# column-like.
1214+
other = np.array(other).reshape(values.shape)
1215+
# If lengths don't match (or len(other)==1), we will raise
1216+
# inside expressions.where, see test_series_where
1217+
1218+
# Note: expressions.where may upcast.
12181219
result = expressions.where(~icond, values, other)
12191220

12201221
if self._can_hold_na or self.ndim == 1:
@@ -1233,7 +1234,6 @@ def where(self, other, cond) -> list[Block]:
12331234
result_blocks: list[Block] = []
12341235
for m in [mask, ~mask]:
12351236
if m.any():
1236-
result = cast(np.ndarray, result) # EABlock overrides where
12371237
taken = result.take(m.nonzero()[0], axis=axis)
12381238
r = maybe_downcast_numeric(taken, self.dtype)
12391239
nb = self.make_block(r.T, placement=self._mgr_locs[m])
@@ -1734,7 +1734,9 @@ def where(self, other, cond) -> list[Block]:
17341734
try:
17351735
res_values = arr.T._where(cond, other).T
17361736
except (ValueError, TypeError):
1737-
return Block.where(self, other, cond)
1737+
blk = self.coerce_to_target_dtype(other)
1738+
nbs = blk.where(other, cond)
1739+
return self._maybe_downcast(nbs, "infer")
17381740

17391741
nb = self.make_block_same_class(res_values)
17401742
return [nb]

0 commit comments

Comments
 (0)