Skip to content

Commit 4aa160e

Browse files
test_get_dummies_with_str_dtype
1 parent fc39d86 commit 4aa160e

File tree

2 files changed

+76
-27
lines changed

2 files changed

+76
-27
lines changed

pandas/conftest.py

+4
Original file line numberDiff line numberDiff line change
@@ -1435,6 +1435,10 @@ def any_string_dtype(request):
14351435
return pd.StringDtype(storage, na_value)
14361436

14371437

1438+
# Generate cartesian product of any_string_dtype:
1439+
any_string_dtype2 = any_string_dtype
1440+
1441+
14381442
@pytest.fixture(params=tm.DATETIME64_DTYPES)
14391443
def datetime64_dtype(request):
14401444
"""

pandas/tests/strings/test_get_dummies.py

+72-27
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import numpy as np
22
import pytest
33

4-
from pandas._config import using_string_dtype
5-
64
import pandas.util._test_decorators as td
75

86
from pandas import (
@@ -98,30 +96,77 @@ def test_get_dummies_with_pyarrow_dtype(any_string_dtype, dtype):
9896

9997

10098
# GH#47872
101-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
102-
def test_get_dummies_with_str_dtype(any_string_dtype):
99+
@pytest.mark.parametrize("use_string_repr", [True, False])
100+
def test_get_dummies_with_any_string_dtype(
101+
request, any_string_dtype, any_string_dtype2, use_string_repr, using_infer_string
102+
):
103103
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
104-
result = s.str.get_dummies("|", dtype=str)
105-
expected = DataFrame(
106-
[["T", "T", "F"], ["T", "F", "T"], ["F", "F", "F"]],
107-
columns=list("abc"),
108-
dtype=str,
109-
)
110-
tm.assert_frame_equal(result, expected)
111-
112-
113-
# GH#47872
114-
@td.skip_if_no("pyarrow")
115-
def test_get_dummies_with_pa_str_dtype(any_string_dtype):
116-
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
117-
result = s.str.get_dummies("|", dtype="str[pyarrow]")
118-
expected = DataFrame(
119-
[
120-
["true", "true", "false"],
121-
["true", "false", "true"],
122-
["false", "false", "false"],
123-
],
124-
columns=list("abc"),
125-
dtype="str[pyarrow]",
126-
)
104+
test_ids = request.node.callspec.id.split("-")
105+
series_dtype_id = test_ids[0][7:]
106+
expected_dtype_id = test_ids[1][7:]
107+
if expected_dtype_id == "object":
108+
if "pyarrow" in series_dtype_id:
109+
request.applymarker(
110+
pytest.mark.xfail(
111+
reason=("pyarrow.lib.ArrowTypeError: Expected integer, got bool"),
112+
strict=True,
113+
)
114+
)
115+
expected = DataFrame(
116+
[
117+
[True, True, False],
118+
[True, False, True],
119+
[False, False, False],
120+
],
121+
columns=list("abc"),
122+
dtype=np.bool_,
123+
)
124+
elif expected_dtype_id == "str[pyarrow]" and use_string_repr:
125+
# data type 'str[pyarrow]' uses pandas.ArrowDtype instead
126+
expected = DataFrame(
127+
[
128+
["true", "true", "false"],
129+
["true", "false", "true"],
130+
["false", "false", "false"],
131+
],
132+
columns=list("abc"),
133+
dtype="str[pyarrow]",
134+
)
135+
elif expected_dtype_id == "str[python]" and use_string_repr:
136+
# data type 'str[python]' not understood"
137+
expected_dtype_id = str
138+
if using_infer_string:
139+
expected = DataFrame(
140+
[
141+
["True", "True", "False"],
142+
["True", "False", "True"],
143+
["False", "False", "False"],
144+
],
145+
columns=list("abc"),
146+
dtype=expected_dtype_id,
147+
)
148+
else:
149+
expected = DataFrame(
150+
[
151+
["T", "T", "F"],
152+
["T", "F", "T"],
153+
["F", "F", "F"],
154+
],
155+
columns=list("abc"),
156+
dtype=expected_dtype_id,
157+
)
158+
else:
159+
expected = DataFrame(
160+
[
161+
["True", "True", "False"],
162+
["True", "False", "True"],
163+
["False", "False", "False"],
164+
],
165+
columns=list("abc"),
166+
dtype=any_string_dtype2,
167+
)
168+
if use_string_repr:
169+
result = s.str.get_dummies("|", dtype=expected_dtype_id)
170+
else:
171+
result = s.str.get_dummies("|", dtype=any_string_dtype2)
127172
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)