Skip to content

Commit 9b51ab2

Browse files
phofllopof
andauthored
ENH: Make get_dummies return ea booleans for ea inputs (#56291)
* ENH: Make get_dummies return ea booleans for ea inputs * ENH: Make get_dummies return ea booleans for ea inputs * Update * Update pandas/tests/reshape/test_get_dummies.py Co-authored-by: Thomas Baumann <[email protected]> * Update test_get_dummies.py * Update test_get_dummies.py * Fixup --------- Co-authored-by: Thomas Baumann <[email protected]>
1 parent 8aa7a96 commit 9b51ab2

File tree

3 files changed

+69
-1
lines changed

3 files changed

+69
-1
lines changed

doc/source/whatsnew/v2.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ Other enhancements
218218

219219
- :meth:`~DataFrame.to_sql` with method parameter set to ``multi`` works with Oracle on the backend
220220
- :attr:`Series.attrs` / :attr:`DataFrame.attrs` now uses a deepcopy for propagating ``attrs`` (:issue:`54134`).
221+
- :func:`get_dummies` now returning extension dtypes ``boolean`` or ``bool[pyarrow]`` that are compatible with the input dtype (:issue:`56273`)
221222
- :func:`read_csv` now supports ``on_bad_lines`` parameter with ``engine="pyarrow"``. (:issue:`54480`)
222223
- :func:`read_sas` returns ``datetime64`` dtypes with resolutions better matching those stored natively in SAS, and avoids returning object-dtype in cases that cannot be stored with ``datetime64[ns]`` dtype (:issue:`56127`)
223224
- :func:`read_spss` now returns a :class:`DataFrame` that stores the metadata in :attr:`DataFrame.attrs`. (:issue:`54264`)

pandas/core/reshape/encoding.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,14 @@
2121
is_object_dtype,
2222
pandas_dtype,
2323
)
24+
from pandas.core.dtypes.dtypes import (
25+
ArrowDtype,
26+
CategoricalDtype,
27+
)
2428

2529
from pandas.core.arrays import SparseArray
2630
from pandas.core.arrays.categorical import factorize_from_iterable
31+
from pandas.core.arrays.string_ import StringDtype
2732
from pandas.core.frame import DataFrame
2833
from pandas.core.indexes.api import (
2934
Index,
@@ -244,8 +249,25 @@ def _get_dummies_1d(
244249
# Series avoids inconsistent NaN handling
245250
codes, levels = factorize_from_iterable(Series(data, copy=False))
246251

247-
if dtype is None:
252+
if dtype is None and hasattr(data, "dtype"):
253+
input_dtype = data.dtype
254+
if isinstance(input_dtype, CategoricalDtype):
255+
input_dtype = input_dtype.categories.dtype
256+
257+
if isinstance(input_dtype, ArrowDtype):
258+
import pyarrow as pa
259+
260+
dtype = ArrowDtype(pa.bool_()) # type: ignore[assignment]
261+
elif (
262+
isinstance(input_dtype, StringDtype)
263+
and input_dtype.storage != "pyarrow_numpy"
264+
):
265+
dtype = pandas_dtype("boolean") # type: ignore[assignment]
266+
else:
267+
dtype = np.dtype(bool)
268+
elif dtype is None:
248269
dtype = np.dtype(bool)
270+
249271
_dtype = pandas_dtype(dtype)
250272

251273
if is_object_dtype(_dtype):

pandas/tests/reshape/test_get_dummies.py

+45
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,18 @@
44
import numpy as np
55
import pytest
66

7+
import pandas.util._test_decorators as td
8+
79
from pandas.core.dtypes.common import is_integer_dtype
810

911
import pandas as pd
1012
from pandas import (
13+
ArrowDtype,
1114
Categorical,
15+
CategoricalDtype,
1216
CategoricalIndex,
1317
DataFrame,
18+
Index,
1419
RangeIndex,
1520
Series,
1621
SparseDtype,
@@ -19,6 +24,11 @@
1924
import pandas._testing as tm
2025
from pandas.core.arrays.sparse import SparseArray
2126

27+
try:
28+
import pyarrow as pa
29+
except ImportError:
30+
pa = None
31+
2232

2333
class TestGetDummies:
2434
@pytest.fixture
@@ -217,6 +227,7 @@ def test_dataframe_dummies_string_dtype(self, df):
217227
},
218228
dtype=bool,
219229
)
230+
expected[["B_b", "B_c"]] = expected[["B_b", "B_c"]].astype("boolean")
220231
tm.assert_frame_equal(result, expected)
221232

222233
def test_dataframe_dummies_mix_default(self, df, sparse, dtype):
@@ -693,3 +704,37 @@ def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_and_arrow_dtype):
693704
dtype=any_numeric_ea_and_arrow_dtype,
694705
)
695706
tm.assert_frame_equal(result, expected)
707+
708+
@td.skip_if_no("pyarrow")
709+
def test_get_dummies_ea_dtype(self):
710+
# GH#56273
711+
for dtype, exp_dtype in [
712+
("string[pyarrow]", "boolean"),
713+
("string[pyarrow_numpy]", "bool"),
714+
(CategoricalDtype(Index(["a"], dtype="string[pyarrow]")), "boolean"),
715+
(CategoricalDtype(Index(["a"], dtype="string[pyarrow_numpy]")), "bool"),
716+
]:
717+
df = DataFrame({"name": Series(["a"], dtype=dtype), "x": 1})
718+
result = get_dummies(df)
719+
expected = DataFrame({"x": 1, "name_a": Series([True], dtype=exp_dtype)})
720+
tm.assert_frame_equal(result, expected)
721+
722+
@td.skip_if_no("pyarrow")
723+
def test_get_dummies_arrow_dtype(self):
724+
# GH#56273
725+
df = DataFrame({"name": Series(["a"], dtype=ArrowDtype(pa.string())), "x": 1})
726+
result = get_dummies(df)
727+
expected = DataFrame({"x": 1, "name_a": Series([True], dtype="bool[pyarrow]")})
728+
tm.assert_frame_equal(result, expected)
729+
730+
df = DataFrame(
731+
{
732+
"name": Series(
733+
["a"],
734+
dtype=CategoricalDtype(Index(["a"], dtype=ArrowDtype(pa.string()))),
735+
),
736+
"x": 1,
737+
}
738+
)
739+
result = get_dummies(df)
740+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)