diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index 861d802d3ba62..f1a6ea1d01727 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -828,6 +828,8 @@ Other - Bug in :func:`api.interchange.from_dataframe` when converting an empty DataFrame object (:issue:`53155`) - Bug in :func:`assert_almost_equal` now throwing assertion error for two unequal sets (:issue:`51727`) - Bug in :func:`assert_frame_equal` checks category dtypes even when asked not to check index type (:issue:`52126`) +- Bug in :func:`from_dummies` where the resulting :class:`Index` did not match the original :class:`Index` (:issue:`54300`) +- Bug in :func:`from_dummies` where the resulting data would always be ``object`` dtype instead of the dtype of the columns (:issue:`54300`) - Bug in :meth:`DataFrame.pivot_table` with casting the mean of ints back to an int (:issue:`16676`) - Bug in :meth:`DataFrame.reindex` with a ``fill_value`` that should be inferred with a :class:`ExtensionDtype` incorrectly inferring ``object`` dtype (:issue:`52586`) - Bug in :meth:`DataFrame.shift` and :meth:`Series.shift` and :meth:`DataFrameGroupBy.shift` when passing both "freq" and "fill_value" silently ignoring "fill_value" instead of raising ``ValueError`` (:issue:`53832`) diff --git a/pandas/core/reshape/encoding.py b/pandas/core/reshape/encoding.py index e30881e1a79c6..9ebce3a71c966 100644 --- a/pandas/core/reshape/encoding.py +++ b/pandas/core/reshape/encoding.py @@ -534,8 +534,13 @@ def from_dummies( ) else: data_slice = data_to_decode.loc[:, prefix_slice] - cats_array = np.array(cats, dtype="object") + cats_array = data._constructor_sliced(cats, dtype=data.columns.dtype) # get indices of True entries along axis=1 - cat_data[prefix] = cats_array[data_slice.to_numpy().nonzero()[1]] + true_values = data_slice.idxmax(axis=1) + indexer = data_slice.columns.get_indexer_for(true_values) + cat_data[prefix] = cats_array.take(indexer).set_axis(data.index) - return DataFrame(cat_data) + result = DataFrame(cat_data) + if sep is not None: + result.columns = result.columns.astype(data.columns.dtype) + return result diff --git a/pandas/tests/reshape/test_from_dummies.py b/pandas/tests/reshape/test_from_dummies.py index aab49538a2dcf..0074a90d7a51e 100644 --- a/pandas/tests/reshape/test_from_dummies.py +++ b/pandas/tests/reshape/test_from_dummies.py @@ -257,7 +257,7 @@ def test_no_prefix_int_cats_basic(): dummies = DataFrame( {1: [1, 0, 0, 0], 25: [0, 1, 0, 0], 2: [0, 0, 1, 0], 5: [0, 0, 0, 1]} ) - expected = DataFrame({"": [1, 25, 2, 5]}, dtype="object") + expected = DataFrame({"": [1, 25, 2, 5]}) result = from_dummies(dummies) tm.assert_frame_equal(result, expected) @@ -266,7 +266,7 @@ def test_no_prefix_float_cats_basic(): dummies = DataFrame( {1.0: [1, 0, 0, 0], 25.0: [0, 1, 0, 0], 2.5: [0, 0, 1, 0], 5.84: [0, 0, 0, 1]} ) - expected = DataFrame({"": [1.0, 25.0, 2.5, 5.84]}, dtype="object") + expected = DataFrame({"": [1.0, 25.0, 2.5, 5.84]}) result = from_dummies(dummies) tm.assert_frame_equal(result, expected) @@ -399,3 +399,45 @@ def test_with_prefix_default_category( dummies_with_unassigned, sep="_", default_category=default_category ) tm.assert_frame_equal(result, expected) + + +def test_ea_categories(): + # GH 54300 + df = DataFrame({"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]}) + df.columns = df.columns.astype("string[python]") + result = from_dummies(df) + expected = DataFrame({"": Series(list("abca"), dtype="string[python]")}) + tm.assert_frame_equal(result, expected) + + +def test_ea_categories_with_sep(): + # GH 54300 + df = DataFrame( + { + "col1_a": [1, 0, 1], + "col1_b": [0, 1, 0], + "col2_a": [0, 1, 0], + "col2_b": [1, 0, 0], + "col2_c": [0, 0, 1], + } + ) + df.columns = df.columns.astype("string[python]") + result = from_dummies(df, sep="_") + expected = DataFrame( + { + "col1": Series(list("aba"), dtype="string[python]"), + "col2": Series(list("bac"), dtype="string[python]"), + } + ) + expected.columns = expected.columns.astype("string[python]") + tm.assert_frame_equal(result, expected) + + +def test_maintain_original_index(): + # GH 54300 + df = DataFrame( + {"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]}, index=list("abcd") + ) + result = from_dummies(df) + expected = DataFrame({"": list("abca")}, index=list("abcd")) + tm.assert_frame_equal(result, expected)