Skip to content

Commit ddad4a9

Browse files
committed
Fix stubname equal suffix in wide to long function
1 parent 38086f1 commit ddad4a9

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

pandas/core/reshape/melt.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,9 @@ def get_var_names(df, stub: str, sep: str, suffix: str):
603603
return df.columns[df.columns.str.match(regex)]
604604

605605
def melt_stub(df, stub: str, i, j, value_vars, sep: str):
606+
# Ensure value_name and var_name are different when passing to melt
607+
j_original = j
608+
j = f"{j}_1" if stub == j else j
606609
newdf = melt(
607610
df,
608611
id_vars=i,
@@ -619,7 +622,10 @@ def melt_stub(df, stub: str, i, j, value_vars, sep: str):
619622
# TODO: anything else to catch?
620623
pass
621624

622-
return newdf.set_index(i + [j])
625+
newdf = newdf.set_index(i + [j])
626+
if j != j_original:
627+
newdf.index = newdf.index.set_names(j_original, level=-1)
628+
return newdf
623629

624630
if not is_list_like(stubnames):
625631
stubnames = [stubnames]

pandas/tests/reshape/test_melt.py

+27
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,33 @@ def test_missing_stubname(self, dtype):
12181218
expected.index = expected.index.set_levels(new_level, level=0)
12191219
tm.assert_frame_equal(result, expected)
12201220

1221+
@pytest.mark.parametrize("stubnames", ["year", ["year"]])
1222+
def test_stubname_equal_suffix(self, stubnames):
1223+
# https://github.com/pandas-dev/pandas/issues/46939
1224+
df = DataFrame(
1225+
{
1226+
"year1": {0: 4.5, 1: 1.7},
1227+
"year2": {0: 2.5, 1: 1.2},
1228+
"X": dict(zip(range(2), range(2, 4))),
1229+
}
1230+
)
1231+
df["id"] = df.index
1232+
result = wide_to_long(
1233+
df,
1234+
stubnames=stubnames,
1235+
i="id",
1236+
j="year",
1237+
)
1238+
expected = DataFrame(
1239+
[[2, 4.5], [3, 1.7], [2, 2.5], [3, 1.2]],
1240+
columns=["X", "year"],
1241+
index=pd.MultiIndex.from_arrays(
1242+
[[0, 1, 0, 1], [1, 1, 2, 2]],
1243+
names=["id", "year"],
1244+
),
1245+
)
1246+
tm.assert_frame_equal(result, expected)
1247+
12211248

12221249
def test_wide_to_long_pyarrow_string_columns():
12231250
# GH 57066

0 commit comments

Comments
 (0)