Skip to content

Commit 82b69f7

Browse files
authored
ENH: implement _maybe_squeeze_arg (#44520)
1 parent c676844 commit 82b69f7

File tree

1 file changed

+55
-46
lines changed

1 file changed

+55
-46
lines changed

pandas/core/internals/blocks.py

+55-46
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,12 @@ def _replace_coerce(
869869

870870
# ---------------------------------------------------------------------
871871

872+
def _maybe_squeeze_arg(self, arg: np.ndarray) -> np.ndarray:
873+
"""
874+
For compatibility with 1D-only ExtensionArrays.
875+
"""
876+
return arg
877+
872878
def setitem(self, indexer, value):
873879
"""
874880
Attempt self.values[indexer] = value, possibly creating a new array.
@@ -1314,6 +1320,46 @@ class EABackedBlock(Block):
13141320

13151321
values: ExtensionArray
13161322

1323+
def putmask(self, mask, new) -> list[Block]:
1324+
"""
1325+
See Block.putmask.__doc__
1326+
"""
1327+
mask = extract_bool_array(mask)
1328+
1329+
values = self.values
1330+
1331+
mask = self._maybe_squeeze_arg(mask)
1332+
1333+
try:
1334+
# Caller is responsible for ensuring matching lengths
1335+
values._putmask(mask, new)
1336+
except (TypeError, ValueError) as err:
1337+
if isinstance(err, ValueError) and "Timezones don't match" not in str(err):
1338+
# TODO(2.0): remove catching ValueError at all since
1339+
# DTA raising here is deprecated
1340+
raise
1341+
1342+
if is_interval_dtype(self.dtype):
1343+
# Discussion about what we want to support in the general
1344+
# case GH#39584
1345+
blk = self.coerce_to_target_dtype(new)
1346+
if blk.dtype == _dtype_obj:
1347+
# For now at least, only support casting e.g.
1348+
# Interval[int64]->Interval[float64],
1349+
raise
1350+
return blk.putmask(mask, new)
1351+
1352+
elif isinstance(self, NDArrayBackedExtensionBlock):
1353+
# NB: not (yet) the same as
1354+
# isinstance(values, NDArrayBackedExtensionArray)
1355+
blk = self.coerce_to_target_dtype(new)
1356+
return blk.putmask(mask, new)
1357+
1358+
else:
1359+
raise
1360+
1361+
return [self]
1362+
13171363
def delete(self, loc) -> None:
13181364
"""
13191365
Delete given loc(-s) from block in-place.
@@ -1410,36 +1456,16 @@ def set_inplace(self, locs, values) -> None:
14101456
# _cache not yet initialized
14111457
pass
14121458

1413-
def putmask(self, mask, new) -> list[Block]:
1459+
def _maybe_squeeze_arg(self, arg):
14141460
"""
1415-
See Block.putmask.__doc__
1461+
If necessary, squeeze a (N, 1) ndarray to (N,)
14161462
"""
1417-
mask = extract_bool_array(mask)
1418-
1419-
new_values = self.values
1420-
1421-
if mask.ndim == new_values.ndim + 1:
1463+
# e.g. if we are passed a 2D mask for putmask
1464+
if isinstance(arg, np.ndarray) and arg.ndim == self.values.ndim + 1:
14221465
# TODO(EA2D): unnecessary with 2D EAs
1423-
mask = mask.reshape(new_values.shape)
1424-
1425-
try:
1426-
# Caller is responsible for ensuring matching lengths
1427-
new_values._putmask(mask, new)
1428-
except TypeError:
1429-
if not is_interval_dtype(self.dtype):
1430-
# Discussion about what we want to support in the general
1431-
# case GH#39584
1432-
raise
1433-
1434-
blk = self.coerce_to_target_dtype(new)
1435-
if blk.dtype == _dtype_obj:
1436-
# For now at least, only support casting e.g.
1437-
# Interval[int64]->Interval[float64],
1438-
raise
1439-
return blk.putmask(mask, new)
1440-
1441-
nb = type(self)(new_values, placement=self._mgr_locs, ndim=self.ndim)
1442-
return [nb]
1466+
assert arg.shape[1] == 1
1467+
arg = arg[:, 0]
1468+
return arg
14431469

14441470
@property
14451471
def is_view(self) -> bool:
@@ -1595,15 +1621,8 @@ def where(self, other, cond) -> list[Block]:
15951621
cond = extract_bool_array(cond)
15961622
assert not isinstance(other, (ABCIndex, ABCSeries, ABCDataFrame))
15971623

1598-
if isinstance(other, np.ndarray) and other.ndim == 2:
1599-
# TODO(EA2D): unnecessary with 2D EAs
1600-
assert other.shape[1] == 1
1601-
other = other[:, 0]
1602-
1603-
if isinstance(cond, np.ndarray) and cond.ndim == 2:
1604-
# TODO(EA2D): unnecessary with 2D EAs
1605-
assert cond.shape[1] == 1
1606-
cond = cond[:, 0]
1624+
other = self._maybe_squeeze_arg(other)
1625+
cond = self._maybe_squeeze_arg(cond)
16071626

16081627
if lib.is_scalar(other) and isna(other):
16091628
# The default `other` for Series / Frame is np.nan
@@ -1698,16 +1717,6 @@ def setitem(self, indexer, value):
16981717
values[indexer] = value
16991718
return self
17001719

1701-
def putmask(self, mask, new) -> list[Block]:
1702-
mask = extract_bool_array(mask)
1703-
1704-
if not self._can_hold_element(new):
1705-
return self.coerce_to_target_dtype(new).putmask(mask, new)
1706-
1707-
arr = self.values
1708-
arr.T._putmask(mask, new)
1709-
return [self]
1710-
17111720
def where(self, other, cond) -> list[Block]:
17121721
arr = self.values
17131722

0 commit comments

Comments
 (0)