From 7ecafa1440b3ea62ad90acf1330b6108ec6b61de Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Sat, 2 Dec 2023 01:02:23 +0100 Subject: [PATCH 1/7] ENH: Make get_dummies return ea booleans for ea inputs --- doc/source/whatsnew/v2.2.0.rst | 1 + pandas/core/reshape/encoding.py | 21 ++++++++++++++- pandas/tests/reshape/test_get_dummies.py | 33 ++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index ade87c4215a38..0a143293a8f97 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -218,6 +218,7 @@ Other enhancements - :meth:`to_sql` with method parameter set to ``multi`` works with Oracle on the backend - :attr:`Series.attrs` / :attr:`DataFrame.attrs` now uses a deepcopy for propagating ``attrs`` (:issue:`54134`). +- :func:`get_dummies` now returning extension dtypes ``boolean`` or ``bool[pyarrow]`` that are compatible with the input dtype (:issue:`56273`) - :func:`read_csv` now supports ``on_bad_lines`` parameter with ``engine="pyarrow"``. (:issue:`54480`) - :func:`read_spss` now returns a :class:`DataFrame` that stores the metadata in :attr:`DataFrame.attrs`. (:issue:`54264`) - :func:`tseries.api.guess_datetime_format` is now part of the public API (:issue:`54727`) diff --git a/pandas/core/reshape/encoding.py b/pandas/core/reshape/encoding.py index 6963bf677bcfb..1cd4ecf961853 100644 --- a/pandas/core/reshape/encoding.py +++ b/pandas/core/reshape/encoding.py @@ -21,9 +21,14 @@ is_object_dtype, pandas_dtype, ) +from pandas.core.dtypes.dtypes import ( + ArrowDtype, + CategoricalDtype, +) from pandas.core.arrays import SparseArray from pandas.core.arrays.categorical import factorize_from_iterable +from pandas.core.arrays.string_ import StringDtype from pandas.core.frame import DataFrame from pandas.core.indexes.api import ( Index, @@ -245,7 +250,21 @@ def _get_dummies_1d( codes, levels = factorize_from_iterable(Series(data, copy=False)) if dtype is None: - dtype = np.dtype(bool) + input_dtype = data.dtype + if isinstance(input_dtype, CategoricalDtype): + input_dtype = input_dtype.categories.dtype + + if isinstance(input_dtype, ArrowDtype): + import pyarrow as pa + + dtype = ArrowDtype(pa.bool_()) + elif ( + isinstance(input_dtype, StringDtype) + and input_dtype.storage != "pyarrow_numpy" + ): + dtype = pandas_dtype("boolean") + else: + dtype = np.dtype(bool) _dtype = pandas_dtype(dtype) if is_object_dtype(_dtype): diff --git a/pandas/tests/reshape/test_get_dummies.py b/pandas/tests/reshape/test_get_dummies.py index 3bfff56cfedf2..5bd93edc8ffc0 100644 --- a/pandas/tests/reshape/test_get_dummies.py +++ b/pandas/tests/reshape/test_get_dummies.py @@ -4,13 +4,18 @@ import numpy as np import pytest +import pandas.util._test_decorators as td + from pandas.core.dtypes.common import is_integer_dtype import pandas as pd from pandas import ( + ArrowDtype, Categorical, + CategoricalDtype, CategoricalIndex, DataFrame, + Index, RangeIndex, Series, SparseDtype, @@ -19,6 +24,11 @@ import pandas._testing as tm from pandas.core.arrays.sparse import SparseArray +try: + import pyarrow as pa +except ImportError: + pa = None + class TestGetDummies: @pytest.fixture @@ -693,3 +703,26 @@ def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_and_arrow_dtype): dtype=any_numeric_ea_and_arrow_dtype, ) tm.assert_frame_equal(result, expected) + + @td.skip_if_no("pyarrow") + @pytest.mark.parametrize( + "dtype, exp_dtype", + [ + (ArrowDtype(pa.string()), "bool[pyarrow]"), + ("string[pyarrow]", "boolean"), + ("string[pyarrow_numpy]", "bool"), + ( + CategoricalDtype(Index(["a"], dtype=ArrowDtype(pa.string()))), + "bool[pyarrow]", + ), + (CategoricalDtype(Index(["a"], dtype="string[pyarrow]")), "boolean"), + (CategoricalDtype(Index(["a"], dtype="string[pyarrow_numpy]")), "bool"), + ], + ) + def test_get_dummies_arrow_dtype(self, dtype, exp_dtype): + # GH#56273 + + df = DataFrame({"name": Series(["a"], dtype=dtype), "x": 1}) + result = get_dummies(df) + expected = DataFrame({"x": 1, "name_a": Series([True], dtype=exp_dtype)}) + tm.assert_frame_equal(result, expected) From 043779b928547142bdf52b71af2762add08064c0 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Sat, 2 Dec 2023 01:04:26 +0100 Subject: [PATCH 2/7] ENH: Make get_dummies return ea booleans for ea inputs --- pandas/core/reshape/encoding.py | 5 ++++- pandas/tests/reshape/test_get_dummies.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pandas/core/reshape/encoding.py b/pandas/core/reshape/encoding.py index 1cd4ecf961853..83a929e6fbe08 100644 --- a/pandas/core/reshape/encoding.py +++ b/pandas/core/reshape/encoding.py @@ -249,7 +249,7 @@ def _get_dummies_1d( # Series avoids inconsistent NaN handling codes, levels = factorize_from_iterable(Series(data, copy=False)) - if dtype is None: + if dtype is None and hasattr(data, "dtype"): input_dtype = data.dtype if isinstance(input_dtype, CategoricalDtype): input_dtype = input_dtype.categories.dtype @@ -265,6 +265,9 @@ def _get_dummies_1d( dtype = pandas_dtype("boolean") else: dtype = np.dtype(bool) + elif dtype is None: + dtype = np.dtype(bool) + _dtype = pandas_dtype(dtype) if is_object_dtype(_dtype): diff --git a/pandas/tests/reshape/test_get_dummies.py b/pandas/tests/reshape/test_get_dummies.py index 5bd93edc8ffc0..c426141967d09 100644 --- a/pandas/tests/reshape/test_get_dummies.py +++ b/pandas/tests/reshape/test_get_dummies.py @@ -227,6 +227,7 @@ def test_dataframe_dummies_string_dtype(self, df): }, dtype=bool, ) + expected[["B_b", "B_c"]] = expected[["B_b", "B_c"]].astype("boolean") tm.assert_frame_equal(result, expected) def test_dataframe_dummies_mix_default(self, df, sparse, dtype): From 8775218607520d28dc42fe741310aaad6ef54adb Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Sat, 2 Dec 2023 20:15:58 +0100 Subject: [PATCH 3/7] Update --- pandas/core/reshape/encoding.py | 4 ++-- pandas/tests/reshape/test_get_dummies.py | 28 ++++++++++++++++++------ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/pandas/core/reshape/encoding.py b/pandas/core/reshape/encoding.py index 83a929e6fbe08..3ed67bb7b7c02 100644 --- a/pandas/core/reshape/encoding.py +++ b/pandas/core/reshape/encoding.py @@ -257,12 +257,12 @@ def _get_dummies_1d( if isinstance(input_dtype, ArrowDtype): import pyarrow as pa - dtype = ArrowDtype(pa.bool_()) + dtype = ArrowDtype(pa.bool_()) # type: ignore[assignment] elif ( isinstance(input_dtype, StringDtype) and input_dtype.storage != "pyarrow_numpy" ): - dtype = pandas_dtype("boolean") + dtype = pandas_dtype("boolean") # type: ignore[assignment] else: dtype = np.dtype(bool) elif dtype is None: diff --git a/pandas/tests/reshape/test_get_dummies.py b/pandas/tests/reshape/test_get_dummies.py index c426141967d09..eb7bfb1f1810e 100644 --- a/pandas/tests/reshape/test_get_dummies.py +++ b/pandas/tests/reshape/test_get_dummies.py @@ -709,21 +709,35 @@ def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_and_arrow_dtype): @pytest.mark.parametrize( "dtype, exp_dtype", [ - (ArrowDtype(pa.string()), "bool[pyarrow]"), ("string[pyarrow]", "boolean"), ("string[pyarrow_numpy]", "bool"), - ( - CategoricalDtype(Index(["a"], dtype=ArrowDtype(pa.string()))), - "bool[pyarrow]", - ), (CategoricalDtype(Index(["a"], dtype="string[pyarrow]")), "boolean"), (CategoricalDtype(Index(["a"], dtype="string[pyarrow_numpy]")), "bool"), ], ) - def test_get_dummies_arrow_dtype(self, dtype, exp_dtype): + def test_get_dummies_ea_dtyoe(self, dtype, exp_dtype): # GH#56273 - df = DataFrame({"name": Series(["a"], dtype=dtype), "x": 1}) result = get_dummies(df) expected = DataFrame({"x": 1, "name_a": Series([True], dtype=exp_dtype)}) tm.assert_frame_equal(result, expected) + + @td.skip_if_no("pyarrow") + def test_get_dummies_arrow_dtype(self): + # GH#56273 + df = DataFrame({"name": Series(["a"], dtype=ArrowDtype(pa.string())), "x": 1}) + result = get_dummies(df) + expected = DataFrame({"x": 1, "name_a": Series([True], dtype="bool[pyarrow]")}) + tm.assert_frame_equal(result, expected) + + df = DataFrame( + { + "name": Series( + ["a"], + dtype=CategoricalDtype(Index(["a"], dtype=ArrowDtype(pa.string()))), + ), + "x": 1, + } + ) + result = get_dummies(df) + tm.assert_frame_equal(result, expected) From 8863d53300cfa24a14bc46edfc238fb9cdebef34 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Mon, 4 Dec 2023 12:06:08 +0100 Subject: [PATCH 4/7] Update pandas/tests/reshape/test_get_dummies.py Co-authored-by: Thomas Baumann --- pandas/tests/reshape/test_get_dummies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/reshape/test_get_dummies.py b/pandas/tests/reshape/test_get_dummies.py index eb7bfb1f1810e..b3873b0242cfb 100644 --- a/pandas/tests/reshape/test_get_dummies.py +++ b/pandas/tests/reshape/test_get_dummies.py @@ -715,7 +715,7 @@ def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_and_arrow_dtype): (CategoricalDtype(Index(["a"], dtype="string[pyarrow_numpy]")), "bool"), ], ) - def test_get_dummies_ea_dtyoe(self, dtype, exp_dtype): + def test_get_dummies_ea_dtype(self, dtype, exp_dtype): # GH#56273 df = DataFrame({"name": Series(["a"], dtype=dtype), "x": 1}) result = get_dummies(df) From 546bdd1677490b0814ce97feba683c669b8fda5e Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 8 Dec 2023 00:00:37 +0100 Subject: [PATCH 5/7] Update test_get_dummies.py --- pandas/tests/reshape/test_get_dummies.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/pandas/tests/reshape/test_get_dummies.py b/pandas/tests/reshape/test_get_dummies.py index b3873b0242cfb..7b4cb9366af72 100644 --- a/pandas/tests/reshape/test_get_dummies.py +++ b/pandas/tests/reshape/test_get_dummies.py @@ -706,21 +706,19 @@ def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_and_arrow_dtype): tm.assert_frame_equal(result, expected) @td.skip_if_no("pyarrow") - @pytest.mark.parametrize( - "dtype, exp_dtype", - [ + def test_get_dummies_ea_dtype(self): + # GH#56273 + for dtype, exp_dtype in [ ("string[pyarrow]", "boolean"), ("string[pyarrow_numpy]", "bool"), (CategoricalDtype(Index(["a"], dtype="string[pyarrow]")), "boolean"), (CategoricalDtype(Index(["a"], dtype="string[pyarrow_numpy]")), "bool"), - ], - ) - def test_get_dummies_ea_dtype(self, dtype, exp_dtype): - # GH#56273 - df = DataFrame({"name": Series(["a"], dtype=dtype), "x": 1}) - result = get_dummies(df) - expected = DataFrame({"x": 1, "name_a": Series([True], dtype=exp_dtype)}) - tm.assert_frame_equal(result, expected) + ]: + + df = DataFrame({"name": Series(["a"], dtype=dtype), "x": 1}) + result = get_dummies(df) + expected = DataFrame({"x": 1, "name_a": Series([True], dtype=exp_dtype)}) + tm.assert_frame_equal(result, expected) @td.skip_if_no("pyarrow") def test_get_dummies_arrow_dtype(self): From 713ada400b6bf54cb5e1a7accff29f2d77a0454b Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 8 Dec 2023 23:09:00 +0100 Subject: [PATCH 6/7] Update test_get_dummies.py --- pandas/tests/reshape/test_get_dummies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/reshape/test_get_dummies.py b/pandas/tests/reshape/test_get_dummies.py index 7b4cb9366af72..2da61b6d58b30 100644 --- a/pandas/tests/reshape/test_get_dummies.py +++ b/pandas/tests/reshape/test_get_dummies.py @@ -714,7 +714,7 @@ def test_get_dummies_ea_dtype(self): (CategoricalDtype(Index(["a"], dtype="string[pyarrow]")), "boolean"), (CategoricalDtype(Index(["a"], dtype="string[pyarrow_numpy]")), "bool"), ]: - + df = DataFrame({"name": Series(["a"], dtype=dtype), "x": 1}) result = get_dummies(df) expected = DataFrame({"x": 1, "name_a": Series([True], dtype=exp_dtype)}) From 5848f93f3ce349cbeee62174e1ba0903dd04fd21 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Sun, 10 Dec 2023 00:10:44 +0100 Subject: [PATCH 7/7] Fixup --- pandas/tests/reshape/test_get_dummies.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pandas/tests/reshape/test_get_dummies.py b/pandas/tests/reshape/test_get_dummies.py index 2da61b6d58b30..9b7aefac60969 100644 --- a/pandas/tests/reshape/test_get_dummies.py +++ b/pandas/tests/reshape/test_get_dummies.py @@ -714,7 +714,6 @@ def test_get_dummies_ea_dtype(self): (CategoricalDtype(Index(["a"], dtype="string[pyarrow]")), "boolean"), (CategoricalDtype(Index(["a"], dtype="string[pyarrow_numpy]")), "bool"), ]: - df = DataFrame({"name": Series(["a"], dtype=dtype), "x": 1}) result = get_dummies(df) expected = DataFrame({"x": 1, "name_a": Series([True], dtype=exp_dtype)})