@@ -751,6 +751,138 @@ def _convert_level_number(level_num, columns):
751
751
return result
752
752
753
753
754
+ def from_dummies (data , columns = None , prefix_sep = "_" , dtype = "category" , fill_first = None ):
755
+ """
756
+ The inverse transformation of ``pandas.get_dummies``.
757
+
758
+ Parameters
759
+ ----------
760
+ data : DataFrame
761
+ columns : list-like, default None
762
+ Column names in the DataFrame to be decoded.
763
+ If `columns` is None then all the columns will be converted.
764
+ prefix_sep : str, default '_'
765
+ Separator between original column name and dummy variable
766
+ dtype : dtype, default 'category'
767
+ Data dtype for new columns - only a single data type is allowed
768
+ fill_first : str, list, or dict, default None
769
+ Used to fill rows for which all the dummy variables are 0
770
+
771
+ Returns
772
+ -------
773
+ transformed : DataFrame
774
+
775
+ Examples
776
+ --------
777
+ Say we have a dataframe where some variables have been dummified:
778
+
779
+ >>> df = pd.DataFrame(
780
+ ... {
781
+ ... "animal_baboon": [0, 0, 1],
782
+ ... "animal_lemur": [0, 1, 0],
783
+ ... "animal_zebra": [1, 0, 0],
784
+ ... "other_col": ["a", "b", "c"],
785
+ ... }
786
+ ... )
787
+ >>> df
788
+ animal_baboon animal_lemur animal_zebra other_col
789
+ 0 0 0 1 a
790
+ 1 0 1 0 b
791
+ 2 1 0 0 c
792
+
793
+ We can recover the original dataframe using `from_dummies`:
794
+
795
+ >>> pd.from_dummies(df, columns=['animal'])
796
+ other_col animal
797
+ 0 a zebra
798
+ 1 b lemur
799
+ 2 c baboon
800
+
801
+ Suppose our dataframe has one column from each dummified column
802
+ dropped:
803
+
804
+ >>> df = df.drop('animal_zebra', axis=1)
805
+ >>> df
806
+ animal_baboon animal_lemur other_col
807
+ 0 0 0 a
808
+ 1 0 1 b
809
+ 2 1 0 c
810
+
811
+ We can still recover the original dataframe, by using the argument
812
+ `fill_first`:
813
+
814
+ >>> pd.from_dummies(df, columns=["animal"], fill_first=["zebra"])
815
+ other_col animal
816
+ 0 a zebra
817
+ 1 b lemur
818
+ 2 c baboon
819
+ """
820
+ if dtype is None :
821
+ dtype = "category"
822
+
823
+ if columns is None :
824
+ data_to_decode = data .copy ()
825
+ columns = data .columns .tolist ()
826
+ columns = list (
827
+ {i .split (prefix_sep )[0 ] for i in data .columns if prefix_sep in i }
828
+ )
829
+
830
+ data_to_decode = data [
831
+ [i for i in data .columns for c in columns if i .startswith (c + prefix_sep )]
832
+ ]
833
+
834
+ # Check each row sums to 1 or 0
835
+ if not all (i in [0 , 1 ] for i in data_to_decode .sum (axis = 1 ).unique ().tolist ()):
836
+ raise ValueError (
837
+ "Data cannot be decoded! Each row must contain only 0s and"
838
+ " 1s, and each row may have at most one 1"
839
+ )
840
+
841
+ if fill_first is None :
842
+ fill_first = [None ] * len (columns )
843
+ elif isinstance (fill_first , str ):
844
+ fill_first = itertools .cycle ([fill_first ])
845
+ elif isinstance (fill_first , dict ):
846
+ fill_first = [fill_first [col ] for col in columns ]
847
+
848
+ out = data .copy ()
849
+ for column , fill_first_ in zip (columns , fill_first ):
850
+ cols , labels = [
851
+ [
852
+ i .replace (x , "" )
853
+ for i in data_to_decode .columns
854
+ if column + prefix_sep in i
855
+ ]
856
+ for x in ["" , column + prefix_sep ]
857
+ ]
858
+ if not cols :
859
+ continue
860
+ out = out .drop (cols , axis = 1 )
861
+ if fill_first_ :
862
+ cols = [column + prefix_sep + fill_first_ ] + cols
863
+ labels = [fill_first_ ] + labels
864
+ data [cols [0 ]] = (1 - data [cols [1 :]]).all (axis = 1 )
865
+ out [column ] = Series (
866
+ np .array (labels )[np .argmax (data [cols ].to_numpy (), axis = 1 )], dtype = dtype
867
+ )
868
+ return out
869
+
870
+
871
+ def _check_len (item , name , data_to_encode ):
872
+ """ Validate prefixes and separator to avoid silently dropping cols. """
873
+ len_msg = (
874
+ "Length of '{name}' ({len_item}) did not match the "
875
+ "length of the columns being encoded ({len_enc})."
876
+ )
877
+
878
+ if is_list_like (item ):
879
+ if not len (item ) == data_to_encode .shape [1 ]:
880
+ len_msg = len_msg .format (
881
+ name = name , len_item = len (item ), len_enc = data_to_encode .shape [1 ]
882
+ )
883
+ raise ValueError (len_msg )
884
+
885
+
754
886
def get_dummies (
755
887
data ,
756
888
prefix = None ,
@@ -871,20 +1003,8 @@ def get_dummies(
871
1003
else :
872
1004
data_to_encode = data [columns ]
873
1005
874
- # validate prefixes and separator to avoid silently dropping cols
875
- def check_len (item , name ):
876
-
877
- if is_list_like (item ):
878
- if not len (item ) == data_to_encode .shape [1 ]:
879
- len_msg = (
880
- f"Length of '{ name } ' ({ len (item )} ) did not match the "
881
- "length of the columns being encoded "
882
- f"({ data_to_encode .shape [1 ]} )."
883
- )
884
- raise ValueError (len_msg )
885
-
886
- check_len (prefix , "prefix" )
887
- check_len (prefix_sep , "prefix_sep" )
1006
+ _check_len (prefix , "prefix" , data_to_encode )
1007
+ _check_len (prefix_sep , "prefix_sep" , data_to_encode )
888
1008
889
1009
if isinstance (prefix , str ):
890
1010
prefix = itertools .cycle ([prefix ])
0 commit comments