Skip to content

Commit 76d412f

Browse files
authored
BUG: frame.mask(foo, bar, inplace=True) with EAs incorrectly raising (#45577)
1 parent 96cf851 commit 76d412f

File tree

4 files changed

+19
-1
lines changed

4 files changed

+19
-1
lines changed

doc/source/whatsnew/v1.5.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ Indexing
264264
- Bug in :meth:`loc.__setitem__` treating ``range`` keys as positional instead of label-based (:issue:`45479`)
265265
- Bug in :meth:`Series.__setitem__` when setting ``boolean`` dtype values containing ``NA`` incorrectly raising instead of casting to ``boolean`` dtype (:issue:`45462`)
266266
- Bug in :meth:`Series.__setitem__` where setting :attr:`NA` into a numeric-dtpye :class:`Series` would incorrectly upcast to object-dtype rather than treating the value as ``np.nan`` (:issue:`44199`)
267+
- Bug in :meth:`DataFrame.mask` with ``inplace=True`` and ``ExtensionDtype`` columns incorrectly raising (:issue:`45577`)
267268
- Bug in getting a column from a DataFrame with an object-dtype row index with datetime-like values: the resulting Series now preserves the exact object-dtype Index from the parent DataFrame (:issue:`42950`)
268269
-
269270

pandas/core/internals/blocks.py

+1
Original file line numberDiff line numberDiff line change
@@ -1434,6 +1434,7 @@ def putmask(self, mask, new) -> list[Block]:
14341434

14351435
values = self.values
14361436

1437+
new = self._maybe_squeeze_arg(new)
14371438
mask = self._maybe_squeeze_arg(mask)
14381439

14391440
try:

pandas/tests/extension/base/methods.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,8 @@ def test_where_series(self, data, na_value, as_frame):
443443
cls = type(data)
444444
a, b = data[:2]
445445

446-
ser = pd.Series(cls._from_sequence([a, a, b, b], dtype=data.dtype))
446+
orig = pd.Series(cls._from_sequence([a, a, b, b], dtype=data.dtype))
447+
ser = orig.copy()
447448
cond = np.array([True, True, False, False])
448449

449450
if as_frame:
@@ -459,7 +460,13 @@ def test_where_series(self, data, na_value, as_frame):
459460
expected = expected.to_frame(name="a")
460461
self.assert_equal(result, expected)
461462

463+
ser.mask(~cond, inplace=True)
464+
self.assert_equal(ser, expected)
465+
462466
# array other
467+
ser = orig.copy()
468+
if as_frame:
469+
ser = ser.to_frame(name="a")
463470
cond = np.array([True, False, True, True])
464471
other = cls._from_sequence([a, b, a, b], dtype=data.dtype)
465472
if as_frame:
@@ -471,6 +478,9 @@ def test_where_series(self, data, na_value, as_frame):
471478
expected = expected.to_frame(name="a")
472479
self.assert_equal(result, expected)
473480

481+
ser.mask(~cond, other, inplace=True)
482+
self.assert_equal(ser, expected)
483+
474484
@pytest.mark.parametrize("repeats", [0, 1, 2, [1, 2, 3]])
475485
def test_repeat(self, data, repeats, as_series, use_numpy):
476486
arr = type(data)._from_sequence(data[:3], dtype=data.dtype)

pandas/tests/frame/indexing/test_where.py

+6
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,12 @@ def test_where_string_dtype(frame_or_series):
812812
)
813813
tm.assert_equal(result, expected)
814814

815+
result = obj.mask(~filter_ser, filtered_obj)
816+
tm.assert_equal(result, expected)
817+
818+
obj.mask(~filter_ser, filtered_obj, inplace=True)
819+
tm.assert_equal(result, expected)
820+
815821

816822
def test_where_bool_comparison():
817823
# GH 10336

0 commit comments

Comments
 (0)