Skip to content

Commit 344d52a

Browse files
DriesSchaumontyeshsurya
authored andcommitted
BUG: Dataframe mask method does not work properly with pd.StringDtype() (pandas-dev#40941)
1 parent b6cfeef commit 344d52a

File tree

5 files changed

+70
-3
lines changed

5 files changed

+70
-3
lines changed

doc/source/whatsnew/v1.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,7 @@ ExtensionArray
834834
- Bug in :meth:`DataFrame.where` when ``other`` is a :class:`Series` with :class:`ExtensionArray` dtype (:issue:`38729`)
835835
- Fixed bug where :meth:`Series.idxmax`, :meth:`Series.idxmin` and ``argmax/min`` fail when the underlying data is :class:`ExtensionArray` (:issue:`32749`, :issue:`33719`, :issue:`36566`)
836836
- Fixed a bug where some properties of subclasses of :class:`PandasExtensionDtype` where improperly cached (:issue:`40329`)
837-
-
837+
- Bug in :meth:`DataFrame.mask` where masking a :class:`Dataframe` with an :class:`ExtensionArray` dtype raises ``ValueError`` (:issue:`40941`)
838838

839839
Styler
840840
^^^^^^

pandas/core/generic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8958,7 +8958,7 @@ def _where(
89588958
join="left",
89598959
axis=axis,
89608960
level=level,
8961-
fill_value=np.nan,
8961+
fill_value=None,
89628962
copy=False,
89638963
)
89648964

pandas/tests/frame/indexing/test_mask.py

+24
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
import numpy as np
66

77
from pandas import (
8+
NA,
89
DataFrame,
10+
Series,
11+
StringDtype,
912
isna,
1013
)
1114
import pandas._testing as tm
@@ -99,3 +102,24 @@ def test_mask_try_cast_deprecated(frame_or_series):
99102
with tm.assert_produces_warning(FutureWarning):
100103
# try_cast keyword deprecated
101104
obj.mask(mask, -1, try_cast=True)
105+
106+
107+
def test_mask_stringdtype():
108+
# GH 40824
109+
df = DataFrame(
110+
{"A": ["foo", "bar", "baz", NA]},
111+
index=["id1", "id2", "id3", "id4"],
112+
dtype=StringDtype(),
113+
)
114+
filtered_df = DataFrame(
115+
{"A": ["this", "that"]}, index=["id2", "id3"], dtype=StringDtype()
116+
)
117+
filter_ser = Series([False, True, True, False])
118+
result = df.mask(filter_ser, filtered_df)
119+
120+
expected = DataFrame(
121+
{"A": [NA, "this", "that", NA]},
122+
index=["id1", "id2", "id3", "id4"],
123+
dtype=StringDtype(),
124+
)
125+
tm.assert_frame_equal(result, expected)

pandas/tests/frame/indexing/test_where.py

+20
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
DataFrame,
1111
DatetimeIndex,
1212
Series,
13+
StringDtype,
1314
Timestamp,
1415
date_range,
1516
isna,
@@ -709,3 +710,22 @@ def test_where_copies_with_noop(frame_or_series):
709710
where_res *= 2
710711

711712
tm.assert_equal(result, expected)
713+
714+
715+
def test_where_string_dtype(frame_or_series):
716+
# GH40824
717+
obj = frame_or_series(
718+
["a", "b", "c", "d"], index=["id1", "id2", "id3", "id4"], dtype=StringDtype()
719+
)
720+
filtered_obj = frame_or_series(
721+
["b", "c"], index=["id2", "id3"], dtype=StringDtype()
722+
)
723+
filter_ser = Series([False, True, True, False])
724+
725+
result = obj.where(filter_ser, filtered_obj)
726+
expected = frame_or_series(
727+
[pd.NA, "b", "c", pd.NA],
728+
index=["id1", "id2", "id3", "id4"],
729+
dtype=StringDtype(),
730+
)
731+
tm.assert_equal(result, expected)

pandas/tests/series/indexing/test_mask.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import numpy as np
22
import pytest
33

4-
from pandas import Series
4+
from pandas import (
5+
NA,
6+
Series,
7+
StringDtype,
8+
)
59
import pandas._testing as tm
610

711

@@ -63,3 +67,22 @@ def test_mask_inplace():
6367
rs = s.copy()
6468
rs.mask(cond, -s, inplace=True)
6569
tm.assert_series_equal(rs, s.mask(cond, -s))
70+
71+
72+
def test_mask_stringdtype():
73+
# GH 40824
74+
ser = Series(
75+
["foo", "bar", "baz", NA],
76+
index=["id1", "id2", "id3", "id4"],
77+
dtype=StringDtype(),
78+
)
79+
filtered_ser = Series(["this", "that"], index=["id2", "id3"], dtype=StringDtype())
80+
filter_ser = Series([False, True, True, False])
81+
result = ser.mask(filter_ser, filtered_ser)
82+
83+
expected = Series(
84+
[NA, "this", "that", NA],
85+
index=["id1", "id2", "id3", "id4"],
86+
dtype=StringDtype(),
87+
)
88+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)