Skip to content

ENH: Add sort keyword to unstack #53298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 1, 2023
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ Other enhancements
- Performance improvement in :func:`read_csv` (:issue:`52632`) with ``engine="c"``
- :meth:`Categorical.from_codes` has gotten a ``validate`` parameter (:issue:`50975`)
- :meth:`DataFrame.stack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
- :meth:`DataFrame.unstack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
- Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`)
- Performance improvement in :func:`concat` with homogeneous ``np.float64`` or ``np.float32`` dtypes (:issue:`52685`)
- Performance improvement in :meth:`DataFrame.filter` when ``items`` is given (:issue:`52941`)
Expand Down
6 changes: 4 additions & 2 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9296,7 +9296,7 @@ def explode(

return result.__finalize__(self, method="explode")

def unstack(self, level: Level = -1, fill_value=None):
def unstack(self, level: Level = -1, fill_value=None, sort: bool = True):
"""
Pivot a level of the (necessarily hierarchical) index labels.

Expand All @@ -9312,6 +9312,8 @@ def unstack(self, level: Level = -1, fill_value=None):
Level(s) of index to unstack, can pass level name.
fill_value : int, str or dict
Replace NaN with this value if the unstack produces missing values.
sort : bool, default True
Sort the level(s) in the resulting MultiIndex columns.

Returns
-------
Expand Down Expand Up @@ -9359,7 +9361,7 @@ def unstack(self, level: Level = -1, fill_value=None):
"""
from pandas.core.reshape.reshape import unstack

result = unstack(self, level, fill_value)
result = unstack(self, level, fill_value, sort)

return result.__finalize__(self, method="unstack")

Expand Down
68 changes: 47 additions & 21 deletions pandas/core/reshape/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,11 @@ class _Unstacker:
unstacked : DataFrame
"""

def __init__(self, index: MultiIndex, level: Level, constructor) -> None:
def __init__(
self, index: MultiIndex, level: Level, constructor, sort: bool = True
) -> None:
self.constructor = constructor
self.sort = sort

self.index = index.remove_unused_levels()

Expand All @@ -119,11 +122,15 @@ def __init__(self, index: MultiIndex, level: Level, constructor) -> None:
self.removed_name = self.new_index_names.pop(self.level)
self.removed_level = self.new_index_levels.pop(self.level)
self.removed_level_full = index.levels[self.level]
if not self.sort:
unique_codes = unique(self.index.codes[self.level])
self.removed_level = self.removed_level.take(unique_codes)
self.removed_level_full = self.removed_level_full.take(unique_codes)

# Bug fix GH 20601
# If the data frame is too big, the number of unique index combination
# will cause int32 overflow on windows environments.
# We want to check and raise an error before this happens
# We want to check and raise an warning before this happens
num_rows = np.max([index_level.size for index_level in self.new_index_levels])
num_columns = self.removed_level.size

Expand Down Expand Up @@ -164,13 +171,17 @@ def _indexer_and_to_sort(
@cache_readonly
def sorted_labels(self) -> list[np.ndarray]:
indexer, to_sort = self._indexer_and_to_sort
return [line.take(indexer) for line in to_sort]
if self.sort:
return [line.take(indexer) for line in to_sort]
return to_sort

def _make_sorted_values(self, values: np.ndarray) -> np.ndarray:
indexer, _ = self._indexer_and_to_sort
if self.sort:
indexer, _ = self._indexer_and_to_sort

sorted_values = algos.take_nd(values, indexer, axis=0)
return sorted_values
sorted_values = algos.take_nd(values, indexer, axis=0)
return sorted_values
return values

def _make_selectors(self):
new_levels = self.new_index_levels
Expand All @@ -195,7 +206,10 @@ def _make_selectors(self):

self.group_index = comp_index
self.mask = mask
self.compressor = comp_index.searchsorted(np.arange(ngroups))
if self.sort:
self.compressor = comp_index.searchsorted(np.arange(ngroups))
else:
self.compressor = np.sort(np.unique(comp_index, return_index=True)[1])

@cache_readonly
def mask_all(self) -> bool:
Expand Down Expand Up @@ -376,7 +390,9 @@ def new_index(self) -> MultiIndex:
)


def _unstack_multiple(data: Series | DataFrame, clocs, fill_value=None):
def _unstack_multiple(
data: Series | DataFrame, clocs, fill_value=None, sort: bool = True
):
if len(clocs) == 0:
return data

Expand Down Expand Up @@ -421,7 +437,7 @@ def _unstack_multiple(data: Series | DataFrame, clocs, fill_value=None):
dummy = data.copy()
dummy.index = dummy_index

unstacked = dummy.unstack("__placeholder__", fill_value=fill_value)
unstacked = dummy.unstack("__placeholder__", fill_value=fill_value, sort=sort)
new_levels = clevels
new_names = cnames
new_codes = recons_codes
Expand All @@ -430,7 +446,7 @@ def _unstack_multiple(data: Series | DataFrame, clocs, fill_value=None):
result = data
while clocs:
val = clocs.pop(0)
result = result.unstack(val, fill_value=fill_value)
result = result.unstack(val, fill_value=fill_value, sort=sort)
clocs = [v if v < val else v - 1 for v in clocs]

return result
Expand All @@ -439,7 +455,9 @@ def _unstack_multiple(data: Series | DataFrame, clocs, fill_value=None):
dummy_df = data.copy(deep=False)
dummy_df.index = dummy_index

unstacked = dummy_df.unstack("__placeholder__", fill_value=fill_value)
unstacked = dummy_df.unstack(
"__placeholder__", fill_value=fill_value, sort=sort
)
if isinstance(unstacked, Series):
unstcols = unstacked.index
else:
Expand All @@ -464,12 +482,12 @@ def _unstack_multiple(data: Series | DataFrame, clocs, fill_value=None):
return unstacked


def unstack(obj: Series | DataFrame, level, fill_value=None):
def unstack(obj: Series | DataFrame, level, fill_value=None, sort: bool = True):
if isinstance(level, (tuple, list)):
if len(level) != 1:
# _unstack_multiple only handles MultiIndexes,
# and isn't needed for a single level
return _unstack_multiple(obj, level, fill_value=fill_value)
return _unstack_multiple(obj, level, fill_value=fill_value, sort=sort)
else:
level = level[0]

Expand All @@ -479,9 +497,9 @@ def unstack(obj: Series | DataFrame, level, fill_value=None):

if isinstance(obj, DataFrame):
if isinstance(obj.index, MultiIndex):
return _unstack_frame(obj, level, fill_value=fill_value)
return _unstack_frame(obj, level, fill_value=fill_value, sort=sort)
else:
return obj.T.stack(dropna=False)
return obj.T.stack(dropna=False, sort=sort)
elif not isinstance(obj.index, MultiIndex):
# GH 36113
# Give nicer error messages when unstack a Series whose
Expand All @@ -491,18 +509,22 @@ def unstack(obj: Series | DataFrame, level, fill_value=None):
)
else:
if is_1d_only_ea_dtype(obj.dtype):
return _unstack_extension_series(obj, level, fill_value)
return _unstack_extension_series(obj, level, fill_value, sort=sort)
unstacker = _Unstacker(
obj.index, level=level, constructor=obj._constructor_expanddim
obj.index, level=level, constructor=obj._constructor_expanddim, sort=sort
)
return unstacker.get_result(
obj._values, value_columns=None, fill_value=fill_value
)


def _unstack_frame(obj: DataFrame, level, fill_value=None) -> DataFrame:
def _unstack_frame(
obj: DataFrame, level, fill_value=None, sort: bool = True
) -> DataFrame:
assert isinstance(obj.index, MultiIndex) # checked by caller
unstacker = _Unstacker(obj.index, level=level, constructor=obj._constructor)
unstacker = _Unstacker(
obj.index, level=level, constructor=obj._constructor, sort=sort
)

if not obj._can_fast_transpose:
mgr = obj._mgr.unstack(unstacker, fill_value=fill_value)
Expand All @@ -513,7 +535,9 @@ def _unstack_frame(obj: DataFrame, level, fill_value=None) -> DataFrame:
)


def _unstack_extension_series(series: Series, level, fill_value) -> DataFrame:
def _unstack_extension_series(
series: Series, level, fill_value, sort: bool
) -> DataFrame:
"""
Unstack an ExtensionArray-backed Series.

Expand All @@ -529,6 +553,8 @@ def _unstack_extension_series(series: Series, level, fill_value) -> DataFrame:
The user-level (not physical storage) fill value to use for
missing values introduced by the reshape. Passed to
``series.values.take``.
sort : bool
Whether to sort the resulting MuliIndex levels

Returns
-------
Expand All @@ -538,7 +564,7 @@ def _unstack_extension_series(series: Series, level, fill_value) -> DataFrame:
"""
# Defer to the logic in ExtensionBlock._unstack
df = series.to_frame()
result = df.unstack(level=level, fill_value=fill_value)
result = df.unstack(level=level, fill_value=fill_value, sort=sort)

# equiv: result.droplevel(level=0, axis=1)
# but this avoids an extra copy
Expand Down
8 changes: 6 additions & 2 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4274,7 +4274,9 @@ def explode(self, ignore_index: bool = False) -> Series:

return self._constructor(values, index=index, name=self.name, copy=False)

def unstack(self, level: IndexLabel = -1, fill_value: Hashable = None) -> DataFrame:
def unstack(
self, level: IndexLabel = -1, fill_value: Hashable = None, sort: bool = True
) -> DataFrame:
"""
Unstack, also known as pivot, Series with MultiIndex to produce DataFrame.

Expand All @@ -4284,6 +4286,8 @@ def unstack(self, level: IndexLabel = -1, fill_value: Hashable = None) -> DataFr
Level(s) to unstack, can pass level name.
fill_value : scalar value, default None
Value to use when replacing NaN values.
sort : bool, default True
Sort the level(s) in the resulting MultiIndex columns.

Returns
-------
Expand Down Expand Up @@ -4318,7 +4322,7 @@ def unstack(self, level: IndexLabel = -1, fill_value: Hashable = None) -> DataFr
"""
from pandas.core.reshape.reshape import unstack

return unstack(self, level, fill_value)
return unstack(self, level, fill_value, sort)

# ----------------------------------------------------------------------
# function application
Expand Down
38 changes: 38 additions & 0 deletions pandas/tests/frame/test_stack_unstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,44 @@ def test_unstack_swaplevel_sortlevel(self, level):
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("dtype", ["float64", "Float64"])
def test_unstack_sort_false(frame_or_series, dtype):
# GH 15105
index = MultiIndex.from_tuples(
[("two", "z", "b"), ("two", "y", "a"), ("one", "z", "b"), ("one", "y", "a")]
)
obj = frame_or_series(np.arange(1.0, 5.0), index=index, dtype=dtype)
result = obj.unstack(level=-1, sort=False)

if frame_or_series is DataFrame:
expected_columns = MultiIndex.from_tuples([(0, "b"), (0, "a")])
else:
expected_columns = ["b", "a"]
expected = DataFrame(
[[1.0, np.nan], [np.nan, 2.0], [3.0, np.nan], [np.nan, 4.0]],
columns=expected_columns,
index=MultiIndex.from_tuples(
[("two", "z"), ("two", "y"), ("one", "z"), ("one", "y")]
),
dtype=dtype,
)
tm.assert_frame_equal(result, expected)

result = obj.unstack(level=[1, 2], sort=False)

if frame_or_series is DataFrame:
expected_columns = MultiIndex.from_tuples([(0, "z", "b"), (0, "y", "a")])
else:
expected_columns = MultiIndex.from_tuples([("z", "b"), ("y", "a")])
expected = DataFrame(
[[1.0, 2.0], [3.0, 4.0]],
index=["two", "one"],
columns=expected_columns,
dtype=dtype,
)
tm.assert_frame_equal(result, expected)


def test_unstack_fill_frame_object():
# GH12815 Test unstacking with object.
data = Series(["a", "b", "c", "a"], dtype="object")
Expand Down