@@ -869,6 +869,12 @@ def _replace_coerce(
869
869
870
870
# ---------------------------------------------------------------------
871
871
872
+ def _maybe_squeeze_arg (self , arg : np .ndarray ) -> np .ndarray :
873
+ """
874
+ For compatibility with 1D-only ExtensionArrays.
875
+ """
876
+ return arg
877
+
872
878
def setitem (self , indexer , value ):
873
879
"""
874
880
Attempt self.values[indexer] = value, possibly creating a new array.
@@ -1314,6 +1320,46 @@ class EABackedBlock(Block):
1314
1320
1315
1321
values : ExtensionArray
1316
1322
1323
+ def putmask (self , mask , new ) -> list [Block ]:
1324
+ """
1325
+ See Block.putmask.__doc__
1326
+ """
1327
+ mask = extract_bool_array (mask )
1328
+
1329
+ values = self .values
1330
+
1331
+ mask = self ._maybe_squeeze_arg (mask )
1332
+
1333
+ try :
1334
+ # Caller is responsible for ensuring matching lengths
1335
+ values ._putmask (mask , new )
1336
+ except (TypeError , ValueError ) as err :
1337
+ if isinstance (err , ValueError ) and "Timezones don't match" not in str (err ):
1338
+ # TODO(2.0): remove catching ValueError at all since
1339
+ # DTA raising here is deprecated
1340
+ raise
1341
+
1342
+ if is_interval_dtype (self .dtype ):
1343
+ # Discussion about what we want to support in the general
1344
+ # case GH#39584
1345
+ blk = self .coerce_to_target_dtype (new )
1346
+ if blk .dtype == _dtype_obj :
1347
+ # For now at least, only support casting e.g.
1348
+ # Interval[int64]->Interval[float64],
1349
+ raise
1350
+ return blk .putmask (mask , new )
1351
+
1352
+ elif isinstance (self , NDArrayBackedExtensionBlock ):
1353
+ # NB: not (yet) the same as
1354
+ # isinstance(values, NDArrayBackedExtensionArray)
1355
+ blk = self .coerce_to_target_dtype (new )
1356
+ return blk .putmask (mask , new )
1357
+
1358
+ else :
1359
+ raise
1360
+
1361
+ return [self ]
1362
+
1317
1363
def delete (self , loc ) -> None :
1318
1364
"""
1319
1365
Delete given loc(-s) from block in-place.
@@ -1410,36 +1456,16 @@ def set_inplace(self, locs, values) -> None:
1410
1456
# _cache not yet initialized
1411
1457
pass
1412
1458
1413
- def putmask (self , mask , new ) -> list [ Block ] :
1459
+ def _maybe_squeeze_arg (self , arg ) :
1414
1460
"""
1415
- See Block.putmask.__doc__
1461
+ If necessary, squeeze a (N, 1) ndarray to (N,)
1416
1462
"""
1417
- mask = extract_bool_array (mask )
1418
-
1419
- new_values = self .values
1420
-
1421
- if mask .ndim == new_values .ndim + 1 :
1463
+ # e.g. if we are passed a 2D mask for putmask
1464
+ if isinstance (arg , np .ndarray ) and arg .ndim == self .values .ndim + 1 :
1422
1465
# TODO(EA2D): unnecessary with 2D EAs
1423
- mask = mask .reshape (new_values .shape )
1424
-
1425
- try :
1426
- # Caller is responsible for ensuring matching lengths
1427
- new_values ._putmask (mask , new )
1428
- except TypeError :
1429
- if not is_interval_dtype (self .dtype ):
1430
- # Discussion about what we want to support in the general
1431
- # case GH#39584
1432
- raise
1433
-
1434
- blk = self .coerce_to_target_dtype (new )
1435
- if blk .dtype == _dtype_obj :
1436
- # For now at least, only support casting e.g.
1437
- # Interval[int64]->Interval[float64],
1438
- raise
1439
- return blk .putmask (mask , new )
1440
-
1441
- nb = type (self )(new_values , placement = self ._mgr_locs , ndim = self .ndim )
1442
- return [nb ]
1466
+ assert arg .shape [1 ] == 1
1467
+ arg = arg [:, 0 ]
1468
+ return arg
1443
1469
1444
1470
@property
1445
1471
def is_view (self ) -> bool :
@@ -1595,15 +1621,8 @@ def where(self, other, cond) -> list[Block]:
1595
1621
cond = extract_bool_array (cond )
1596
1622
assert not isinstance (other , (ABCIndex , ABCSeries , ABCDataFrame ))
1597
1623
1598
- if isinstance (other , np .ndarray ) and other .ndim == 2 :
1599
- # TODO(EA2D): unnecessary with 2D EAs
1600
- assert other .shape [1 ] == 1
1601
- other = other [:, 0 ]
1602
-
1603
- if isinstance (cond , np .ndarray ) and cond .ndim == 2 :
1604
- # TODO(EA2D): unnecessary with 2D EAs
1605
- assert cond .shape [1 ] == 1
1606
- cond = cond [:, 0 ]
1624
+ other = self ._maybe_squeeze_arg (other )
1625
+ cond = self ._maybe_squeeze_arg (cond )
1607
1626
1608
1627
if lib .is_scalar (other ) and isna (other ):
1609
1628
# The default `other` for Series / Frame is np.nan
@@ -1698,16 +1717,6 @@ def setitem(self, indexer, value):
1698
1717
values [indexer ] = value
1699
1718
return self
1700
1719
1701
- def putmask (self , mask , new ) -> list [Block ]:
1702
- mask = extract_bool_array (mask )
1703
-
1704
- if not self ._can_hold_element (new ):
1705
- return self .coerce_to_target_dtype (new ).putmask (mask , new )
1706
-
1707
- arr = self .values
1708
- arr .T ._putmask (mask , new )
1709
- return [self ]
1710
-
1711
1720
def where (self , other , cond ) -> list [Block ]:
1712
1721
arr = self .values
1713
1722
0 commit comments