Skip to content

Commit 476fcb1

Browse files
authored
BUG: from_dummies always returning object data (#54300)
* BUG: from_dummies always returning object data * misspelling * Add Gh issue
1 parent 4cf63ea commit 476fcb1

File tree

3 files changed

+54
-5
lines changed

3 files changed

+54
-5
lines changed

doc/source/whatsnew/v2.1.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,8 @@ Other
835835
- Bug in :func:`api.interchange.from_dataframe` when converting an empty DataFrame object (:issue:`53155`)
836836
- Bug in :func:`assert_almost_equal` now throwing assertion error for two unequal sets (:issue:`51727`)
837837
- Bug in :func:`assert_frame_equal` checks category dtypes even when asked not to check index type (:issue:`52126`)
838+
- Bug in :func:`from_dummies` where the resulting :class:`Index` did not match the original :class:`Index` (:issue:`54300`)
839+
- Bug in :func:`from_dummies` where the resulting data would always be ``object`` dtype instead of the dtype of the columns (:issue:`54300`)
838840
- Bug in :meth:`DataFrame.pivot_table` with casting the mean of ints back to an int (:issue:`16676`)
839841
- Bug in :meth:`DataFrame.reindex` with a ``fill_value`` that should be inferred with a :class:`ExtensionDtype` incorrectly inferring ``object`` dtype (:issue:`52586`)
840842
- Bug in :meth:`DataFrame.shift` and :meth:`Series.shift` and :meth:`DataFrameGroupBy.shift` when passing both "freq" and "fill_value" silently ignoring "fill_value" instead of raising ``ValueError`` (:issue:`53832`)

pandas/core/reshape/encoding.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -534,8 +534,13 @@ def from_dummies(
534534
)
535535
else:
536536
data_slice = data_to_decode.loc[:, prefix_slice]
537-
cats_array = np.array(cats, dtype="object")
537+
cats_array = data._constructor_sliced(cats, dtype=data.columns.dtype)
538538
# get indices of True entries along axis=1
539-
cat_data[prefix] = cats_array[data_slice.to_numpy().nonzero()[1]]
539+
true_values = data_slice.idxmax(axis=1)
540+
indexer = data_slice.columns.get_indexer_for(true_values)
541+
cat_data[prefix] = cats_array.take(indexer).set_axis(data.index)
540542

541-
return DataFrame(cat_data)
543+
result = DataFrame(cat_data)
544+
if sep is not None:
545+
result.columns = result.columns.astype(data.columns.dtype)
546+
return result

pandas/tests/reshape/test_from_dummies.py

+44-2
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def test_no_prefix_int_cats_basic():
257257
dummies = DataFrame(
258258
{1: [1, 0, 0, 0], 25: [0, 1, 0, 0], 2: [0, 0, 1, 0], 5: [0, 0, 0, 1]}
259259
)
260-
expected = DataFrame({"": [1, 25, 2, 5]}, dtype="object")
260+
expected = DataFrame({"": [1, 25, 2, 5]})
261261
result = from_dummies(dummies)
262262
tm.assert_frame_equal(result, expected)
263263

@@ -266,7 +266,7 @@ def test_no_prefix_float_cats_basic():
266266
dummies = DataFrame(
267267
{1.0: [1, 0, 0, 0], 25.0: [0, 1, 0, 0], 2.5: [0, 0, 1, 0], 5.84: [0, 0, 0, 1]}
268268
)
269-
expected = DataFrame({"": [1.0, 25.0, 2.5, 5.84]}, dtype="object")
269+
expected = DataFrame({"": [1.0, 25.0, 2.5, 5.84]})
270270
result = from_dummies(dummies)
271271
tm.assert_frame_equal(result, expected)
272272

@@ -399,3 +399,45 @@ def test_with_prefix_default_category(
399399
dummies_with_unassigned, sep="_", default_category=default_category
400400
)
401401
tm.assert_frame_equal(result, expected)
402+
403+
404+
def test_ea_categories():
405+
# GH 54300
406+
df = DataFrame({"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]})
407+
df.columns = df.columns.astype("string[python]")
408+
result = from_dummies(df)
409+
expected = DataFrame({"": Series(list("abca"), dtype="string[python]")})
410+
tm.assert_frame_equal(result, expected)
411+
412+
413+
def test_ea_categories_with_sep():
414+
# GH 54300
415+
df = DataFrame(
416+
{
417+
"col1_a": [1, 0, 1],
418+
"col1_b": [0, 1, 0],
419+
"col2_a": [0, 1, 0],
420+
"col2_b": [1, 0, 0],
421+
"col2_c": [0, 0, 1],
422+
}
423+
)
424+
df.columns = df.columns.astype("string[python]")
425+
result = from_dummies(df, sep="_")
426+
expected = DataFrame(
427+
{
428+
"col1": Series(list("aba"), dtype="string[python]"),
429+
"col2": Series(list("bac"), dtype="string[python]"),
430+
}
431+
)
432+
expected.columns = expected.columns.astype("string[python]")
433+
tm.assert_frame_equal(result, expected)
434+
435+
436+
def test_maintain_original_index():
437+
# GH 54300
438+
df = DataFrame(
439+
{"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]}, index=list("abcd")
440+
)
441+
result = from_dummies(df)
442+
expected = DataFrame({"": list("abca")}, index=list("abcd"))
443+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)