diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index ca81d54a0fb86..a8e7224eb524f 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -869,6 +869,12 @@ def _replace_coerce( # --------------------------------------------------------------------- + def _maybe_squeeze_arg(self, arg: np.ndarray) -> np.ndarray: + """ + For compatibility with 1D-only ExtensionArrays. + """ + return arg + def setitem(self, indexer, value): """ Attempt self.values[indexer] = value, possibly creating a new array. @@ -1314,6 +1320,46 @@ class EABackedBlock(Block): values: ExtensionArray + def putmask(self, mask, new) -> list[Block]: + """ + See Block.putmask.__doc__ + """ + mask = extract_bool_array(mask) + + values = self.values + + mask = self._maybe_squeeze_arg(mask) + + try: + # Caller is responsible for ensuring matching lengths + values._putmask(mask, new) + except (TypeError, ValueError) as err: + if isinstance(err, ValueError) and "Timezones don't match" not in str(err): + # TODO(2.0): remove catching ValueError at all since + # DTA raising here is deprecated + raise + + if is_interval_dtype(self.dtype): + # Discussion about what we want to support in the general + # case GH#39584 + blk = self.coerce_to_target_dtype(new) + if blk.dtype == _dtype_obj: + # For now at least, only support casting e.g. + # Interval[int64]->Interval[float64], + raise + return blk.putmask(mask, new) + + elif isinstance(self, NDArrayBackedExtensionBlock): + # NB: not (yet) the same as + # isinstance(values, NDArrayBackedExtensionArray) + blk = self.coerce_to_target_dtype(new) + return blk.putmask(mask, new) + + else: + raise + + return [self] + def delete(self, loc) -> None: """ Delete given loc(-s) from block in-place. @@ -1410,36 +1456,16 @@ def set_inplace(self, locs, values) -> None: # _cache not yet initialized pass - def putmask(self, mask, new) -> list[Block]: + def _maybe_squeeze_arg(self, arg): """ - See Block.putmask.__doc__ + If necessary, squeeze a (N, 1) ndarray to (N,) """ - mask = extract_bool_array(mask) - - new_values = self.values - - if mask.ndim == new_values.ndim + 1: + # e.g. if we are passed a 2D mask for putmask + if isinstance(arg, np.ndarray) and arg.ndim == self.values.ndim + 1: # TODO(EA2D): unnecessary with 2D EAs - mask = mask.reshape(new_values.shape) - - try: - # Caller is responsible for ensuring matching lengths - new_values._putmask(mask, new) - except TypeError: - if not is_interval_dtype(self.dtype): - # Discussion about what we want to support in the general - # case GH#39584 - raise - - blk = self.coerce_to_target_dtype(new) - if blk.dtype == _dtype_obj: - # For now at least, only support casting e.g. - # Interval[int64]->Interval[float64], - raise - return blk.putmask(mask, new) - - nb = type(self)(new_values, placement=self._mgr_locs, ndim=self.ndim) - return [nb] + assert arg.shape[1] == 1 + arg = arg[:, 0] + return arg @property def is_view(self) -> bool: @@ -1595,15 +1621,8 @@ def where(self, other, cond) -> list[Block]: cond = extract_bool_array(cond) assert not isinstance(other, (ABCIndex, ABCSeries, ABCDataFrame)) - if isinstance(other, np.ndarray) and other.ndim == 2: - # TODO(EA2D): unnecessary with 2D EAs - assert other.shape[1] == 1 - other = other[:, 0] - - if isinstance(cond, np.ndarray) and cond.ndim == 2: - # TODO(EA2D): unnecessary with 2D EAs - assert cond.shape[1] == 1 - cond = cond[:, 0] + other = self._maybe_squeeze_arg(other) + cond = self._maybe_squeeze_arg(cond) if lib.is_scalar(other) and isna(other): # The default `other` for Series / Frame is np.nan @@ -1698,16 +1717,6 @@ def setitem(self, indexer, value): values[indexer] = value return self - def putmask(self, mask, new) -> list[Block]: - mask = extract_bool_array(mask) - - if not self._can_hold_element(new): - return self.coerce_to_target_dtype(new).putmask(mask, new) - - arr = self.values - arr.T._putmask(mask, new) - return [self] - def where(self, other, cond) -> list[Block]: arr = self.values