Skip to content

Commit 95f5dd2

Browse files
ganevgvproost
authored andcommitted
TST: add test for df.where() with category dtype (pandas-dev#29454)
xref pandas-dev#16979
1 parent bcf9674 commit 95f5dd2

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

pandas/tests/frame/test_dtypes.py

+16
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,22 @@ def test_astype_extension_dtypes_duplicate_col(self, dtype):
815815
expected = concat([a1.astype(dtype), a2.astype(dtype)], axis=1)
816816
tm.assert_frame_equal(result, expected)
817817

818+
@pytest.mark.parametrize("kwargs", [dict(), dict(other=None)])
819+
def test_df_where_with_category(self, kwargs):
820+
# GH 16979
821+
df = DataFrame(np.arange(2 * 3).reshape(2, 3), columns=list("ABC"))
822+
mask = np.array([[True, False, True], [False, True, True]])
823+
824+
# change type to category
825+
df.A = df.A.astype("category")
826+
df.B = df.B.astype("category")
827+
df.C = df.C.astype("category")
828+
829+
result = df.A.where(mask[:, 0], **kwargs)
830+
expected = Series(pd.Categorical([0, np.nan], categories=[0, 3]), name="A")
831+
832+
tm.assert_series_equal(result, expected)
833+
818834
@pytest.mark.parametrize(
819835
"dtype", [{100: "float64", 200: "uint64"}, "category", "float64"]
820836
)

0 commit comments

Comments
 (0)