@@ -1291,7 +1291,7 @@ def shift(self, periods: int, axis: int = 0, fill_value: Any = None) -> List[Blo
1291
1291
1292
1292
return [self .make_block (new_values )]
1293
1293
1294
- def where (self , other , cond , errors = "raise" , axis : int = 0 ) -> List [Block ]:
1294
+ def where (self , other , cond , errors = "raise" ) -> List [Block ]:
1295
1295
"""
1296
1296
evaluate the block; return result block(s) from the result
1297
1297
@@ -1302,14 +1302,14 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
1302
1302
errors : str, {'raise', 'ignore'}, default 'raise'
1303
1303
- ``raise`` : allow exceptions to be raised
1304
1304
- ``ignore`` : suppress exceptions. On error return original object
1305
- axis : int, default 0
1306
1305
1307
1306
Returns
1308
1307
-------
1309
1308
List[Block]
1310
1309
"""
1311
1310
import pandas .core .computation .expressions as expressions
1312
1311
1312
+ assert cond .ndim == self .ndim
1313
1313
assert not isinstance (other , (ABCIndex , ABCSeries , ABCDataFrame ))
1314
1314
1315
1315
assert errors in ["raise" , "ignore" ]
@@ -1322,7 +1322,7 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
1322
1322
1323
1323
icond , noop = validate_putmask (values , ~ cond )
1324
1324
1325
- if is_valid_na_for_dtype (other , self .dtype ) and not self .is_object :
1325
+ if is_valid_na_for_dtype (other , self .dtype ) and self .dtype != _dtype_obj :
1326
1326
other = self .fill_value
1327
1327
1328
1328
if noop :
@@ -1335,7 +1335,7 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
1335
1335
# we cannot coerce, return a compat dtype
1336
1336
# we are explicitly ignoring errors
1337
1337
block = self .coerce_to_target_dtype (other )
1338
- blocks = block .where (orig_other , cond , errors = errors , axis = axis )
1338
+ blocks = block .where (orig_other , cond , errors = errors )
1339
1339
return self ._maybe_downcast (blocks , "infer" )
1340
1340
1341
1341
# error: Argument 1 to "setitem_datetimelike_compat" has incompatible type
@@ -1364,7 +1364,7 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
1364
1364
cond = ~ icond
1365
1365
axis = cond .ndim - 1
1366
1366
cond = cond .swapaxes (axis , 0 )
1367
- mask = np . array ([ cond [ i ] .all () for i in range ( cond . shape [ 0 ])], dtype = bool )
1367
+ mask = cond .all (axis = 1 )
1368
1368
1369
1369
result_blocks : List [Block ] = []
1370
1370
for m in [mask , ~ mask ]:
@@ -1670,7 +1670,7 @@ def shift(self, periods: int, axis: int = 0, fill_value: Any = None) -> List[Blo
1670
1670
new_values = self .values .shift (periods = periods , fill_value = fill_value )
1671
1671
return [self .make_block_same_class (new_values )]
1672
1672
1673
- def where (self , other , cond , errors = "raise" , axis : int = 0 ) -> List [Block ]:
1673
+ def where (self , other , cond , errors = "raise" ) -> List [Block ]:
1674
1674
1675
1675
cond = extract_bool_array (cond )
1676
1676
assert not isinstance (other , (ABCIndex , ABCSeries , ABCDataFrame ))
@@ -1828,7 +1828,7 @@ def putmask(self, mask, new) -> List[Block]:
1828
1828
arr .T .putmask (mask , new )
1829
1829
return [self ]
1830
1830
1831
- def where (self , other , cond , errors = "raise" , axis : int = 0 ) -> List [Block ]:
1831
+ def where (self , other , cond , errors = "raise" ) -> List [Block ]:
1832
1832
# TODO(EA2D): reshape unnecessary with 2D EAs
1833
1833
arr = self .array_values ().reshape (self .shape )
1834
1834
@@ -1837,7 +1837,7 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
1837
1837
try :
1838
1838
res_values = arr .T .where (cond , other ).T
1839
1839
except (ValueError , TypeError ):
1840
- return super ().where (other , cond , errors = errors , axis = axis )
1840
+ return super ().where (other , cond , errors = errors )
1841
1841
1842
1842
# TODO(EA2D): reshape not needed with 2D EAs
1843
1843
res_values = res_values .reshape (self .values .shape )
0 commit comments