@@ -6684,7 +6684,7 @@ def _reindex_columns(self, columns):
6684
6684
6685
6685
return self ._internal .copy (sdf = sdf , data_columns = columns , column_index = idx )
6686
6686
6687
- def melt (self , id_vars = None , value_vars = None , var_name = 'variable' ,
6687
+ def melt (self , id_vars = None , value_vars = None , var_name = None ,
6688
6688
value_name = 'value' ):
6689
6689
"""
6690
6690
Unpivot a DataFrame from wide format to long format, optionally
@@ -6705,7 +6705,8 @@ def melt(self, id_vars=None, value_vars=None, var_name='variable',
6705
6705
Column(s) to unpivot. If not specified, uses all columns that
6706
6706
are not set as `id_vars`.
6707
6707
var_name : scalar, default 'variable'
6708
- Name to use for the 'variable' column.
6708
+ Name to use for the 'variable' column. If None it uses `frame.columns.name` or
6709
+ ‘variable’.
6709
6710
value_name : scalar, default 'value'
6710
6711
Name to use for the 'value' column.
6711
6712
@@ -6718,7 +6719,8 @@ def melt(self, id_vars=None, value_vars=None, var_name='variable',
6718
6719
--------
6719
6720
>>> df = ks.DataFrame({'A': {0: 'a', 1: 'b', 2: 'c'},
6720
6721
... 'B': {0: 1, 1: 3, 2: 5},
6721
- ... 'C': {0: 2, 1: 4, 2: 6}})
6722
+ ... 'C': {0: 2, 1: 4, 2: 6}},
6723
+ ... columns=['A', 'B', 'C'])
6722
6724
>>> df
6723
6725
A B C
6724
6726
0 a 1 2
@@ -6769,29 +6771,55 @@ def melt(self, id_vars=None, value_vars=None, var_name='variable',
6769
6771
"""
6770
6772
if id_vars is None :
6771
6773
id_vars = []
6772
- if not isinstance (id_vars , (list , tuple , np .ndarray )):
6773
- id_vars = list (id_vars )
6774
+ elif isinstance (id_vars , str ):
6775
+ id_vars = [(id_vars ,)]
6776
+ elif isinstance (id_vars , tuple ):
6777
+ if self ._internal .column_index_level == 1 :
6778
+ id_vars = [idv if isinstance (idv , tuple ) else (idv ,) for idv in id_vars ]
6779
+ else :
6780
+ raise ValueError ('id_vars must be a list of tuples when columns are a MultiIndex' )
6781
+ else :
6782
+ id_vars = [idv if isinstance (idv , tuple ) else (idv ,) for idv in id_vars ]
6774
6783
6775
- data_columns = self ._internal .data_columns
6784
+ column_index = self ._internal .column_index
6776
6785
6777
6786
if value_vars is None :
6778
6787
value_vars = []
6779
- if not isinstance (value_vars , (list , tuple , np .ndarray )):
6780
- value_vars = list (value_vars )
6788
+ elif isinstance (value_vars , str ):
6789
+ value_vars = [(value_vars ,)]
6790
+ elif isinstance (value_vars , tuple ):
6791
+ value_vars = [value_vars ]
6792
+ else :
6793
+ value_vars = [valv if isinstance (valv , tuple ) else (valv ,) for valv in value_vars ]
6781
6794
if len (value_vars ) == 0 :
6782
- value_vars = data_columns
6795
+ value_vars = column_index
6796
+
6797
+ column_index = [idx for idx in column_index if idx not in id_vars ]
6783
6798
6784
- data_columns = [data_column for data_column in data_columns if data_column not in id_vars ]
6785
6799
sdf = self ._sdf
6786
6800
6801
+ if var_name is None :
6802
+ if self ._internal .column_index_names is not None :
6803
+ var_name = self ._internal .column_index_names
6804
+ elif self ._internal .column_index_level == 1 :
6805
+ var_name = ['variable' ]
6806
+ else :
6807
+ var_name = ['variable_{}' .format (i )
6808
+ for i in range (self ._internal .column_index_level )]
6809
+ elif isinstance (var_name , str ):
6810
+ var_name = [var_name ]
6811
+
6787
6812
pairs = F .explode (F .array (* [
6788
6813
F .struct (* (
6789
- [F .lit (column ).alias (var_name )] +
6790
- [self ._internal .scol_for (column ).alias (value_name )])
6791
- ) for column in data_columns if column in value_vars ]))
6792
-
6793
- columns = (id_vars +
6794
- [F .col ("pairs.%s" % var_name ), F .col ("pairs.%s" % value_name )])
6814
+ [F .lit (c ).alias (name ) for c , name in zip (idx , var_name )] +
6815
+ [self ._internal .scol_for (idx ).alias (value_name )])
6816
+ ) for idx in column_index if idx in value_vars ]))
6817
+
6818
+ columns = ([self ._internal .scol_for (idx ).alias (str (idx ) if len (idx ) > 1 else idx [0 ])
6819
+ for idx in id_vars ] +
6820
+ [F .col ("pairs.%s" % name )
6821
+ for name in var_name [:self ._internal .column_index_level ]] +
6822
+ [F .col ("pairs.%s" % value_name )])
6795
6823
exploded_df = sdf .withColumn ("pairs" , pairs ).select (columns )
6796
6824
6797
6825
return DataFrame (exploded_df )
0 commit comments