|
4 | 4 | import numpy as np
|
5 | 5 | import pytest
|
6 | 6 |
|
| 7 | +import pandas.util._test_decorators as td |
| 8 | + |
7 | 9 | from pandas.core.dtypes.common import is_integer_dtype
|
8 | 10 |
|
9 | 11 | import pandas as pd
|
10 | 12 | from pandas import (
|
| 13 | + ArrowDtype, |
11 | 14 | Categorical,
|
| 15 | + CategoricalDtype, |
12 | 16 | CategoricalIndex,
|
13 | 17 | DataFrame,
|
| 18 | + Index, |
14 | 19 | RangeIndex,
|
15 | 20 | Series,
|
16 | 21 | SparseDtype,
|
|
19 | 24 | import pandas._testing as tm
|
20 | 25 | from pandas.core.arrays.sparse import SparseArray
|
21 | 26 |
|
| 27 | +try: |
| 28 | + import pyarrow as pa |
| 29 | +except ImportError: |
| 30 | + pa = None |
| 31 | + |
22 | 32 |
|
23 | 33 | class TestGetDummies:
|
24 | 34 | @pytest.fixture
|
@@ -217,6 +227,7 @@ def test_dataframe_dummies_string_dtype(self, df):
|
217 | 227 | },
|
218 | 228 | dtype=bool,
|
219 | 229 | )
|
| 230 | + expected[["B_b", "B_c"]] = expected[["B_b", "B_c"]].astype("boolean") |
220 | 231 | tm.assert_frame_equal(result, expected)
|
221 | 232 |
|
222 | 233 | def test_dataframe_dummies_mix_default(self, df, sparse, dtype):
|
@@ -693,3 +704,37 @@ def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_and_arrow_dtype):
|
693 | 704 | dtype=any_numeric_ea_and_arrow_dtype,
|
694 | 705 | )
|
695 | 706 | tm.assert_frame_equal(result, expected)
|
| 707 | + |
| 708 | + @td.skip_if_no("pyarrow") |
| 709 | + def test_get_dummies_ea_dtype(self): |
| 710 | + # GH#56273 |
| 711 | + for dtype, exp_dtype in [ |
| 712 | + ("string[pyarrow]", "boolean"), |
| 713 | + ("string[pyarrow_numpy]", "bool"), |
| 714 | + (CategoricalDtype(Index(["a"], dtype="string[pyarrow]")), "boolean"), |
| 715 | + (CategoricalDtype(Index(["a"], dtype="string[pyarrow_numpy]")), "bool"), |
| 716 | + ]: |
| 717 | + df = DataFrame({"name": Series(["a"], dtype=dtype), "x": 1}) |
| 718 | + result = get_dummies(df) |
| 719 | + expected = DataFrame({"x": 1, "name_a": Series([True], dtype=exp_dtype)}) |
| 720 | + tm.assert_frame_equal(result, expected) |
| 721 | + |
| 722 | + @td.skip_if_no("pyarrow") |
| 723 | + def test_get_dummies_arrow_dtype(self): |
| 724 | + # GH#56273 |
| 725 | + df = DataFrame({"name": Series(["a"], dtype=ArrowDtype(pa.string())), "x": 1}) |
| 726 | + result = get_dummies(df) |
| 727 | + expected = DataFrame({"x": 1, "name_a": Series([True], dtype="bool[pyarrow]")}) |
| 728 | + tm.assert_frame_equal(result, expected) |
| 729 | + |
| 730 | + df = DataFrame( |
| 731 | + { |
| 732 | + "name": Series( |
| 733 | + ["a"], |
| 734 | + dtype=CategoricalDtype(Index(["a"], dtype=ArrowDtype(pa.string()))), |
| 735 | + ), |
| 736 | + "x": 1, |
| 737 | + } |
| 738 | + ) |
| 739 | + result = get_dummies(df) |
| 740 | + tm.assert_frame_equal(result, expected) |
0 commit comments