diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 2b0b62ab7facf..8ebb5437978ea 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -830,7 +830,7 @@ ExtensionArray - Bug in :meth:`DataFrame.where` when ``other`` is a :class:`Series` with :class:`ExtensionArray` dtype (:issue:`38729`) - Fixed bug where :meth:`Series.idxmax`, :meth:`Series.idxmin` and ``argmax/min`` fail when the underlying data is :class:`ExtensionArray` (:issue:`32749`, :issue:`33719`, :issue:`36566`) - Fixed a bug where some properties of subclasses of :class:`PandasExtensionDtype` where improperly cached (:issue:`40329`) -- +- Bug in :meth:`DataFrame.mask` where masking a :class:`Dataframe` with an :class:`ExtensionArray` dtype raises ``ValueError`` (:issue:`40941`) Styler ^^^^^^ diff --git a/pandas/core/generic.py b/pandas/core/generic.py index cbc353eead464..bad42a85aeeee 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -8958,7 +8958,7 @@ def _where( join="left", axis=axis, level=level, - fill_value=np.nan, + fill_value=None, copy=False, ) diff --git a/pandas/tests/frame/indexing/test_mask.py b/pandas/tests/frame/indexing/test_mask.py index afa8c757c23e4..364475428e529 100644 --- a/pandas/tests/frame/indexing/test_mask.py +++ b/pandas/tests/frame/indexing/test_mask.py @@ -5,7 +5,10 @@ import numpy as np from pandas import ( + NA, DataFrame, + Series, + StringDtype, isna, ) import pandas._testing as tm @@ -99,3 +102,24 @@ def test_mask_try_cast_deprecated(frame_or_series): with tm.assert_produces_warning(FutureWarning): # try_cast keyword deprecated obj.mask(mask, -1, try_cast=True) + + +def test_mask_stringdtype(): + # GH 40824 + df = DataFrame( + {"A": ["foo", "bar", "baz", NA]}, + index=["id1", "id2", "id3", "id4"], + dtype=StringDtype(), + ) + filtered_df = DataFrame( + {"A": ["this", "that"]}, index=["id2", "id3"], dtype=StringDtype() + ) + filter_ser = Series([False, True, True, False]) + result = df.mask(filter_ser, filtered_df) + + expected = DataFrame( + {"A": [NA, "this", "that", NA]}, + index=["id1", "id2", "id3", "id4"], + dtype=StringDtype(), + ) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/frame/indexing/test_where.py b/pandas/tests/frame/indexing/test_where.py index 574fa46d10f67..7ffe2fb9ab1ff 100644 --- a/pandas/tests/frame/indexing/test_where.py +++ b/pandas/tests/frame/indexing/test_where.py @@ -10,6 +10,7 @@ DataFrame, DatetimeIndex, Series, + StringDtype, Timestamp, date_range, isna, @@ -709,3 +710,22 @@ def test_where_copies_with_noop(frame_or_series): where_res *= 2 tm.assert_equal(result, expected) + + +def test_where_string_dtype(frame_or_series): + # GH40824 + obj = frame_or_series( + ["a", "b", "c", "d"], index=["id1", "id2", "id3", "id4"], dtype=StringDtype() + ) + filtered_obj = frame_or_series( + ["b", "c"], index=["id2", "id3"], dtype=StringDtype() + ) + filter_ser = Series([False, True, True, False]) + + result = obj.where(filter_ser, filtered_obj) + expected = frame_or_series( + [pd.NA, "b", "c", pd.NA], + index=["id1", "id2", "id3", "id4"], + dtype=StringDtype(), + ) + tm.assert_equal(result, expected) diff --git a/pandas/tests/series/indexing/test_mask.py b/pandas/tests/series/indexing/test_mask.py index dc4fb530dbb52..a4dda3a5c0c5b 100644 --- a/pandas/tests/series/indexing/test_mask.py +++ b/pandas/tests/series/indexing/test_mask.py @@ -1,7 +1,11 @@ import numpy as np import pytest -from pandas import Series +from pandas import ( + NA, + Series, + StringDtype, +) import pandas._testing as tm @@ -63,3 +67,22 @@ def test_mask_inplace(): rs = s.copy() rs.mask(cond, -s, inplace=True) tm.assert_series_equal(rs, s.mask(cond, -s)) + + +def test_mask_stringdtype(): + # GH 40824 + ser = Series( + ["foo", "bar", "baz", NA], + index=["id1", "id2", "id3", "id4"], + dtype=StringDtype(), + ) + filtered_ser = Series(["this", "that"], index=["id2", "id3"], dtype=StringDtype()) + filter_ser = Series([False, True, True, False]) + result = ser.mask(filter_ser, filtered_ser) + + expected = Series( + [NA, "this", "that", NA], + index=["id1", "id2", "id3", "id4"], + dtype=StringDtype(), + ) + tm.assert_series_equal(result, expected)