Skip to content

Commit 563dd81

Browse files
authored
ENH: Add sort keyword to stack (#53282)
* ENH: Add sort keyword to stack * Removed commented * Use np.sort
1 parent 6afe4a0 commit 563dd81

File tree

4 files changed

+61
-11
lines changed

4 files changed

+61
-11
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ Other enhancements
9797
- Let :meth:`DataFrame.to_feather` accept a non-default :class:`Index` and non-string column names (:issue:`51787`)
9898
- Performance improvement in :func:`read_csv` (:issue:`52632`) with ``engine="c"``
9999
- :meth:`Categorical.from_codes` has gotten a ``validate`` parameter (:issue:`50975`)
100+
- :meth:`DataFrame.stack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
100101
- Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`)
101102
- Performance improvement in :func:`concat` with homogeneous ``np.float64`` or ``np.float32`` dtypes (:issue:`52685`)
102103
- Performance improvement in :meth:`DataFrame.filter` when ``items`` is given (:issue:`52941`)

pandas/core/frame.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -8991,7 +8991,7 @@ def pivot_table(
89918991
sort=sort,
89928992
)
89938993

8994-
def stack(self, level: Level = -1, dropna: bool = True):
8994+
def stack(self, level: Level = -1, dropna: bool = True, sort: bool = True):
89958995
"""
89968996
Stack the prescribed level(s) from columns to index.
89978997
@@ -9017,6 +9017,8 @@ def stack(self, level: Level = -1, dropna: bool = True):
90179017
axis can create combinations of index and column values
90189018
that are missing from the original dataframe. See Examples
90199019
section.
9020+
sort : bool, default True
9021+
Whether to sort the levels of the resulting MultiIndex.
90209022
90219023
Returns
90229024
-------
@@ -9160,9 +9162,9 @@ def stack(self, level: Level = -1, dropna: bool = True):
91609162
)
91619163

91629164
if isinstance(level, (tuple, list)):
9163-
result = stack_multiple(self, level, dropna=dropna)
9165+
result = stack_multiple(self, level, dropna=dropna, sort=sort)
91649166
else:
9165-
result = stack(self, level, dropna=dropna)
9167+
result = stack(self, level, dropna=dropna, sort=sort)
91669168

91679169
return result.__finalize__(self, method="stack")
91689170

pandas/core/reshape/reshape.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pandas.core.dtypes.missing import notna
2929

3030
import pandas.core.algorithms as algos
31+
from pandas.core.algorithms import unique
3132
from pandas.core.arrays.categorical import factorize_from_iterable
3233
from pandas.core.construction import ensure_wrapped_if_datetimelike
3334
from pandas.core.frame import DataFrame
@@ -545,7 +546,7 @@ def _unstack_extension_series(series: Series, level, fill_value) -> DataFrame:
545546
return result
546547

547548

548-
def stack(frame: DataFrame, level=-1, dropna: bool = True):
549+
def stack(frame: DataFrame, level=-1, dropna: bool = True, sort: bool = True):
549550
"""
550551
Convert DataFrame to Series with multi-level Index. Columns become the
551552
second level of the resulting hierarchical index
@@ -567,7 +568,9 @@ def factorize(index):
567568
level_num = frame.columns._get_level_number(level)
568569

569570
if isinstance(frame.columns, MultiIndex):
570-
return _stack_multi_columns(frame, level_num=level_num, dropna=dropna)
571+
return _stack_multi_columns(
572+
frame, level_num=level_num, dropna=dropna, sort=sort
573+
)
571574
elif isinstance(frame.index, MultiIndex):
572575
new_levels = list(frame.index.levels)
573576
new_codes = [lab.repeat(K) for lab in frame.index.codes]
@@ -620,13 +623,13 @@ def factorize(index):
620623
return frame._constructor_sliced(new_values, index=new_index)
621624

622625

623-
def stack_multiple(frame: DataFrame, level, dropna: bool = True):
626+
def stack_multiple(frame: DataFrame, level, dropna: bool = True, sort: bool = True):
624627
# If all passed levels match up to column names, no
625628
# ambiguity about what to do
626629
if all(lev in frame.columns.names for lev in level):
627630
result = frame
628631
for lev in level:
629-
result = stack(result, lev, dropna=dropna)
632+
result = stack(result, lev, dropna=dropna, sort=sort)
630633

631634
# Otherwise, level numbers may change as each successive level is stacked
632635
elif all(isinstance(lev, int) for lev in level):
@@ -639,7 +642,7 @@ def stack_multiple(frame: DataFrame, level, dropna: bool = True):
639642

640643
while level:
641644
lev = level.pop(0)
642-
result = stack(result, lev, dropna=dropna)
645+
result = stack(result, lev, dropna=dropna, sort=sort)
643646
# Decrement all level numbers greater than current, as these
644647
# have now shifted down by one
645648
level = [v if v <= lev else v - 1 for v in level]
@@ -681,7 +684,7 @@ def _stack_multi_column_index(columns: MultiIndex) -> MultiIndex:
681684

682685

683686
def _stack_multi_columns(
684-
frame: DataFrame, level_num: int = -1, dropna: bool = True
687+
frame: DataFrame, level_num: int = -1, dropna: bool = True, sort: bool = True
685688
) -> DataFrame:
686689
def _convert_level_number(level_num: int, columns: Index):
687690
"""
@@ -711,7 +714,7 @@ def _convert_level_number(level_num: int, columns: Index):
711714
roll_columns = roll_columns.swaplevel(lev1, lev2)
712715
this.columns = mi_cols = roll_columns
713716

714-
if not mi_cols._is_lexsorted():
717+
if not mi_cols._is_lexsorted() and sort:
715718
# Workaround the edge case where 0 is one of the column names,
716719
# which interferes with trying to sort based on the first
717720
# level
@@ -725,7 +728,9 @@ def _convert_level_number(level_num: int, columns: Index):
725728
# time to ravel the values
726729
new_data = {}
727730
level_vals = mi_cols.levels[-1]
728-
level_codes = sorted(set(mi_cols.codes[-1]))
731+
level_codes = unique(mi_cols.codes[-1])
732+
if sort:
733+
level_codes = np.sort(level_codes)
729734
level_vals_nan = level_vals.insert(len(level_vals), None)
730735

731736
level_vals_used = np.take(level_vals_nan, level_codes)

pandas/tests/frame/test_stack_unstack.py

+42
Original file line numberDiff line numberDiff line change
@@ -1361,6 +1361,48 @@ def test_unstack_non_slice_like_blocks(using_array_manager):
13611361
tm.assert_frame_equal(res, expected)
13621362

13631363

1364+
def test_stack_sort_false():
1365+
# GH 15105
1366+
data = [[1, 2, 3.0, 4.0], [2, 3, 4.0, 5.0], [3, 4, np.nan, np.nan]]
1367+
df = DataFrame(
1368+
data,
1369+
columns=MultiIndex(
1370+
levels=[["B", "A"], ["x", "y"]], codes=[[0, 0, 1, 1], [0, 1, 0, 1]]
1371+
),
1372+
)
1373+
result = df.stack(level=0, sort=False)
1374+
expected = DataFrame(
1375+
{"x": [1.0, 3.0, 2.0, 4.0, 3.0], "y": [2.0, 4.0, 3.0, 5.0, 4.0]},
1376+
index=MultiIndex.from_arrays([[0, 0, 1, 1, 2], ["B", "A", "B", "A", "B"]]),
1377+
)
1378+
tm.assert_frame_equal(result, expected)
1379+
1380+
# Codes sorted in this call
1381+
df = DataFrame(
1382+
data,
1383+
columns=MultiIndex.from_arrays([["B", "B", "A", "A"], ["x", "y", "x", "y"]]),
1384+
)
1385+
result = df.stack(level=0, sort=False)
1386+
tm.assert_frame_equal(result, expected)
1387+
1388+
1389+
def test_stack_sort_false_multi_level():
1390+
# GH 15105
1391+
idx = MultiIndex.from_tuples([("weight", "kg"), ("height", "m")])
1392+
df = DataFrame([[1.0, 2.0], [3.0, 4.0]], index=["cat", "dog"], columns=idx)
1393+
result = df.stack([0, 1], sort=False)
1394+
expected_index = MultiIndex.from_tuples(
1395+
[
1396+
("cat", "weight", "kg"),
1397+
("cat", "height", "m"),
1398+
("dog", "weight", "kg"),
1399+
("dog", "height", "m"),
1400+
]
1401+
)
1402+
expected = Series([1.0, 2.0, 3.0, 4.0], index=expected_index)
1403+
tm.assert_series_equal(result, expected)
1404+
1405+
13641406
class TestStackUnstackMultiLevel:
13651407
def test_unstack(self, multiindex_year_month_day_dataframe_random_data):
13661408
# just check that it works for now

0 commit comments

Comments
 (0)