@@ -752,7 +752,7 @@ def _convert_level_number(level_num, columns):
752
752
753
753
754
754
def from_dummies (
755
- data , prefix = None , prefix_sep = "_" , dtype = "category" , fill_first = None
755
+ data , prefix = None , prefix_sep = "_" , dtype = "category"
756
756
) -> "DataFrame" :
757
757
"""
758
758
The inverse transformation of ``pandas.get_dummies``.
@@ -762,14 +762,13 @@ def from_dummies(
762
762
data : DataFrame
763
763
Data which contains dummy indicators.
764
764
prefix : list-like, default None
765
- Prefixes of the columns in the DataFrame to be decoded.
766
- If `prefix` is None then all the columns will be decoded.
765
+ How to name the decoded groups of columns. If there are columns
766
+ containing `prefix_sep`, then the part of their name preceding
767
+ `prefix_sep` will be used (see examples below).
767
768
prefix_sep : str, default '_'
768
769
Separator between original column name and dummy variable.
769
770
dtype : dtype, default 'category'
770
771
Data dtype for new columns - only a single data type is allowed.
771
- fill_first : str, list, or dict, default None
772
- Used to fill rows for which all the dummy variables are 0.
773
772
774
773
Returns
775
774
-------
@@ -782,90 +781,105 @@ def from_dummies(
782
781
783
782
>>> df = pd.DataFrame(
784
783
... {
785
- ... "animal_baboon": [0, 0, 1],
786
- ... "animal_lemur": [0, 1, 0],
787
- ... "animal_zebra": [1, 0, 0],
788
- ... "other_col": ["a", "b", "c"],
784
+ ... "baboon": [0, 0, 1],
785
+ ... "lemur": [0, 1, 0],
786
+ ... "zebra": [1, 0, 0],
789
787
... }
790
788
... )
791
789
>>> df
792
- animal_baboon animal_lemur animal_zebra other_col
793
- 0 0 0 1 a
794
- 1 0 1 0 b
795
- 2 1 0 0 c
790
+ baboon lemur zebra
791
+ 0 0 0 1
792
+ 1 0 1 0
793
+ 2 1 0 0
796
794
797
795
We can recover the original dataframe using `from_dummies`:
798
796
799
- >>> pd.from_dummies(df, prefix=[ 'animal'] )
800
- other_col animal
801
- 0 a zebra
802
- 1 b lemur
803
- 2 c baboon
797
+ >>> pd.from_dummies(df, prefix='animal')
798
+ animal
799
+ 0 zebra
800
+ 1 lemur
801
+ 2 baboon
804
802
805
- Suppose our dataframe has one column from each dummified column
806
- dropped :
803
+ If our dataframe already has columns with `prefix_sep` in them,
804
+ we don't need to pass in the `prefix` argument :
807
805
808
- >>> df = df.drop('animal_zebra', axis=1)
806
+ >>> df = pd.DataFrame(
807
+ ... {
808
+ ... "animal_baboon": [0, 0, 1],
809
+ ... "animal_lemur": [0, 1, 0],
810
+ ... "animal_zebra": [1, 0, 0],
811
+ ... "other": ['a', 'b', 'c'],
812
+ ... }
813
+ ... )
809
814
>>> df
810
- animal_baboon animal_lemur other_col
811
- 0 0 0 a
812
- 1 0 1 b
813
- 2 1 0 c
814
-
815
- We can still recover the original dataframe, by using the argument
816
- `fill_first`:
817
-
818
- >>> pd.from_dummies(df, prefix=["animal"], fill_first=["zebra"])
819
- other_col animal
820
- 0 a zebra
821
- 1 b lemur
822
- 2 c baboon
815
+ animal_baboon animal_lemur animal_zebra other
816
+ 0 0 0 1 a
817
+ 1 0 1 0 b
818
+ 2 1 0 0 c
819
+
820
+ >>> pd.from_dummies(df)
821
+ other animal
822
+ 0 a zebra
823
+ 1 b lemur
824
+ 2 c baboon
823
825
"""
824
826
if dtype is None :
825
827
dtype = "category"
826
828
827
- if prefix is None :
828
- data_to_decode = data .copy ()
829
- prefix = data .columns .tolist ()
830
- prefix = list ({i .split (prefix_sep )[0 ] for i in data .columns if prefix_sep in i })
829
+ columns_to_decode = [i for i in data .columns if prefix_sep in i ]
830
+ if not columns_to_decode :
831
+ if prefix is None :
832
+ raise ValueError (
833
+ "If no columns contain `prefix_sep`, you must"
834
+ " pass a value to `prefix` with which to name"
835
+ " the decoded columns."
836
+ )
837
+ # If no column contains `prefix_sep`, we add `prefix`_`prefix_sep` to
838
+ # each column.
839
+ out = data .rename (columns = lambda x : f'{ prefix } { prefix_sep } { x } ' ).copy ()
840
+ columns_to_decode = out .columns
841
+ else :
842
+ out = data .copy ()
831
843
832
- data_to_decode = data [
833
- [i for i in data .columns for p in prefix if i .startswith (p + prefix_sep )]
834
- ]
844
+ data_to_decode = out [columns_to_decode ]
835
845
836
- # Check each row sums to 1 or 0
837
- if not all (i in [0 , 1 ] for i in data_to_decode .sum (axis = 1 ).unique ().tolist ()):
838
- raise ValueError (
839
- "Data cannot be decoded! Each row must contain only 0s and"
840
- " 1s, and each row may have at most one 1"
841
- )
846
+ if prefix is None :
847
+ # If no prefix has been passed, extract it from columns containing
848
+ # `prefix_sep`
849
+ seen = set ()
850
+ prefix = []
851
+ for i in columns_to_decode :
852
+ i = i .split (prefix_sep )[0 ]
853
+ if i in seen :
854
+ continue
855
+ seen .add (i )
856
+ prefix .append (i )
857
+ elif isinstance (prefix , str ):
858
+ prefix = [prefix ]
842
859
843
- if fill_first is None :
844
- fill_first = [None ] * len (prefix )
845
- elif isinstance (fill_first , str ):
846
- fill_first = itertools .cycle ([fill_first ])
847
- elif isinstance (fill_first , dict ):
848
- fill_first = [fill_first [p ] for p in prefix ]
860
+ # Check each row sums to 1 or 0
861
+ def _validate_values (data ):
862
+ if not all (i in [0 , 1 ] for i in data .sum (axis = 1 ).unique ().tolist ()):
863
+ raise ValueError (
864
+ "Data cannot be decoded! Each row must contain only 0s and"
865
+ " 1s, and each row may have at most one 1."
866
+ )
849
867
850
- out = data .copy ()
851
- for prefix_ , fill_first_ in zip (prefix , fill_first ):
852
- cols , labels = [
868
+ for prefix_ in prefix :
869
+ cols , labels = (
853
870
[
854
871
i .replace (x , "" )
855
872
for i in data_to_decode .columns
856
873
if prefix_ + prefix_sep in i
857
874
]
858
875
for x in ["" , prefix_ + prefix_sep ]
859
- ]
876
+ )
860
877
if not cols :
861
878
continue
879
+ _validate_values (data_to_decode [cols ])
862
880
out = out .drop (cols , axis = 1 )
863
- if fill_first_ :
864
- cols = [prefix_ + prefix_sep + fill_first_ ] + cols
865
- labels = [fill_first_ ] + labels
866
- data [cols [0 ]] = (1 - data [cols [1 :]]).all (axis = 1 )
867
881
out [prefix_ ] = Series (
868
- np .array (labels )[np .argmax (data [cols ].to_numpy (), axis = 1 )], dtype = dtype
882
+ np .array (labels )[np .argmax (data_to_decode [cols ].to_numpy (), axis = 1 )], dtype = dtype
869
883
)
870
884
return out
871
885
0 commit comments