Skip to content

REF: melt #55948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,8 @@ Reshaping
- Bug in :func:`concat` ignoring ``sort`` parameter when passed :class:`DatetimeIndex` indexes (:issue:`54769`)
- Bug in :func:`merge_asof` raising ``TypeError`` when ``by`` dtype is not ``object``, ``int64``, or ``uint64`` (:issue:`22794`)
- Bug in :func:`merge` returning columns in incorrect order when left and/or right is empty (:issue:`51929`)
- Bug in :meth:`pandas.DataFrame.melt` where an exception was raised if ``var_name`` was not a string (:issue:`55948`)
- Bug in :meth:`pandas.DataFrame.melt` where it would not preserve the datetime (:issue:`55254`)
-

Sparse
^^^^^^
Expand Down
104 changes: 45 additions & 59 deletions pandas/core/reshape/melt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@

import pandas.core.algorithms as algos
from pandas.core.arrays import Categorical
import pandas.core.common as com
from pandas.core.indexes.api import (
Index,
MultiIndex,
)
from pandas.core.indexes.api import MultiIndex
from pandas.core.reshape.concat import concat
from pandas.core.reshape.util import tile_compat
from pandas.core.shared_docs import _shared_docs
Expand All @@ -31,6 +27,20 @@
from pandas import DataFrame


def ensure_list_vars(arg_vars, variable: str, columns) -> list:
if arg_vars is not None:
if not is_list_like(arg_vars):
return [arg_vars]
elif isinstance(columns, MultiIndex) and not isinstance(arg_vars, list):
raise ValueError(
f"{variable} must be a list of tuples when columns are a MultiIndex"
)
else:
return list(arg_vars)
else:
return []


@Appender(_shared_docs["melt"] % {"caller": "pd.melt(df, ", "other": "DataFrame.melt"})
def melt(
frame: DataFrame,
Expand All @@ -41,61 +51,35 @@ def melt(
col_level=None,
ignore_index: bool = True,
) -> DataFrame:
# If multiindex, gather names of columns on all level for checking presence
# of `id_vars` and `value_vars`
if isinstance(frame.columns, MultiIndex):
cols = [x for c in frame.columns for x in c]
else:
cols = list(frame.columns)

if value_name in frame.columns:
raise ValueError(
f"value_name ({value_name}) cannot match an element in "
"the DataFrame columns."
)
id_vars = ensure_list_vars(id_vars, "id_vars", frame.columns)
value_vars_was_not_none = value_vars is not None
value_vars = ensure_list_vars(value_vars, "value_vars", frame.columns)

if id_vars is not None:
if not is_list_like(id_vars):
id_vars = [id_vars]
elif isinstance(frame.columns, MultiIndex) and not isinstance(id_vars, list):
raise ValueError(
"id_vars must be a list of tuples when columns are a MultiIndex"
)
else:
# Check that `id_vars` are in frame
id_vars = list(id_vars)
missing = Index(com.flatten(id_vars)).difference(cols)
if not missing.empty:
raise KeyError(
"The following 'id_vars' are not present "
f"in the DataFrame: {list(missing)}"
)
else:
id_vars = []

if value_vars is not None:
if not is_list_like(value_vars):
value_vars = [value_vars]
elif isinstance(frame.columns, MultiIndex) and not isinstance(value_vars, list):
raise ValueError(
"value_vars must be a list of tuples when columns are a MultiIndex"
)
else:
value_vars = list(value_vars)
# Check that `value_vars` are in frame
missing = Index(com.flatten(value_vars)).difference(cols)
if not missing.empty:
raise KeyError(
"The following 'value_vars' are not present in "
f"the DataFrame: {list(missing)}"
)
if id_vars or value_vars:
if col_level is not None:
idx = frame.columns.get_level_values(col_level).get_indexer(
id_vars + value_vars
level = frame.columns.get_level_values(col_level)
else:
level = frame.columns
labels = id_vars + value_vars
idx = level.get_indexer_for(labels)
missing = idx == -1
if missing.any():
missing_labels = [
lab for lab, not_found in zip(labels, missing) if not_found
]
raise KeyError(
"The following id_vars or value_vars are not present in "
f"the DataFrame: {missing_labels}"
)
if value_vars_was_not_none:
frame = frame.iloc[:, algos.unique(idx)]
else:
idx = algos.unique(frame.columns.get_indexer_for(id_vars + value_vars))
frame = frame.iloc[:, idx]
frame = frame.copy()
else:
frame = frame.copy()

Expand All @@ -113,24 +97,26 @@ def melt(
var_name = [
frame.columns.name if frame.columns.name is not None else "variable"
]
if isinstance(var_name, str):
elif is_list_like(var_name):
raise ValueError(f"{var_name=} must be a scalar.")
else:
var_name = [var_name]

N, K = frame.shape
K -= len(id_vars)
num_rows, K = frame.shape
num_cols_adjusted = K - len(id_vars)

mdata: dict[Hashable, AnyArrayLike] = {}
for col in id_vars:
id_data = frame.pop(col)
if not isinstance(id_data.dtype, np.dtype):
# i.e. ExtensionDtype
if K > 0:
mdata[col] = concat([id_data] * K, ignore_index=True)
if num_cols_adjusted > 0:
mdata[col] = concat([id_data] * num_cols_adjusted, ignore_index=True)
else:
# We can't concat empty list. (GH 46044)
mdata[col] = type(id_data)([], name=id_data.name, dtype=id_data.dtype)
else:
mdata[col] = np.tile(id_data._values, K)
mdata[col] = np.tile(id_data._values, num_cols_adjusted)

mcolumns = id_vars + var_name + [value_name]

Expand All @@ -143,12 +129,12 @@ def melt(
else:
mdata[value_name] = frame._values.ravel("F")
for i, col in enumerate(var_name):
mdata[col] = frame.columns._get_level_values(i).repeat(N)
mdata[col] = frame.columns._get_level_values(i).repeat(num_rows)

result = frame._constructor(mdata, columns=mcolumns)

if not ignore_index:
result.index = tile_compat(frame.index, K)
result.index = tile_compat(frame.index, num_cols_adjusted)

return result

Expand Down
8 changes: 4 additions & 4 deletions pandas/core/shared_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,17 @@

Parameters
----------
id_vars : tuple, list, or ndarray, optional
id_vars : scalar, tuple, list, or ndarray, optional
Column(s) to use as identifier variables.
value_vars : tuple, list, or ndarray, optional
value_vars : scalar, tuple, list, or ndarray, optional
Column(s) to unpivot. If not specified, uses all columns that
are not set as `id_vars`.
var_name : scalar
var_name : scalar, default None
Name to use for the 'variable' column. If None it uses
``frame.columns.name`` or 'variable'.
value_name : scalar, default 'value'
Name to use for the 'value' column, can't be an existing column label.
col_level : int or str, optional
col_level : scalar, optional
If columns are a MultiIndex then use this level to melt.
ignore_index : bool, default True
If True, original index is ignored. If False, the original index is retained.
Expand Down
50 changes: 40 additions & 10 deletions pandas/tests/reshape/test_melt.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,32 +327,28 @@ def test_melt_missing_columns_raises(self):
)

# Try to melt with missing `value_vars` column name
msg = "The following '{Var}' are not present in the DataFrame: {Col}"
with pytest.raises(
KeyError, match=msg.format(Var="value_vars", Col="\\['C'\\]")
):
msg = "The following id_vars or value_vars are not present in the DataFrame:"
with pytest.raises(KeyError, match=msg):
df.melt(["a", "b"], ["C", "d"])

# Try to melt with missing `id_vars` column name
with pytest.raises(KeyError, match=msg.format(Var="id_vars", Col="\\['A'\\]")):
with pytest.raises(KeyError, match=msg):
df.melt(["A", "b"], ["c", "d"])

# Multiple missing
with pytest.raises(
KeyError,
match=msg.format(Var="id_vars", Col="\\['not_here', 'or_there'\\]"),
match=msg,
):
df.melt(["a", "b", "not_here", "or_there"], ["c", "d"])

# Multiindex melt fails if column is missing from multilevel melt
multi = df.copy()
multi.columns = [list("ABCD"), list("abcd")]
with pytest.raises(KeyError, match=msg.format(Var="id_vars", Col="\\['E'\\]")):
with pytest.raises(KeyError, match=msg):
multi.melt([("E", "a")], [("B", "b")])
# Multiindex fails if column is missing from single level melt
with pytest.raises(
KeyError, match=msg.format(Var="value_vars", Col="\\['F'\\]")
):
with pytest.raises(KeyError, match=msg):
multi.melt(["A"], ["F"], col_level=0)

def test_melt_mixed_int_str_id_vars(self):
Expand Down Expand Up @@ -500,6 +496,40 @@ def test_melt_preserves_datetime(self):
)
tm.assert_frame_equal(result, expected)

def test_melt_allows_non_scalar_id_vars(self):
df = DataFrame(
data={"a": [1, 2, 3], "b": [4, 5, 6]},
index=["11", "22", "33"],
)
result = df.melt(
id_vars="a",
var_name=0,
value_name=1,
)
expected = DataFrame({"a": [1, 2, 3], 0: ["b"] * 3, 1: [4, 5, 6]})
tm.assert_frame_equal(result, expected)

def test_melt_allows_non_string_var_name(self):
df = DataFrame(
data={"a": [1, 2, 3], "b": [4, 5, 6]},
index=["11", "22", "33"],
)
result = df.melt(
id_vars=["a"],
var_name=0,
value_name=1,
)
expected = DataFrame({"a": [1, 2, 3], 0: ["b"] * 3, 1: [4, 5, 6]})
tm.assert_frame_equal(result, expected)

def test_melt_non_scalar_var_name_raises(self):
df = DataFrame(
data={"a": [1, 2, 3], "b": [4, 5, 6]},
index=["11", "22", "33"],
)
with pytest.raises(ValueError, match=r".* must be a scalar."):
df.melt(id_vars=["a"], var_name=[1, 2])


class TestLreshape:
def test_pairs(self):
Expand Down