Skip to content

Commit 13a40ff

Browse files
test_with_prefix_default_category
1 parent 4aa160e commit 13a40ff

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

pandas/core/reshape/encoding.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
from pandas._libs import missing as libmissing
1414
from pandas._libs.sparse import IntIndex
1515

16+
from pandas.core.dtypes.cast import (
17+
find_common_type,
18+
infer_dtype_from_scalar,
19+
)
1620
from pandas.core.dtypes.common import (
1721
is_integer_dtype,
1822
is_list_like,
@@ -567,7 +571,13 @@ def from_dummies(
567571
)
568572
else:
569573
data_slice = data_to_decode.loc[:, prefix_slice]
570-
cats_array = data._constructor_sliced(cats, dtype=data.columns.dtype)
574+
dtype = data.columns.dtype
575+
if default_category:
576+
default_category_dtype = infer_dtype_from_scalar(default_category[prefix])[
577+
0
578+
]
579+
dtype = find_common_type([dtype, default_category_dtype])
580+
cats_array = data._constructor_sliced(cats, dtype=dtype)
571581
# get indices of True entries along axis=1
572582
true_values = data_slice.idxmax(axis=1)
573583
indexer = data_slice.columns.get_indexer_for(true_values)

pandas/tests/reshape/test_from_dummies.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import numpy as np
22
import pytest
33

4-
from pandas._config import using_string_dtype
5-
64
from pandas import (
75
DataFrame,
86
Series,
@@ -330,14 +328,10 @@ def test_no_prefix_string_cats_contains_get_dummies_NaN_column():
330328
),
331329
],
332330
)
333-
def test_no_prefix_string_cats_default_category(
334-
default_category, expected, using_infer_string
335-
):
331+
def test_no_prefix_string_cats_default_category(default_category, expected):
336332
dummies = DataFrame({"a": [1, 0, 0], "b": [0, 1, 0]})
337333
result = from_dummies(dummies, default_category=default_category)
338334
expected = DataFrame(expected)
339-
if using_infer_string:
340-
expected[""] = expected[""].astype("str")
341335
tm.assert_frame_equal(result, expected)
342336

343337

@@ -364,7 +358,6 @@ def test_with_prefix_contains_get_dummies_NaN_column():
364358
tm.assert_frame_equal(result, expected)
365359

366360

367-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
368361
@pytest.mark.parametrize(
369362
"default_category, expected",
370363
[
@@ -390,7 +383,7 @@ def test_with_prefix_contains_get_dummies_NaN_column():
390383
),
391384
pytest.param(
392385
{"col2": None, "col1": False},
393-
{"col1": ["a", "b", False], "col2": [None, "a", "c"]},
386+
{"col1": ["a", "b", False], "col2": Series([None, "a", "c"], dtype=object)},
394387
id="default_category is a dict with bool and None values",
395388
),
396389
pytest.param(

0 commit comments

Comments
 (0)