Skip to content

Commit f5d9088

Browse files
author
Marco Gorelli
committed
infer prefixes
1 parent 9eed071 commit f5d9088

File tree

2 files changed

+88
-79
lines changed

2 files changed

+88
-79
lines changed

pandas/core/reshape/reshape.py

+76-62
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def _convert_level_number(level_num, columns):
752752

753753

754754
def from_dummies(
755-
data, prefix=None, prefix_sep="_", dtype="category", fill_first=None
755+
data, prefix=None, prefix_sep="_", dtype="category"
756756
) -> "DataFrame":
757757
"""
758758
The inverse transformation of ``pandas.get_dummies``.
@@ -762,14 +762,13 @@ def from_dummies(
762762
data : DataFrame
763763
Data which contains dummy indicators.
764764
prefix : list-like, default None
765-
Prefixes of the columns in the DataFrame to be decoded.
766-
If `prefix` is None then all the columns will be decoded.
765+
How to name the decoded groups of columns. If there are columns
766+
containing `prefix_sep`, then the part of their name preceding
767+
`prefix_sep` will be used (see examples below).
767768
prefix_sep : str, default '_'
768769
Separator between original column name and dummy variable.
769770
dtype : dtype, default 'category'
770771
Data dtype for new columns - only a single data type is allowed.
771-
fill_first : str, list, or dict, default None
772-
Used to fill rows for which all the dummy variables are 0.
773772
774773
Returns
775774
-------
@@ -782,90 +781,105 @@ def from_dummies(
782781
783782
>>> df = pd.DataFrame(
784783
... {
785-
... "animal_baboon": [0, 0, 1],
786-
... "animal_lemur": [0, 1, 0],
787-
... "animal_zebra": [1, 0, 0],
788-
... "other_col": ["a", "b", "c"],
784+
... "baboon": [0, 0, 1],
785+
... "lemur": [0, 1, 0],
786+
... "zebra": [1, 0, 0],
789787
... }
790788
... )
791789
>>> df
792-
animal_baboon animal_lemur animal_zebra other_col
793-
0 0 0 1 a
794-
1 0 1 0 b
795-
2 1 0 0 c
790+
baboon lemur zebra
791+
0 0 0 1
792+
1 0 1 0
793+
2 1 0 0
796794
797795
We can recover the original dataframe using `from_dummies`:
798796
799-
>>> pd.from_dummies(df, prefix=['animal'])
800-
other_col animal
801-
0 a zebra
802-
1 b lemur
803-
2 c baboon
797+
>>> pd.from_dummies(df, prefix='animal')
798+
animal
799+
0 zebra
800+
1 lemur
801+
2 baboon
804802
805-
Suppose our dataframe has one column from each dummified column
806-
dropped:
803+
If our dataframe already has columns with `prefix_sep` in them,
804+
we don't need to pass in the `prefix` argument:
807805
808-
>>> df = df.drop('animal_zebra', axis=1)
806+
>>> df = pd.DataFrame(
807+
... {
808+
... "animal_baboon": [0, 0, 1],
809+
... "animal_lemur": [0, 1, 0],
810+
... "animal_zebra": [1, 0, 0],
811+
... "other": ['a', 'b', 'c'],
812+
... }
813+
... )
809814
>>> df
810-
animal_baboon animal_lemur other_col
811-
0 0 0 a
812-
1 0 1 b
813-
2 1 0 c
814-
815-
We can still recover the original dataframe, by using the argument
816-
`fill_first`:
817-
818-
>>> pd.from_dummies(df, prefix=["animal"], fill_first=["zebra"])
819-
other_col animal
820-
0 a zebra
821-
1 b lemur
822-
2 c baboon
815+
animal_baboon animal_lemur animal_zebra other
816+
0 0 0 1 a
817+
1 0 1 0 b
818+
2 1 0 0 c
819+
820+
>>> pd.from_dummies(df)
821+
other animal
822+
0 a zebra
823+
1 b lemur
824+
2 c baboon
823825
"""
824826
if dtype is None:
825827
dtype = "category"
826828

827-
if prefix is None:
828-
data_to_decode = data.copy()
829-
prefix = data.columns.tolist()
830-
prefix = list({i.split(prefix_sep)[0] for i in data.columns if prefix_sep in i})
829+
columns_to_decode = [i for i in data.columns if prefix_sep in i]
830+
if not columns_to_decode:
831+
if prefix is None:
832+
raise ValueError(
833+
"If no columns contain `prefix_sep`, you must"
834+
" pass a value to `prefix` with which to name"
835+
" the decoded columns."
836+
)
837+
# If no column contains `prefix_sep`, we add `prefix`_`prefix_sep` to
838+
# each column.
839+
out = data.rename(columns = lambda x: f'{prefix}{prefix_sep}{x}').copy()
840+
columns_to_decode = out.columns
841+
else:
842+
out = data.copy()
831843

832-
data_to_decode = data[
833-
[i for i in data.columns for p in prefix if i.startswith(p + prefix_sep)]
834-
]
844+
data_to_decode = out[columns_to_decode]
835845

836-
# Check each row sums to 1 or 0
837-
if not all(i in [0, 1] for i in data_to_decode.sum(axis=1).unique().tolist()):
838-
raise ValueError(
839-
"Data cannot be decoded! Each row must contain only 0s and"
840-
" 1s, and each row may have at most one 1"
841-
)
846+
if prefix is None:
847+
# If no prefix has been passed, extract it from columns containing
848+
# `prefix_sep`
849+
seen = set()
850+
prefix = []
851+
for i in columns_to_decode:
852+
i = i.split(prefix_sep)[0]
853+
if i in seen:
854+
continue
855+
seen.add(i)
856+
prefix.append(i)
857+
elif isinstance(prefix, str):
858+
prefix = [prefix]
842859

843-
if fill_first is None:
844-
fill_first = [None] * len(prefix)
845-
elif isinstance(fill_first, str):
846-
fill_first = itertools.cycle([fill_first])
847-
elif isinstance(fill_first, dict):
848-
fill_first = [fill_first[p] for p in prefix]
860+
# Check each row sums to 1 or 0
861+
def _validate_values(data):
862+
if not all(i in [0, 1] for i in data.sum(axis=1).unique().tolist()):
863+
raise ValueError(
864+
"Data cannot be decoded! Each row must contain only 0s and"
865+
" 1s, and each row may have at most one 1."
866+
)
849867

850-
out = data.copy()
851-
for prefix_, fill_first_ in zip(prefix, fill_first):
852-
cols, labels = [
868+
for prefix_ in prefix:
869+
cols, labels = (
853870
[
854871
i.replace(x, "")
855872
for i in data_to_decode.columns
856873
if prefix_ + prefix_sep in i
857874
]
858875
for x in ["", prefix_ + prefix_sep]
859-
]
876+
)
860877
if not cols:
861878
continue
879+
_validate_values(data_to_decode[cols])
862880
out = out.drop(cols, axis=1)
863-
if fill_first_:
864-
cols = [prefix_ + prefix_sep + fill_first_] + cols
865-
labels = [fill_first_] + labels
866-
data[cols[0]] = (1 - data[cols[1:]]).all(axis=1)
867881
out[prefix_] = Series(
868-
np.array(labels)[np.argmax(data[cols].to_numpy(), axis=1)], dtype=dtype
882+
np.array(labels)[np.argmax(data_to_decode[cols].to_numpy(), axis=1)], dtype=dtype
869883
)
870884
return out
871885

pandas/tests/reshape/test_from_dummies.py

+12-17
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,6 @@ def test_dtype(dtype, expected_dict):
2121
tm.assert_frame_equal(result, expected)
2222

2323

24-
@pytest.mark.parametrize(
25-
"fill_first, expected_dict",
26-
[
27-
("a", {"col1": ["a", "a", "b"]}),
28-
(["a"], {"col1": ["a", "a", "b"]}),
29-
({"col1": "a"}, {"col1": ["a", "a", "b"]}),
30-
],
31-
)
32-
def test_fill_first(fill_first, expected_dict):
33-
df = pd.DataFrame({"col1_b": [0, 0, 1]})
34-
result = pd.from_dummies(df, fill_first=fill_first)
35-
# get_dummies changes the ordering of columns,
36-
# see https://github.com/pandas-dev/pandas/issues/17612
37-
expected = pd.DataFrame(expected_dict, dtype="category")
38-
tm.assert_frame_equal(result, expected)
39-
40-
4124
def test_malformed():
4225
df = pd.DataFrame({"col1_a": [1, 1, 0], "col1_b": [1, 0, 1]})
4326
msg = (
@@ -61,3 +44,15 @@ def test_prefix_sep(prefix_sep, input_dict):
6144
result = pd.from_dummies(df, prefix_sep=prefix_sep)
6245
expected = pd.DataFrame({"col1": ["a", "a", "b"]}, dtype="category")
6346
tm.assert_frame_equal(result, expected)
47+
48+
def test_no_prefix():
49+
df = pd.DataFrame({"a": [1, 1, 0], "b": [0, 0, 1]})
50+
result = pd.from_dummies(df, prefix='letter')
51+
expected = pd.DataFrame({'letter': ['a', 'a', 'b']}, dtype='category')
52+
tm.assert_frame_equal(result, expected)
53+
54+
def test_multiple_columns():
55+
df = pd.DataFrame({"col1_a": [1, 0], "col1_b": [0, 1], "col2_a": [0, 0], "col2_c": [1, 1]})
56+
result = pd.from_dummies(df)
57+
expected = pd.DataFrame({'col1': ['a', 'b'], 'col2': ['c', 'c']}, dtype='category')
58+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)