Skip to content

Commit d99c448

Browse files
authored
REF: lreshape, wide_to_long (#55976)
* Refactor lreshape * Refactor wide_to_long validation * Refactor wide_to_long * Annotation
1 parent 484ec01 commit d99c448

File tree

1 file changed

+23
-35
lines changed

1 file changed

+23
-35
lines changed

pandas/core/reshape/melt.py

+23-35
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from pandas.core.dtypes.missing import notna
1313

1414
import pandas.core.algorithms as algos
15-
from pandas.core.arrays import Categorical
1615
from pandas.core.indexes.api import MultiIndex
1716
from pandas.core.reshape.concat import concat
1817
from pandas.core.reshape.util import tile_compat
@@ -139,7 +138,7 @@ def melt(
139138
return result
140139

141140

142-
def lreshape(data: DataFrame, groups, dropna: bool = True) -> DataFrame:
141+
def lreshape(data: DataFrame, groups: dict, dropna: bool = True) -> DataFrame:
143142
"""
144143
Reshape wide-format data to long. Generalized inverse of DataFrame.pivot.
145144
@@ -192,30 +191,20 @@ def lreshape(data: DataFrame, groups, dropna: bool = True) -> DataFrame:
192191
2 Red Sox 2008 545
193192
3 Yankees 2008 526
194193
"""
195-
if isinstance(groups, dict):
196-
keys = list(groups.keys())
197-
values = list(groups.values())
198-
else:
199-
keys, values = zip(*groups)
200-
201-
all_cols = list(set.union(*(set(x) for x in values)))
202-
id_cols = list(data.columns.difference(all_cols))
203-
204-
K = len(values[0])
205-
206-
for seq in values:
207-
if len(seq) != K:
208-
raise ValueError("All column lists must be same length")
209-
210194
mdata = {}
211195
pivot_cols = []
212-
213-
for target, names in zip(keys, values):
196+
all_cols: set[Hashable] = set()
197+
K = len(next(iter(groups.values())))
198+
for target, names in groups.items():
199+
if len(names) != K:
200+
raise ValueError("All column lists must be same length")
214201
to_concat = [data[col]._values for col in names]
215202

216203
mdata[target] = concat_compat(to_concat)
217204
pivot_cols.append(target)
205+
all_cols = all_cols.union(names)
218206

207+
id_cols = list(data.columns.difference(all_cols))
219208
for col in id_cols:
220209
mdata[col] = np.tile(data[col]._values, K)
221210

@@ -467,10 +456,10 @@ def wide_to_long(
467456
two 2.9
468457
"""
469458

470-
def get_var_names(df, stub: str, sep: str, suffix: str) -> list[str]:
459+
def get_var_names(df, stub: str, sep: str, suffix: str):
471460
regex = rf"^{re.escape(stub)}{re.escape(sep)}{suffix}$"
472461
pattern = re.compile(regex)
473-
return [col for col in df.columns if pattern.match(col)]
462+
return df.columns[df.columns.str.match(pattern)]
474463

475464
def melt_stub(df, stub: str, i, j, value_vars, sep: str):
476465
newdf = melt(
@@ -480,7 +469,6 @@ def melt_stub(df, stub: str, i, j, value_vars, sep: str):
480469
value_name=stub.rstrip(sep),
481470
var_name=j,
482471
)
483-
newdf[j] = Categorical(newdf[j])
484472
newdf[j] = newdf[j].str.replace(re.escape(stub + sep), "", regex=True)
485473

486474
# GH17627 Cast numerics suffixes to int/float
@@ -497,7 +485,7 @@ def melt_stub(df, stub: str, i, j, value_vars, sep: str):
497485
else:
498486
stubnames = list(stubnames)
499487

500-
if any(col in stubnames for col in df.columns):
488+
if df.columns.isin(stubnames).any():
501489
raise ValueError("stubname can't be identical to a column name")
502490

503491
if not is_list_like(i):
@@ -508,18 +496,18 @@ def melt_stub(df, stub: str, i, j, value_vars, sep: str):
508496
if df[i].duplicated().any():
509497
raise ValueError("the id variables need to uniquely identify each row")
510498

511-
value_vars = [get_var_names(df, stub, sep, suffix) for stub in stubnames]
512-
513-
value_vars_flattened = [e for sublist in value_vars for e in sublist]
514-
id_vars = list(set(df.columns.tolist()).difference(value_vars_flattened))
499+
_melted = []
500+
value_vars_flattened = []
501+
for stub in stubnames:
502+
value_var = get_var_names(df, stub, sep, suffix)
503+
value_vars_flattened.extend(value_var)
504+
_melted.append(melt_stub(df, stub, i, j, value_var, sep))
515505

516-
_melted = [melt_stub(df, s, i, j, v, sep) for s, v in zip(stubnames, value_vars)]
517-
melted = _melted[0].join(_melted[1:], how="outer")
506+
melted = concat(_melted, axis=1)
507+
id_vars = df.columns.difference(value_vars_flattened)
508+
new = df[id_vars]
518509

519510
if len(i) == 1:
520-
new = df[id_vars].set_index(i).join(melted)
521-
return new
522-
523-
new = df[id_vars].merge(melted.reset_index(), on=i).set_index(i + [j])
524-
525-
return new
511+
return new.set_index(i).join(melted)
512+
else:
513+
return new.merge(melted.reset_index(), on=i).set_index(i + [j])

0 commit comments

Comments
 (0)