Skip to content

Commit 787a4f8

Browse files
authored
ENH: Add sort keyword to unstack (#53298)
* Start adding sort to unstack * Fix compressor * Add testing for dataframe * Add whatsnew * remove sort for now * Fix param * Add sort=sort
1 parent c6700d5 commit 787a4f8

File tree

5 files changed

+96
-25
lines changed

5 files changed

+96
-25
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ Other enhancements
9898
- Performance improvement in :func:`read_csv` (:issue:`52632`) with ``engine="c"``
9999
- :meth:`Categorical.from_codes` has gotten a ``validate`` parameter (:issue:`50975`)
100100
- :meth:`DataFrame.stack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
101+
- :meth:`DataFrame.unstack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
101102
- Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`)
102103
- Performance improvement in :func:`concat` with homogeneous ``np.float64`` or ``np.float32`` dtypes (:issue:`52685`)
103104
- Performance improvement in :meth:`DataFrame.filter` when ``items`` is given (:issue:`52941`)

pandas/core/frame.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -9296,7 +9296,7 @@ def explode(
92969296

92979297
return result.__finalize__(self, method="explode")
92989298

9299-
def unstack(self, level: Level = -1, fill_value=None):
9299+
def unstack(self, level: Level = -1, fill_value=None, sort: bool = True):
93009300
"""
93019301
Pivot a level of the (necessarily hierarchical) index labels.
93029302
@@ -9312,6 +9312,8 @@ def unstack(self, level: Level = -1, fill_value=None):
93129312
Level(s) of index to unstack, can pass level name.
93139313
fill_value : int, str or dict
93149314
Replace NaN with this value if the unstack produces missing values.
9315+
sort : bool, default True
9316+
Sort the level(s) in the resulting MultiIndex columns.
93159317
93169318
Returns
93179319
-------
@@ -9359,7 +9361,7 @@ def unstack(self, level: Level = -1, fill_value=None):
93599361
"""
93609362
from pandas.core.reshape.reshape import unstack
93619363

9362-
result = unstack(self, level, fill_value)
9364+
result = unstack(self, level, fill_value, sort)
93639365

93649366
return result.__finalize__(self, method="unstack")
93659367

pandas/core/reshape/reshape.py

+47-21
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,11 @@ class _Unstacker:
102102
unstacked : DataFrame
103103
"""
104104

105-
def __init__(self, index: MultiIndex, level: Level, constructor) -> None:
105+
def __init__(
106+
self, index: MultiIndex, level: Level, constructor, sort: bool = True
107+
) -> None:
106108
self.constructor = constructor
109+
self.sort = sort
107110

108111
self.index = index.remove_unused_levels()
109112

@@ -119,11 +122,15 @@ def __init__(self, index: MultiIndex, level: Level, constructor) -> None:
119122
self.removed_name = self.new_index_names.pop(self.level)
120123
self.removed_level = self.new_index_levels.pop(self.level)
121124
self.removed_level_full = index.levels[self.level]
125+
if not self.sort:
126+
unique_codes = unique(self.index.codes[self.level])
127+
self.removed_level = self.removed_level.take(unique_codes)
128+
self.removed_level_full = self.removed_level_full.take(unique_codes)
122129

123130
# Bug fix GH 20601
124131
# If the data frame is too big, the number of unique index combination
125132
# will cause int32 overflow on windows environments.
126-
# We want to check and raise an error before this happens
133+
# We want to check and raise an warning before this happens
127134
num_rows = np.max([index_level.size for index_level in self.new_index_levels])
128135
num_columns = self.removed_level.size
129136

@@ -164,13 +171,17 @@ def _indexer_and_to_sort(
164171
@cache_readonly
165172
def sorted_labels(self) -> list[np.ndarray]:
166173
indexer, to_sort = self._indexer_and_to_sort
167-
return [line.take(indexer) for line in to_sort]
174+
if self.sort:
175+
return [line.take(indexer) for line in to_sort]
176+
return to_sort
168177

169178
def _make_sorted_values(self, values: np.ndarray) -> np.ndarray:
170-
indexer, _ = self._indexer_and_to_sort
179+
if self.sort:
180+
indexer, _ = self._indexer_and_to_sort
171181

172-
sorted_values = algos.take_nd(values, indexer, axis=0)
173-
return sorted_values
182+
sorted_values = algos.take_nd(values, indexer, axis=0)
183+
return sorted_values
184+
return values
174185

175186
def _make_selectors(self):
176187
new_levels = self.new_index_levels
@@ -195,7 +206,10 @@ def _make_selectors(self):
195206

196207
self.group_index = comp_index
197208
self.mask = mask
198-
self.compressor = comp_index.searchsorted(np.arange(ngroups))
209+
if self.sort:
210+
self.compressor = comp_index.searchsorted(np.arange(ngroups))
211+
else:
212+
self.compressor = np.sort(np.unique(comp_index, return_index=True)[1])
199213

200214
@cache_readonly
201215
def mask_all(self) -> bool:
@@ -376,7 +390,9 @@ def new_index(self) -> MultiIndex:
376390
)
377391

378392

379-
def _unstack_multiple(data: Series | DataFrame, clocs, fill_value=None):
393+
def _unstack_multiple(
394+
data: Series | DataFrame, clocs, fill_value=None, sort: bool = True
395+
):
380396
if len(clocs) == 0:
381397
return data
382398

@@ -421,7 +437,7 @@ def _unstack_multiple(data: Series | DataFrame, clocs, fill_value=None):
421437
dummy = data.copy()
422438
dummy.index = dummy_index
423439

424-
unstacked = dummy.unstack("__placeholder__", fill_value=fill_value)
440+
unstacked = dummy.unstack("__placeholder__", fill_value=fill_value, sort=sort)
425441
new_levels = clevels
426442
new_names = cnames
427443
new_codes = recons_codes
@@ -430,7 +446,7 @@ def _unstack_multiple(data: Series | DataFrame, clocs, fill_value=None):
430446
result = data
431447
while clocs:
432448
val = clocs.pop(0)
433-
result = result.unstack(val, fill_value=fill_value)
449+
result = result.unstack(val, fill_value=fill_value, sort=sort)
434450
clocs = [v if v < val else v - 1 for v in clocs]
435451

436452
return result
@@ -439,7 +455,9 @@ def _unstack_multiple(data: Series | DataFrame, clocs, fill_value=None):
439455
dummy_df = data.copy(deep=False)
440456
dummy_df.index = dummy_index
441457

442-
unstacked = dummy_df.unstack("__placeholder__", fill_value=fill_value)
458+
unstacked = dummy_df.unstack(
459+
"__placeholder__", fill_value=fill_value, sort=sort
460+
)
443461
if isinstance(unstacked, Series):
444462
unstcols = unstacked.index
445463
else:
@@ -464,12 +482,12 @@ def _unstack_multiple(data: Series | DataFrame, clocs, fill_value=None):
464482
return unstacked
465483

466484

467-
def unstack(obj: Series | DataFrame, level, fill_value=None):
485+
def unstack(obj: Series | DataFrame, level, fill_value=None, sort: bool = True):
468486
if isinstance(level, (tuple, list)):
469487
if len(level) != 1:
470488
# _unstack_multiple only handles MultiIndexes,
471489
# and isn't needed for a single level
472-
return _unstack_multiple(obj, level, fill_value=fill_value)
490+
return _unstack_multiple(obj, level, fill_value=fill_value, sort=sort)
473491
else:
474492
level = level[0]
475493

@@ -479,9 +497,9 @@ def unstack(obj: Series | DataFrame, level, fill_value=None):
479497

480498
if isinstance(obj, DataFrame):
481499
if isinstance(obj.index, MultiIndex):
482-
return _unstack_frame(obj, level, fill_value=fill_value)
500+
return _unstack_frame(obj, level, fill_value=fill_value, sort=sort)
483501
else:
484-
return obj.T.stack(dropna=False)
502+
return obj.T.stack(dropna=False, sort=sort)
485503
elif not isinstance(obj.index, MultiIndex):
486504
# GH 36113
487505
# Give nicer error messages when unstack a Series whose
@@ -491,18 +509,22 @@ def unstack(obj: Series | DataFrame, level, fill_value=None):
491509
)
492510
else:
493511
if is_1d_only_ea_dtype(obj.dtype):
494-
return _unstack_extension_series(obj, level, fill_value)
512+
return _unstack_extension_series(obj, level, fill_value, sort=sort)
495513
unstacker = _Unstacker(
496-
obj.index, level=level, constructor=obj._constructor_expanddim
514+
obj.index, level=level, constructor=obj._constructor_expanddim, sort=sort
497515
)
498516
return unstacker.get_result(
499517
obj._values, value_columns=None, fill_value=fill_value
500518
)
501519

502520

503-
def _unstack_frame(obj: DataFrame, level, fill_value=None) -> DataFrame:
521+
def _unstack_frame(
522+
obj: DataFrame, level, fill_value=None, sort: bool = True
523+
) -> DataFrame:
504524
assert isinstance(obj.index, MultiIndex) # checked by caller
505-
unstacker = _Unstacker(obj.index, level=level, constructor=obj._constructor)
525+
unstacker = _Unstacker(
526+
obj.index, level=level, constructor=obj._constructor, sort=sort
527+
)
506528

507529
if not obj._can_fast_transpose:
508530
mgr = obj._mgr.unstack(unstacker, fill_value=fill_value)
@@ -513,7 +535,9 @@ def _unstack_frame(obj: DataFrame, level, fill_value=None) -> DataFrame:
513535
)
514536

515537

516-
def _unstack_extension_series(series: Series, level, fill_value) -> DataFrame:
538+
def _unstack_extension_series(
539+
series: Series, level, fill_value, sort: bool
540+
) -> DataFrame:
517541
"""
518542
Unstack an ExtensionArray-backed Series.
519543
@@ -529,6 +553,8 @@ def _unstack_extension_series(series: Series, level, fill_value) -> DataFrame:
529553
The user-level (not physical storage) fill value to use for
530554
missing values introduced by the reshape. Passed to
531555
``series.values.take``.
556+
sort : bool
557+
Whether to sort the resulting MuliIndex levels
532558
533559
Returns
534560
-------
@@ -538,7 +564,7 @@ def _unstack_extension_series(series: Series, level, fill_value) -> DataFrame:
538564
"""
539565
# Defer to the logic in ExtensionBlock._unstack
540566
df = series.to_frame()
541-
result = df.unstack(level=level, fill_value=fill_value)
567+
result = df.unstack(level=level, fill_value=fill_value, sort=sort)
542568

543569
# equiv: result.droplevel(level=0, axis=1)
544570
# but this avoids an extra copy

pandas/core/series.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -4274,7 +4274,9 @@ def explode(self, ignore_index: bool = False) -> Series:
42744274

42754275
return self._constructor(values, index=index, name=self.name, copy=False)
42764276

4277-
def unstack(self, level: IndexLabel = -1, fill_value: Hashable = None) -> DataFrame:
4277+
def unstack(
4278+
self, level: IndexLabel = -1, fill_value: Hashable = None, sort: bool = True
4279+
) -> DataFrame:
42784280
"""
42794281
Unstack, also known as pivot, Series with MultiIndex to produce DataFrame.
42804282
@@ -4284,6 +4286,8 @@ def unstack(self, level: IndexLabel = -1, fill_value: Hashable = None) -> DataFr
42844286
Level(s) to unstack, can pass level name.
42854287
fill_value : scalar value, default None
42864288
Value to use when replacing NaN values.
4289+
sort : bool, default True
4290+
Sort the level(s) in the resulting MultiIndex columns.
42874291
42884292
Returns
42894293
-------
@@ -4318,7 +4322,7 @@ def unstack(self, level: IndexLabel = -1, fill_value: Hashable = None) -> DataFr
43184322
"""
43194323
from pandas.core.reshape.reshape import unstack
43204324

4321-
return unstack(self, level, fill_value)
4325+
return unstack(self, level, fill_value, sort)
43224326

43234327
# ----------------------------------------------------------------------
43244328
# function application

pandas/tests/frame/test_stack_unstack.py

+38
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,44 @@ def test_unstack_swaplevel_sortlevel(self, level):
12101210
tm.assert_frame_equal(result, expected)
12111211

12121212

1213+
@pytest.mark.parametrize("dtype", ["float64", "Float64"])
1214+
def test_unstack_sort_false(frame_or_series, dtype):
1215+
# GH 15105
1216+
index = MultiIndex.from_tuples(
1217+
[("two", "z", "b"), ("two", "y", "a"), ("one", "z", "b"), ("one", "y", "a")]
1218+
)
1219+
obj = frame_or_series(np.arange(1.0, 5.0), index=index, dtype=dtype)
1220+
result = obj.unstack(level=-1, sort=False)
1221+
1222+
if frame_or_series is DataFrame:
1223+
expected_columns = MultiIndex.from_tuples([(0, "b"), (0, "a")])
1224+
else:
1225+
expected_columns = ["b", "a"]
1226+
expected = DataFrame(
1227+
[[1.0, np.nan], [np.nan, 2.0], [3.0, np.nan], [np.nan, 4.0]],
1228+
columns=expected_columns,
1229+
index=MultiIndex.from_tuples(
1230+
[("two", "z"), ("two", "y"), ("one", "z"), ("one", "y")]
1231+
),
1232+
dtype=dtype,
1233+
)
1234+
tm.assert_frame_equal(result, expected)
1235+
1236+
result = obj.unstack(level=[1, 2], sort=False)
1237+
1238+
if frame_or_series is DataFrame:
1239+
expected_columns = MultiIndex.from_tuples([(0, "z", "b"), (0, "y", "a")])
1240+
else:
1241+
expected_columns = MultiIndex.from_tuples([("z", "b"), ("y", "a")])
1242+
expected = DataFrame(
1243+
[[1.0, 2.0], [3.0, 4.0]],
1244+
index=["two", "one"],
1245+
columns=expected_columns,
1246+
dtype=dtype,
1247+
)
1248+
tm.assert_frame_equal(result, expected)
1249+
1250+
12131251
def test_unstack_fill_frame_object():
12141252
# GH12815 Test unstacking with object.
12151253
data = Series(["a", "b", "c", "a"], dtype="object")

0 commit comments

Comments
 (0)