diff --git a/pandas/core/array_algos/putmask.py b/pandas/core/array_algos/putmask.py index b17e86e774f60..54324bf721945 100644 --- a/pandas/core/array_algos/putmask.py +++ b/pandas/core/array_algos/putmask.py @@ -126,8 +126,7 @@ def putmask_smart(values: np.ndarray, mask: npt.NDArray[np.bool_], new) -> np.nd if values.dtype.kind == new.dtype.kind: # preserves dtype if possible - np.putmask(values, mask, new) - return values + return _putmask_preserve(values, new, mask) dtype = find_common_type([values.dtype, new.dtype]) # error: Argument 1 to "astype" of "_ArrayOrScalarCommon" has incompatible type @@ -136,8 +135,51 @@ def putmask_smart(values: np.ndarray, mask: npt.NDArray[np.bool_], new) -> np.nd # List[Any], _DTypeDict, Tuple[Any, Any]]]" values = values.astype(dtype) # type: ignore[arg-type] - np.putmask(values, mask, new) - return values + return _putmask_preserve(values, new, mask) + + +def _putmask_preserve(new_values: np.ndarray, new, mask: npt.NDArray[np.bool_]): + try: + new_values[mask] = new[mask] + except (IndexError, ValueError): + new_values[mask] = new + return new_values + + +def putmask_without_repeat( + values: np.ndarray, mask: npt.NDArray[np.bool_], new: Any +) -> None: + """ + np.putmask will truncate or repeat if `new` is a listlike with + len(new) != len(values). We require an exact match. + + Parameters + ---------- + values : np.ndarray + mask : np.ndarray[bool] + new : Any + """ + if getattr(new, "ndim", 0) >= 1: + new = new.astype(values.dtype, copy=False) + + # TODO: this prob needs some better checking for 2D cases + nlocs = mask.sum() + if nlocs > 0 and is_list_like(new) and getattr(new, "ndim", 1) == 1: + if nlocs == len(new): + # GH#30567 + # If length of ``new`` is less than the length of ``values``, + # `np.putmask` would first repeat the ``new`` array and then + # assign the masked values hence produces incorrect result. + # `np.place` on the other hand uses the ``new`` values at it is + # to place in the masked locations of ``values`` + np.place(values, mask, new) + # i.e. values[mask] = new + elif mask.shape[-1] == len(new) or len(new) == 1: + np.putmask(values, mask, new) + else: + raise ValueError("cannot assign mismatch length to masked array") + else: + np.putmask(values, mask, new) def validate_putmask( diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index a452eabd4ea6f..2589015e0f0b1 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -51,7 +51,6 @@ is_extension_array_dtype, is_interval_dtype, is_list_like, - is_object_dtype, is_string_dtype, ) from pandas.core.dtypes.dtypes import ( @@ -77,6 +76,7 @@ extract_bool_array, putmask_inplace, putmask_smart, + putmask_without_repeat, setitem_datetimelike_compat, validate_putmask, ) @@ -960,7 +960,10 @@ def putmask(self, mask, new) -> list[Block]: new = self.fill_value if self._can_hold_element(new): - np.putmask(self.values.T, mask, new) + + # error: Argument 1 to "putmask_without_repeat" has incompatible type + # "Union[ndarray, ExtensionArray]"; expected "ndarray" + putmask_without_repeat(self.values.T, mask, new) # type: ignore[arg-type] return [self] elif noop: @@ -1412,16 +1415,15 @@ def putmask(self, mask, new) -> list[Block]: new_values = self.values + if isinstance(new, (np.ndarray, ExtensionArray)) and len(new) == len(mask): + new = new[mask] + if mask.ndim == new_values.ndim + 1: # TODO(EA2D): unnecessary with 2D EAs mask = mask.reshape(new_values.shape) try: - if isinstance(new, (np.ndarray, ExtensionArray)): - # Caller is responsible for ensuring matching lengths - new_values[mask] = new[mask] - else: - new_values[mask] = new + new_values[mask] = new except TypeError: if not is_interval_dtype(self.dtype): # Discussion about what we want to support in the general @@ -1479,14 +1481,7 @@ def setitem(self, indexer, value): # we are always 1-D indexer = indexer[0] - try: - check_setitem_lengths(indexer, value, self.values) - except ValueError: - # If we are object dtype (e.g. PandasDtype[object]) then - # we can hold nested data, so can ignore this mismatch. - if not is_object_dtype(self.dtype): - raise - + check_setitem_lengths(indexer, value, self.values) self.values[indexer] = value return self diff --git a/pandas/core/series.py b/pandas/core/series.py index 579c16613ec2e..7ee9a0bcdd9e1 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -1101,6 +1101,7 @@ def __setitem__(self, key, value) -> None: is_list_like(value) and len(value) != len(self) and not isinstance(value, Series) + and not is_object_dtype(self.dtype) ): # Series will be reindexed to have matching length inside # _where call below diff --git a/pandas/tests/extension/test_numpy.py b/pandas/tests/extension/test_numpy.py index df424e649fbe9..e60f7769270bd 100644 --- a/pandas/tests/extension/test_numpy.py +++ b/pandas/tests/extension/test_numpy.py @@ -363,11 +363,6 @@ def test_concat(self, data, in_frame): class TestSetitem(BaseNumPyTests, base.BaseSetitemTests): - @skip_nested - def test_setitem_sequence_mismatched_length_raises(self, data, as_array): - # doesn't raise bc object dtype holds nested data - super().test_setitem_sequence_mismatched_length_raises(data, as_array) - @skip_nested def test_setitem_invalid(self, data, invalid_scalar): # object dtype can hold anything, so doesn't raise