@@ -1139,17 +1139,11 @@ def shift(self, periods: int, axis: int = 0, fill_value: Any = None) -> list[Blo
1139
1139
# convert integer to float if necessary. need to do a lot more than
1140
1140
# that, handle boolean etc also
1141
1141
1142
- # error: Value of type variable "NumpyArrayT" of "maybe_upcast" cannot be
1143
- # "Union[ndarray[Any, Any], ExtensionArray]"
1144
- new_values , fill_value = maybe_upcast (
1145
- self .values , fill_value # type: ignore[type-var]
1146
- )
1142
+ values = cast (np .ndarray , self .values )
1147
1143
1148
- # error: Argument 1 to "shift" has incompatible type "Union[ndarray[Any, Any],
1149
- # ExtensionArray]"; expected "ndarray[Any, Any]"
1150
- new_values = shift (
1151
- new_values , periods , axis , fill_value # type: ignore[arg-type]
1152
- )
1144
+ new_values , fill_value = maybe_upcast (values , fill_value )
1145
+
1146
+ new_values = shift (new_values , periods , axis , fill_value )
1153
1147
1154
1148
return [self .make_block (new_values )]
1155
1149
@@ -1171,7 +1165,8 @@ def where(self, other, cond) -> list[Block]:
1171
1165
1172
1166
transpose = self .ndim == 2
1173
1167
1174
- values = self .values
1168
+ # EABlocks override where
1169
+ values = cast (np .ndarray , self .values )
1175
1170
orig_other = other
1176
1171
if transpose :
1177
1172
values = values .T
@@ -1185,22 +1180,15 @@ def where(self, other, cond) -> list[Block]:
1185
1180
# TODO: avoid the downcasting at the end in this case?
1186
1181
# GH-39595: Always return a copy
1187
1182
result = values .copy ()
1183
+
1184
+ elif not self ._can_hold_element (other ):
1185
+ # we cannot coerce, return a compat dtype
1186
+ block = self .coerce_to_target_dtype (other )
1187
+ blocks = block .where (orig_other , cond )
1188
+ return self ._maybe_downcast (blocks , "infer" )
1189
+
1188
1190
else :
1189
- # see if we can operate on the entire block, or need item-by-item
1190
- # or if we are a single block (ndim == 1)
1191
- if not self ._can_hold_element (other ):
1192
- # we cannot coerce, return a compat dtype
1193
- block = self .coerce_to_target_dtype (other )
1194
- blocks = block .where (orig_other , cond )
1195
- return self ._maybe_downcast (blocks , "infer" )
1196
-
1197
- # error: Argument 1 to "setitem_datetimelike_compat" has incompatible type
1198
- # "Union[ndarray, ExtensionArray]"; expected "ndarray"
1199
- # error: Argument 2 to "setitem_datetimelike_compat" has incompatible type
1200
- # "number[Any]"; expected "int"
1201
- alt = setitem_datetimelike_compat (
1202
- values , icond .sum (), other # type: ignore[arg-type]
1203
- )
1191
+ alt = setitem_datetimelike_compat (values , icond .sum (), other )
1204
1192
if alt is not other :
1205
1193
if is_list_like (other ) and len (other ) < len (values ):
1206
1194
# call np.where with other to get the appropriate ValueError
@@ -1215,6 +1203,19 @@ def where(self, other, cond) -> list[Block]:
1215
1203
else :
1216
1204
# By the time we get here, we should have all Series/Index
1217
1205
# args extracted to ndarray
1206
+ if (
1207
+ is_list_like (other )
1208
+ and not isinstance (other , np .ndarray )
1209
+ and len (other ) == self .shape [- 1 ]
1210
+ ):
1211
+ # If we don't do this broadcasting here, then expressions.where
1212
+ # will broadcast a 1D other to be row-like instead of
1213
+ # column-like.
1214
+ other = np .array (other ).reshape (values .shape )
1215
+ # If lengths don't match (or len(other)==1), we will raise
1216
+ # inside expressions.where, see test_series_where
1217
+
1218
+ # Note: expressions.where may upcast.
1218
1219
result = expressions .where (~ icond , values , other )
1219
1220
1220
1221
if self ._can_hold_na or self .ndim == 1 :
@@ -1233,7 +1234,6 @@ def where(self, other, cond) -> list[Block]:
1233
1234
result_blocks : list [Block ] = []
1234
1235
for m in [mask , ~ mask ]:
1235
1236
if m .any ():
1236
- result = cast (np .ndarray , result ) # EABlock overrides where
1237
1237
taken = result .take (m .nonzero ()[0 ], axis = axis )
1238
1238
r = maybe_downcast_numeric (taken , self .dtype )
1239
1239
nb = self .make_block (r .T , placement = self ._mgr_locs [m ])
@@ -1734,7 +1734,9 @@ def where(self, other, cond) -> list[Block]:
1734
1734
try :
1735
1735
res_values = arr .T ._where (cond , other ).T
1736
1736
except (ValueError , TypeError ):
1737
- return Block .where (self , other , cond )
1737
+ blk = self .coerce_to_target_dtype (other )
1738
+ nbs = blk .where (other , cond )
1739
+ return self ._maybe_downcast (nbs , "infer" )
1738
1740
1739
1741
nb = self .make_block_same_class (res_values )
1740
1742
return [nb ]
0 commit comments