Skip to content

Commit 4da5926

Browse files
authored
REF: melt (#55948)
* REF: melt * Fix case, generalize test * Add whatsnew * Update issue number * Revert typing
1 parent 8a814a0 commit 4da5926

File tree

4 files changed

+90
-74
lines changed

4 files changed

+90
-74
lines changed

doc/source/whatsnew/v2.2.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,8 @@ Reshaping
442442
- Bug in :func:`concat` ignoring ``sort`` parameter when passed :class:`DatetimeIndex` indexes (:issue:`54769`)
443443
- Bug in :func:`merge_asof` raising ``TypeError`` when ``by`` dtype is not ``object``, ``int64``, or ``uint64`` (:issue:`22794`)
444444
- Bug in :func:`merge` returning columns in incorrect order when left and/or right is empty (:issue:`51929`)
445+
- Bug in :meth:`pandas.DataFrame.melt` where an exception was raised if ``var_name`` was not a string (:issue:`55948`)
445446
- Bug in :meth:`pandas.DataFrame.melt` where it would not preserve the datetime (:issue:`55254`)
446-
-
447447

448448
Sparse
449449
^^^^^^

pandas/core/reshape/melt.py

+45-59
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,7 @@
1313

1414
import pandas.core.algorithms as algos
1515
from pandas.core.arrays import Categorical
16-
import pandas.core.common as com
17-
from pandas.core.indexes.api import (
18-
Index,
19-
MultiIndex,
20-
)
16+
from pandas.core.indexes.api import MultiIndex
2117
from pandas.core.reshape.concat import concat
2218
from pandas.core.reshape.util import tile_compat
2319
from pandas.core.shared_docs import _shared_docs
@@ -31,6 +27,20 @@
3127
from pandas import DataFrame
3228

3329

30+
def ensure_list_vars(arg_vars, variable: str, columns) -> list:
31+
if arg_vars is not None:
32+
if not is_list_like(arg_vars):
33+
return [arg_vars]
34+
elif isinstance(columns, MultiIndex) and not isinstance(arg_vars, list):
35+
raise ValueError(
36+
f"{variable} must be a list of tuples when columns are a MultiIndex"
37+
)
38+
else:
39+
return list(arg_vars)
40+
else:
41+
return []
42+
43+
3444
@Appender(_shared_docs["melt"] % {"caller": "pd.melt(df, ", "other": "DataFrame.melt"})
3545
def melt(
3646
frame: DataFrame,
@@ -41,61 +51,35 @@ def melt(
4151
col_level=None,
4252
ignore_index: bool = True,
4353
) -> DataFrame:
44-
# If multiindex, gather names of columns on all level for checking presence
45-
# of `id_vars` and `value_vars`
46-
if isinstance(frame.columns, MultiIndex):
47-
cols = [x for c in frame.columns for x in c]
48-
else:
49-
cols = list(frame.columns)
50-
5154
if value_name in frame.columns:
5255
raise ValueError(
5356
f"value_name ({value_name}) cannot match an element in "
5457
"the DataFrame columns."
5558
)
59+
id_vars = ensure_list_vars(id_vars, "id_vars", frame.columns)
60+
value_vars_was_not_none = value_vars is not None
61+
value_vars = ensure_list_vars(value_vars, "value_vars", frame.columns)
5662

57-
if id_vars is not None:
58-
if not is_list_like(id_vars):
59-
id_vars = [id_vars]
60-
elif isinstance(frame.columns, MultiIndex) and not isinstance(id_vars, list):
61-
raise ValueError(
62-
"id_vars must be a list of tuples when columns are a MultiIndex"
63-
)
64-
else:
65-
# Check that `id_vars` are in frame
66-
id_vars = list(id_vars)
67-
missing = Index(com.flatten(id_vars)).difference(cols)
68-
if not missing.empty:
69-
raise KeyError(
70-
"The following 'id_vars' are not present "
71-
f"in the DataFrame: {list(missing)}"
72-
)
73-
else:
74-
id_vars = []
75-
76-
if value_vars is not None:
77-
if not is_list_like(value_vars):
78-
value_vars = [value_vars]
79-
elif isinstance(frame.columns, MultiIndex) and not isinstance(value_vars, list):
80-
raise ValueError(
81-
"value_vars must be a list of tuples when columns are a MultiIndex"
82-
)
83-
else:
84-
value_vars = list(value_vars)
85-
# Check that `value_vars` are in frame
86-
missing = Index(com.flatten(value_vars)).difference(cols)
87-
if not missing.empty:
88-
raise KeyError(
89-
"The following 'value_vars' are not present in "
90-
f"the DataFrame: {list(missing)}"
91-
)
63+
if id_vars or value_vars:
9264
if col_level is not None:
93-
idx = frame.columns.get_level_values(col_level).get_indexer(
94-
id_vars + value_vars
65+
level = frame.columns.get_level_values(col_level)
66+
else:
67+
level = frame.columns
68+
labels = id_vars + value_vars
69+
idx = level.get_indexer_for(labels)
70+
missing = idx == -1
71+
if missing.any():
72+
missing_labels = [
73+
lab for lab, not_found in zip(labels, missing) if not_found
74+
]
75+
raise KeyError(
76+
"The following id_vars or value_vars are not present in "
77+
f"the DataFrame: {missing_labels}"
9578
)
79+
if value_vars_was_not_none:
80+
frame = frame.iloc[:, algos.unique(idx)]
9681
else:
97-
idx = algos.unique(frame.columns.get_indexer_for(id_vars + value_vars))
98-
frame = frame.iloc[:, idx]
82+
frame = frame.copy()
9983
else:
10084
frame = frame.copy()
10185

@@ -113,24 +97,26 @@ def melt(
11397
var_name = [
11498
frame.columns.name if frame.columns.name is not None else "variable"
11599
]
116-
if isinstance(var_name, str):
100+
elif is_list_like(var_name):
101+
raise ValueError(f"{var_name=} must be a scalar.")
102+
else:
117103
var_name = [var_name]
118104

119-
N, K = frame.shape
120-
K -= len(id_vars)
105+
num_rows, K = frame.shape
106+
num_cols_adjusted = K - len(id_vars)
121107

122108
mdata: dict[Hashable, AnyArrayLike] = {}
123109
for col in id_vars:
124110
id_data = frame.pop(col)
125111
if not isinstance(id_data.dtype, np.dtype):
126112
# i.e. ExtensionDtype
127-
if K > 0:
128-
mdata[col] = concat([id_data] * K, ignore_index=True)
113+
if num_cols_adjusted > 0:
114+
mdata[col] = concat([id_data] * num_cols_adjusted, ignore_index=True)
129115
else:
130116
# We can't concat empty list. (GH 46044)
131117
mdata[col] = type(id_data)([], name=id_data.name, dtype=id_data.dtype)
132118
else:
133-
mdata[col] = np.tile(id_data._values, K)
119+
mdata[col] = np.tile(id_data._values, num_cols_adjusted)
134120

135121
mcolumns = id_vars + var_name + [value_name]
136122

@@ -143,12 +129,12 @@ def melt(
143129
else:
144130
mdata[value_name] = frame._values.ravel("F")
145131
for i, col in enumerate(var_name):
146-
mdata[col] = frame.columns._get_level_values(i).repeat(N)
132+
mdata[col] = frame.columns._get_level_values(i).repeat(num_rows)
147133

148134
result = frame._constructor(mdata, columns=mcolumns)
149135

150136
if not ignore_index:
151-
result.index = tile_compat(frame.index, K)
137+
result.index = tile_compat(frame.index, num_cols_adjusted)
152138

153139
return result
154140

pandas/core/shared_docs.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -208,17 +208,17 @@
208208
209209
Parameters
210210
----------
211-
id_vars : tuple, list, or ndarray, optional
211+
id_vars : scalar, tuple, list, or ndarray, optional
212212
Column(s) to use as identifier variables.
213-
value_vars : tuple, list, or ndarray, optional
213+
value_vars : scalar, tuple, list, or ndarray, optional
214214
Column(s) to unpivot. If not specified, uses all columns that
215215
are not set as `id_vars`.
216-
var_name : scalar
216+
var_name : scalar, default None
217217
Name to use for the 'variable' column. If None it uses
218218
``frame.columns.name`` or 'variable'.
219219
value_name : scalar, default 'value'
220220
Name to use for the 'value' column, can't be an existing column label.
221-
col_level : int or str, optional
221+
col_level : scalar, optional
222222
If columns are a MultiIndex then use this level to melt.
223223
ignore_index : bool, default True
224224
If True, original index is ignored. If False, the original index is retained.

pandas/tests/reshape/test_melt.py

+40-10
Original file line numberDiff line numberDiff line change
@@ -327,32 +327,28 @@ def test_melt_missing_columns_raises(self):
327327
)
328328

329329
# Try to melt with missing `value_vars` column name
330-
msg = "The following '{Var}' are not present in the DataFrame: {Col}"
331-
with pytest.raises(
332-
KeyError, match=msg.format(Var="value_vars", Col="\\['C'\\]")
333-
):
330+
msg = "The following id_vars or value_vars are not present in the DataFrame:"
331+
with pytest.raises(KeyError, match=msg):
334332
df.melt(["a", "b"], ["C", "d"])
335333

336334
# Try to melt with missing `id_vars` column name
337-
with pytest.raises(KeyError, match=msg.format(Var="id_vars", Col="\\['A'\\]")):
335+
with pytest.raises(KeyError, match=msg):
338336
df.melt(["A", "b"], ["c", "d"])
339337

340338
# Multiple missing
341339
with pytest.raises(
342340
KeyError,
343-
match=msg.format(Var="id_vars", Col="\\['not_here', 'or_there'\\]"),
341+
match=msg,
344342
):
345343
df.melt(["a", "b", "not_here", "or_there"], ["c", "d"])
346344

347345
# Multiindex melt fails if column is missing from multilevel melt
348346
multi = df.copy()
349347
multi.columns = [list("ABCD"), list("abcd")]
350-
with pytest.raises(KeyError, match=msg.format(Var="id_vars", Col="\\['E'\\]")):
348+
with pytest.raises(KeyError, match=msg):
351349
multi.melt([("E", "a")], [("B", "b")])
352350
# Multiindex fails if column is missing from single level melt
353-
with pytest.raises(
354-
KeyError, match=msg.format(Var="value_vars", Col="\\['F'\\]")
355-
):
351+
with pytest.raises(KeyError, match=msg):
356352
multi.melt(["A"], ["F"], col_level=0)
357353

358354
def test_melt_mixed_int_str_id_vars(self):
@@ -500,6 +496,40 @@ def test_melt_preserves_datetime(self):
500496
)
501497
tm.assert_frame_equal(result, expected)
502498

499+
def test_melt_allows_non_scalar_id_vars(self):
500+
df = DataFrame(
501+
data={"a": [1, 2, 3], "b": [4, 5, 6]},
502+
index=["11", "22", "33"],
503+
)
504+
result = df.melt(
505+
id_vars="a",
506+
var_name=0,
507+
value_name=1,
508+
)
509+
expected = DataFrame({"a": [1, 2, 3], 0: ["b"] * 3, 1: [4, 5, 6]})
510+
tm.assert_frame_equal(result, expected)
511+
512+
def test_melt_allows_non_string_var_name(self):
513+
df = DataFrame(
514+
data={"a": [1, 2, 3], "b": [4, 5, 6]},
515+
index=["11", "22", "33"],
516+
)
517+
result = df.melt(
518+
id_vars=["a"],
519+
var_name=0,
520+
value_name=1,
521+
)
522+
expected = DataFrame({"a": [1, 2, 3], 0: ["b"] * 3, 1: [4, 5, 6]})
523+
tm.assert_frame_equal(result, expected)
524+
525+
def test_melt_non_scalar_var_name_raises(self):
526+
df = DataFrame(
527+
data={"a": [1, 2, 3], "b": [4, 5, 6]},
528+
index=["11", "22", "33"],
529+
)
530+
with pytest.raises(ValueError, match=r".* must be a scalar."):
531+
df.melt(id_vars=["a"], var_name=[1, 2])
532+
503533

504534
class TestLreshape:
505535
def test_pairs(self):

0 commit comments

Comments
 (0)