Skip to content

Commit eace01c

Browse files
committed
Fix melt for multi-index columns support. (#920)
1 parent 0a5e04c commit eace01c

File tree

3 files changed

+106
-17
lines changed

3 files changed

+106
-17
lines changed

databricks/koalas/frame.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6684,7 +6684,7 @@ def _reindex_columns(self, columns):
66846684

66856685
return self._internal.copy(sdf=sdf, data_columns=columns, column_index=idx)
66866686

6687-
def melt(self, id_vars=None, value_vars=None, var_name='variable',
6687+
def melt(self, id_vars=None, value_vars=None, var_name=None,
66886688
value_name='value'):
66896689
"""
66906690
Unpivot a DataFrame from wide format to long format, optionally
@@ -6705,7 +6705,8 @@ def melt(self, id_vars=None, value_vars=None, var_name='variable',
67056705
Column(s) to unpivot. If not specified, uses all columns that
67066706
are not set as `id_vars`.
67076707
var_name : scalar, default 'variable'
6708-
Name to use for the 'variable' column.
6708+
Name to use for the 'variable' column. If None it uses `frame.columns.name` or
6709+
‘variable’.
67096710
value_name : scalar, default 'value'
67106711
Name to use for the 'value' column.
67116712
@@ -6718,7 +6719,8 @@ def melt(self, id_vars=None, value_vars=None, var_name='variable',
67186719
--------
67196720
>>> df = ks.DataFrame({'A': {0: 'a', 1: 'b', 2: 'c'},
67206721
... 'B': {0: 1, 1: 3, 2: 5},
6721-
... 'C': {0: 2, 1: 4, 2: 6}})
6722+
... 'C': {0: 2, 1: 4, 2: 6}},
6723+
... columns=['A', 'B', 'C'])
67226724
>>> df
67236725
A B C
67246726
0 a 1 2
@@ -6769,29 +6771,55 @@ def melt(self, id_vars=None, value_vars=None, var_name='variable',
67696771
"""
67706772
if id_vars is None:
67716773
id_vars = []
6772-
if not isinstance(id_vars, (list, tuple, np.ndarray)):
6773-
id_vars = list(id_vars)
6774+
elif isinstance(id_vars, str):
6775+
id_vars = [(id_vars,)]
6776+
elif isinstance(id_vars, tuple):
6777+
if self._internal.column_index_level == 1:
6778+
id_vars = [idv if isinstance(idv, tuple) else (idv,) for idv in id_vars]
6779+
else:
6780+
raise ValueError('id_vars must be a list of tuples when columns are a MultiIndex')
6781+
else:
6782+
id_vars = [idv if isinstance(idv, tuple) else (idv,) for idv in id_vars]
67746783

6775-
data_columns = self._internal.data_columns
6784+
column_index = self._internal.column_index
67766785

67776786
if value_vars is None:
67786787
value_vars = []
6779-
if not isinstance(value_vars, (list, tuple, np.ndarray)):
6780-
value_vars = list(value_vars)
6788+
elif isinstance(value_vars, str):
6789+
value_vars = [(value_vars,)]
6790+
elif isinstance(value_vars, tuple):
6791+
value_vars = [value_vars]
6792+
else:
6793+
value_vars = [valv if isinstance(valv, tuple) else (valv,) for valv in value_vars]
67816794
if len(value_vars) == 0:
6782-
value_vars = data_columns
6795+
value_vars = column_index
6796+
6797+
column_index = [idx for idx in column_index if idx not in id_vars]
67836798

6784-
data_columns = [data_column for data_column in data_columns if data_column not in id_vars]
67856799
sdf = self._sdf
67866800

6801+
if var_name is None:
6802+
if self._internal.column_index_names is not None:
6803+
var_name = self._internal.column_index_names
6804+
elif self._internal.column_index_level == 1:
6805+
var_name = ['variable']
6806+
else:
6807+
var_name = ['variable_{}'.format(i)
6808+
for i in range(self._internal.column_index_level)]
6809+
elif isinstance(var_name, str):
6810+
var_name = [var_name]
6811+
67876812
pairs = F.explode(F.array(*[
67886813
F.struct(*(
6789-
[F.lit(column).alias(var_name)] +
6790-
[self._internal.scol_for(column).alias(value_name)])
6791-
) for column in data_columns if column in value_vars]))
6792-
6793-
columns = (id_vars +
6794-
[F.col("pairs.%s" % var_name), F.col("pairs.%s" % value_name)])
6814+
[F.lit(c).alias(name) for c, name in zip(idx, var_name)] +
6815+
[self._internal.scol_for(idx).alias(value_name)])
6816+
) for idx in column_index if idx in value_vars]))
6817+
6818+
columns = ([self._internal.scol_for(idx).alias(str(idx) if len(idx) > 1 else idx[0])
6819+
for idx in id_vars] +
6820+
[F.col("pairs.%s" % name)
6821+
for name in var_name[:self._internal.column_index_level]] +
6822+
[F.col("pairs.%s" % value_name)])
67956823
exploded_df = sdf.withColumn("pairs", pairs).select(columns)
67966824

67976825
return DataFrame(exploded_df)

databricks/koalas/namespace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1544,7 +1544,7 @@ def concat(objs, axis=0, join='outer', ignore_index=False):
15441544
return result_kdf
15451545

15461546

1547-
def melt(frame, id_vars=None, value_vars=None, var_name='variable',
1547+
def melt(frame, id_vars=None, value_vars=None, var_name=None,
15481548
value_name='value'):
15491549
return DataFrame.melt(frame, id_vars, value_vars, var_name, value_name)
15501550

databricks/koalas/tests/test_dataframe.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,6 +1755,67 @@ def test_reindex(self):
17551755
self.assertRaises(TypeError, lambda: kdf.reindex(columns=['X']))
17561756
self.assertRaises(ValueError, lambda: kdf.reindex(columns=[('X',)]))
17571757

1758+
def test_melt(self):
1759+
pdf = pd.DataFrame({'A': [1, 3, 5],
1760+
'B': [2, 4, 6],
1761+
'C': [7, 8, 9]})
1762+
kdf = ks.from_pandas(pdf)
1763+
1764+
self.assert_eq(kdf.melt().sort_values(['variable', 'value'])
1765+
.reset_index(drop=True),
1766+
pdf.melt().sort_values(['variable', 'value']))
1767+
self.assert_eq(kdf.melt(id_vars='A').sort_values(['variable', 'value'])
1768+
.reset_index(drop=True),
1769+
pdf.melt(id_vars='A').sort_values(['variable', 'value']))
1770+
self.assert_eq(kdf.melt(id_vars=['A', 'B']).sort_values(['variable', 'value'])
1771+
.reset_index(drop=True),
1772+
pdf.melt(id_vars=['A', 'B']).sort_values(['variable', 'value']))
1773+
self.assert_eq(kdf.melt(id_vars=('A', 'B')).sort_values(['variable', 'value'])
1774+
.reset_index(drop=True),
1775+
pdf.melt(id_vars=('A', 'B')).sort_values(['variable', 'value']))
1776+
self.assert_eq(kdf.melt(id_vars=['A'], value_vars=['C']).sort_values(['variable', 'value'])
1777+
.reset_index(drop=True),
1778+
pdf.melt(id_vars=['A'], value_vars=['C']).sort_values(['variable', 'value']))
1779+
self.assert_eq(kdf.melt(id_vars=['A'], value_vars=['B'],
1780+
var_name='myVarname', value_name='myValname')
1781+
.sort_values(['myVarname', 'myValname']).reset_index(drop=True),
1782+
pdf.melt(id_vars=['A'], value_vars=['B'],
1783+
var_name='myVarname', value_name='myValname')
1784+
.sort_values(['myVarname', 'myValname']))
1785+
1786+
# multi-index columns
1787+
columns = pd.MultiIndex.from_tuples([('X', 'A'), ('X', 'B'), ('Y', 'C')])
1788+
pdf.columns = columns
1789+
kdf.columns = columns
1790+
1791+
self.assert_eq(kdf.melt().sort_values(['variable_0', 'variable_1', 'value'])
1792+
.reset_index(drop=True),
1793+
pdf.melt().sort_values(['variable_0', 'variable_1', 'value']))
1794+
self.assert_eq(kdf.melt(id_vars=[('X', 'A')])
1795+
.sort_values(['variable_0', 'variable_1', 'value']).reset_index(drop=True),
1796+
pdf.melt(id_vars=[('X', 'A')])
1797+
.sort_values(['variable_0', 'variable_1', 'value']), almost=True)
1798+
self.assert_eq(kdf.melt(id_vars=[('X', 'A')], value_vars=[('Y', 'C')])
1799+
.sort_values(['variable_0', 'variable_1', 'value']).reset_index(drop=True),
1800+
pdf.melt(id_vars=[('X', 'A')], value_vars=[('Y', 'C')])
1801+
.sort_values(['variable_0', 'variable_1', 'value']), almost=True)
1802+
self.assert_eq(kdf.melt(id_vars=[('X', 'A')], value_vars=[('X', 'B')],
1803+
var_name=['myV1', 'myV2'], value_name='myValname')
1804+
.sort_values(['myV1', 'myV2', 'myValname']).reset_index(drop=True),
1805+
pdf.melt(id_vars=[('X', 'A')], value_vars=[('X', 'B')],
1806+
var_name=['myV1', 'myV2'], value_name='myValname')
1807+
.sort_values(['myV1', 'myV2', 'myValname']), almost=True)
1808+
1809+
columns.names = ['v0', 'v1']
1810+
pdf.columns = columns
1811+
kdf.columns = columns
1812+
1813+
self.assert_eq(kdf.melt().sort_values(['v0', 'v1', 'value'])
1814+
.reset_index(drop=True),
1815+
pdf.melt().sort_values(['v0', 'v1', 'value']))
1816+
1817+
self.assertRaises(ValueError, lambda: kdf.melt(id_vars=('X', 'A')))
1818+
17581819
def test_all(self):
17591820
pdf = pd.DataFrame({
17601821
'col1': [False, False, False],

0 commit comments

Comments
 (0)