Skip to content

Commit d30a8c3

Browse files
ENH: added support for Index.sort_values(key=...)
1 parent eee9120 commit d30a8c3

File tree

7 files changed

+121
-39
lines changed

7 files changed

+121
-39
lines changed

pandas/core/frame.py

+7-16
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import itertools
1616
import sys
1717
from textwrap import dedent
18-
from typing import FrozenSet, List, Optional, Set, Type, Union
18+
from typing import Callable, FrozenSet, List, Optional, Set, Type, Union
1919
import warnings
2020

2121
import numpy as np
@@ -4977,7 +4977,7 @@ def sort_values(
49774977
inplace=False,
49784978
kind="quicksort",
49794979
na_position="last",
4980-
key=None
4980+
key : Union[Callable, None] = None
49814981
):
49824982
inplace = validate_bool_kwarg(inplace, "inplace")
49834983
axis = self._get_axis_number(axis)
@@ -4992,29 +4992,20 @@ def sort_values(
49924992
if len(by) > 1:
49934993
from pandas.core.sorting import lexsort_indexer
49944994

4995-
if key is not None:
4996-
key_func = np.vectorize(key)
4997-
keys = [key_func(self._get_label_or_level_values(x, axis=axis)) for x in by]
4998-
else:
4999-
keys = [self._get_label_or_level_values(x, axis=axis) for x in by]
5000-
5001-
indexer = lexsort_indexer(keys, orders=ascending, na_position=na_position)
4995+
keys = [self._get_label_or_level_values(x, axis=axis) for x in by]
4996+
indexer = lexsort_indexer(keys, orders=ascending, na_position=na_position, key=key)
50024997
indexer = ensure_platform_int(indexer)
50034998
else:
50044999
from pandas.core.sorting import nargsort
50055000

50065001
by = by[0]
50075002
k = self._get_label_or_level_values(by, axis=axis)
50085003

5009-
if key is not None:
5010-
key_func = np.vectorize(key)
5011-
k = key_func(k)
5012-
50135004
if isinstance(ascending, (tuple, list)):
50145005
ascending = ascending[0]
50155006

50165007
indexer = nargsort(
5017-
k, kind=kind, ascending=ascending, na_position=na_position
5008+
k, kind=kind, ascending=ascending, na_position=na_position, key=key
50185009
)
50195010

50205011
new_data = self._data.take(
@@ -5038,7 +5029,7 @@ def sort_index(
50385029
na_position="last",
50395030
sort_remaining=True,
50405031
by=None,
5041-
key=None
5032+
key : Union[Callable, None] = None
50425033
):
50435034

50445035
# TODO: this can be combined with Series.sort_index impl as
@@ -5060,7 +5051,7 @@ def sort_index(
50605051
axis = self._get_axis_number(axis)
50615052
labels = self._get_axis(axis)
50625053
if key is not None:
5063-
labels = labels.map(key)
5054+
labels = labels.map(key, na_action="ignore")
50645055

50655056
# make sure that the axis is lexsorted to start
50665057
# if not we need to reconstruct to get the correct indexer

pandas/core/indexes/base.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from datetime import datetime, timedelta
22
import operator
33
from textwrap import dedent
4-
from typing import Union
4+
from typing import Callable, Union
55
import warnings
66

77
import numpy as np
@@ -4524,7 +4524,7 @@ def asof_locs(self, where, mask):
45244524

45254525
return result
45264526

4527-
def sort_values(self, return_indexer=False, ascending=True):
4527+
def sort_values(self, return_indexer=False, ascending=True, key : Callable = None):
45284528
"""
45294529
Return a sorted copy of the index.
45304530
@@ -4537,6 +4537,9 @@ def sort_values(self, return_indexer=False, ascending=True):
45374537
Should the indices that would sort the index be returned.
45384538
ascending : bool, default True
45394539
Should the index values be sorted in an ascending order.
4540+
key : Callable, default None
4541+
Apply a key function to the indices before sorting, like
4542+
built-in sorted function.
45404543
45414544
Returns
45424545
-------
@@ -4567,7 +4570,12 @@ def sort_values(self, return_indexer=False, ascending=True):
45674570
>>> idx.sort_values(ascending=False, return_indexer=True)
45684571
(Int64Index([1000, 100, 10, 1], dtype='int64'), array([3, 1, 0, 2]))
45694572
"""
4570-
_as = self.argsort()
4573+
if key:
4574+
idx = self.map(key, na_action="ignore")
4575+
else:
4576+
idx = self
4577+
4578+
_as = idx.argsort()
45714579
if not ascending:
45724580
_as = _as[::-1]
45734581

@@ -4679,9 +4687,12 @@ def argsort(self, *args, **kwargs):
46794687
>>> idx[order]
46804688
Index(['a', 'b', 'c', 'd'], dtype='object')
46814689
"""
4690+
46824691
result = self.asi8
4692+
46834693
if result is None:
46844694
result = np.array(self)
4695+
46854696
return result.argsort(*args, **kwargs)
46864697

46874698
_index_shared_docs[

pandas/core/series.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from io import StringIO
66
from shutil import get_terminal_size
77
from textwrap import dedent
8-
from typing import Any, Callable
8+
from typing import Any, Callable, Union
99
import warnings
1010

1111
import numpy as np
@@ -3019,7 +3019,7 @@ def sort_values(
30193019
inplace=False,
30203020
kind="quicksort",
30213021
na_position="last",
3022-
key=None
3022+
key: Callable = None
30233023
):
30243024
"""
30253025
Sort by the values.
@@ -3042,7 +3042,7 @@ def sort_values(
30423042
na_position : {'first' or 'last'}, default 'last'
30433043
Argument 'first' puts NaNs at the beginning, 'last' puts NaNs at
30443044
the end.
3045-
key : function, default None
3045+
key : Callable, default None
30463046
If not None, apply the key function to every value before
30473047
sorting. Identical to key argument in built-in sorted function.
30483048
@@ -3220,7 +3220,7 @@ def sort_index(
32203220
kind="quicksort",
32213221
na_position="last",
32223222
sort_remaining=True,
3223-
key=None
3223+
key : Callable = None
32243224
):
32253225
"""
32263226
Sort Series by index labels.
@@ -3249,7 +3249,7 @@ def sort_index(
32493249
sort_remaining : bool, default True
32503250
If True and sorting by level and index is multilevel, sort by other
32513251
levels too (in order) after sorting by specified level.
3252-
key : function, default None
3252+
key : Callable, default None
32533253
If not None, apply the key function to every index element before
32543254
sorting. Identical to key argument in built-in sorted function.
32553255
@@ -3356,7 +3356,7 @@ def sort_index(
33563356
index = self.index
33573357
true_index = index
33583358
if key is not None:
3359-
index = index.map(key)
3359+
index = index.map(key, na_action="ignore")
33603360

33613361
if level is not None:
33623362
new_index, indexer = index.sortlevel(

pandas/core/sorting.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
""" miscellaneous sorting / groupby utilities """
2+
from typing import Callable, Union
3+
24
import numpy as np
35

46
from pandas._libs import algos, hashtable, lib
@@ -188,7 +190,7 @@ def indexer_from_factorized(labels, shape, compress=True):
188190
return get_group_index_sorter(ids, ngroups)
189191

190192

191-
def lexsort_indexer(keys, orders=None, na_position="last"):
193+
def lexsort_indexer(keys, orders=None, na_position="last", key : Union[Callable, None] = None):
192194
from pandas.core.arrays import Categorical
193195

194196
labels = []
@@ -198,6 +200,10 @@ def lexsort_indexer(keys, orders=None, na_position="last"):
198200
elif orders is None:
199201
orders = [True] * len(keys)
200202

203+
if key:
204+
key_func = np.vectorize(key)
205+
keys = [key_func(entry) if entry.size != 0 else entry for entry in keys]
206+
201207
for key, order in zip(keys, orders):
202208

203209
# we are already a Categorical
@@ -234,7 +240,7 @@ def lexsort_indexer(keys, orders=None, na_position="last"):
234240
return indexer_from_factorized(labels, shape)
235241

236242

237-
def nargsort(items, kind="quicksort", ascending=True, na_position="last"):
243+
def nargsort(items, kind="quicksort", ascending=True, na_position="last", key: Union[Callable, None] = None):
238244
"""
239245
This is intended to be a drop-in replacement for np.argsort which
240246
handles NaNs. It adds ascending and na_position parameters.
@@ -250,6 +256,17 @@ def nargsort(items, kind="quicksort", ascending=True, na_position="last"):
250256
else:
251257
items = np.asanyarray(items)
252258

259+
if key is not None:
260+
key_func = np.vectorize(key)
261+
masked = np.ma.MaskedArray(items, mask)
262+
263+
if masked.size == 0:
264+
vals = np.array([]) # vectorize fails on empty object arrays
265+
else:
266+
vals = np.asarray(key_func(masked)) # revert from masked
267+
268+
return nargsort(vals, kind=kind, ascending=ascending, na_position=na_position, key=None)
269+
253270
idx = np.arange(len(items))
254271
non_nans = items[~mask]
255272
non_nan_idx = idx[~mask]

pandas/tests/frame/test_sorting.py

+37-9
Original file line numberDiff line numberDiff line change
@@ -83,29 +83,33 @@ def test_sort_values(self):
8383
with pytest.raises(ValueError, match=msg):
8484
frame.sort_values(by=["A", "B"], axis=0, ascending=[True] * 5)
8585

86-
def test_sort_values_inplace(self):
86+
@pytest.fixture(params=[None, lambda x : x])
87+
def key(self, request):
88+
return request.param
89+
90+
def test_sort_values_inplace(self, key):
8791
frame = DataFrame(
8892
np.random.randn(4, 4), index=[1, 2, 3, 4], columns=["A", "B", "C", "D"]
8993
)
9094

9195
sorted_df = frame.copy()
92-
sorted_df.sort_values(by="A", inplace=True)
93-
expected = frame.sort_values(by="A")
96+
sorted_df.sort_values(by="A", inplace=True, key=key)
97+
expected = frame.sort_values(by="A", key=key)
9498
assert_frame_equal(sorted_df, expected)
9599

96100
sorted_df = frame.copy()
97-
sorted_df.sort_values(by=1, axis=1, inplace=True)
98-
expected = frame.sort_values(by=1, axis=1)
101+
sorted_df.sort_values(by=1, axis=1, inplace=True, key=key)
102+
expected = frame.sort_values(by=1, axis=1, key=key)
99103
assert_frame_equal(sorted_df, expected)
100104

101105
sorted_df = frame.copy()
102-
sorted_df.sort_values(by="A", ascending=False, inplace=True)
103-
expected = frame.sort_values(by="A", ascending=False)
106+
sorted_df.sort_values(by="A", ascending=False, inplace=True, key=key)
107+
expected = frame.sort_values(by="A", ascending=False, key=key)
104108
assert_frame_equal(sorted_df, expected)
105109

106110
sorted_df = frame.copy()
107-
sorted_df.sort_values(by=["A", "B"], ascending=False, inplace=True)
108-
expected = frame.sort_values(by=["A", "B"], ascending=False)
111+
sorted_df.sort_values(by=["A", "B"], ascending=False, inplace=True, key=key)
112+
expected = frame.sort_values(by=["A", "B"], ascending=False, key=key)
109113
assert_frame_equal(sorted_df, expected)
110114

111115
def test_sort_nan(self):
@@ -249,6 +253,23 @@ def test_sort_multi_index(self):
249253

250254
tm.assert_frame_equal(result, expected)
251255

256+
def test_sort_multi_index_key(self):
257+
# GH 25775, testing that sorting by index works with a multi-index.
258+
df = DataFrame(
259+
{"a": [3, 1, 2], "b": [0, 0, 0], "c": [0, 1, 2], "d": list("abc")}
260+
)
261+
result = df.set_index(list("abc")).sort_index(level=list("ba"), key=lambda x : x[0])
262+
263+
expected = DataFrame(
264+
{"a": [1, 2, 3], "b": [0, 0, 0], "c": [1, 2, 0], "d": list("bca")}
265+
)
266+
expected = expected.set_index(list("abc"))
267+
tm.assert_frame_equal(result, expected)
268+
269+
result = df.set_index(list("abc")).sort_index(level=list("ba"), key=lambda x : x[2])
270+
expected = df.set_index(list("abc"))
271+
tm.assert_frame_equal(result, expected)
272+
252273
def test_stable_categorial(self):
253274
# GH 16793
254275
df = DataFrame({"x": pd.Categorical(np.repeat([1, 2, 3, 4], 5), ordered=True)})
@@ -628,6 +649,13 @@ def test_sort_value_key_nan(self):
628649
expected = df.sort_values(1, key=str.lower, ascending=False)
629650
assert_frame_equal(result, expected)
630651

652+
@pytest.mark.parametrize('key', [None, lambda x : x])
653+
def test_sort_value_key_empty(self, key):
654+
df = DataFrame(np.array([]))
655+
656+
df.sort_values(0, key=key)
657+
df.sort_index(key=key)
658+
631659
def test_sort_index(self):
632660
# GH13496
633661

pandas/tests/indexing/multiindex/test_sorted.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from pandas import DataFrame, MultiIndex, Series
55
from pandas.util import testing as tm
66

7+
import pytest
8+
79

810
class TestMultiIndexSorted:
911
def test_getitem_multilevel_index_tuple_not_sorted(self):
@@ -28,7 +30,8 @@ def test_getitem_slice_not_sorted(self, multiindex_dataframe_random_data):
2830
expected = df.reindex(columns=df.columns[:3])
2931
tm.assert_frame_equal(result, expected)
3032

31-
def test_frame_getitem_not_sorted2(self):
33+
@pytest.mark.parametrize('key', [None, lambda x : x])
34+
def test_frame_getitem_not_sorted2(self, key):
3235
# 13431
3336
df = DataFrame(
3437
{
@@ -47,15 +50,35 @@ def test_frame_getitem_not_sorted2(self):
4750
assert not df2.index.is_monotonic
4851

4952
assert df2_original.index.equals(df2.index)
50-
expected = df2.sort_index()
53+
expected = df2.sort_index(key=key)
5154
assert expected.index.is_lexsorted()
5255
assert expected.index.is_monotonic
5356

54-
result = df2.sort_index(level=0)
57+
result = df2.sort_index(level=0, key=key)
5558
assert result.index.is_lexsorted()
5659
assert result.index.is_monotonic
5760
tm.assert_frame_equal(result, expected)
5861

62+
def test_sort_values_key(self, multiindex_dataframe_random_data):
63+
arrays = [
64+
["bar", "bar", "baz", "baz", "qux", "qux", "foo", "foo"],
65+
["one", "two", "one", "two", "one", "two", "one", "two"],
66+
]
67+
tuples = zip(*arrays)
68+
index = MultiIndex.from_tuples(tuples)
69+
index = index.sort_values(key=lambda x: (x[0][2], x[1][2]))
70+
result = DataFrame(range(8), index=index)
71+
72+
arrays = [
73+
["foo", "foo", "bar", "bar", "qux", "qux", "baz", "baz"],
74+
["one", "two", "one", "two", "one", "two", "one", "two"],
75+
]
76+
tuples = zip(*arrays)
77+
index = MultiIndex.from_tuples(tuples)
78+
expected = DataFrame(range(8), index=index)
79+
80+
tm.assert_frame_equal(result, expected)
81+
5982
def test_frame_getitem_not_sorted(self, multiindex_dataframe_random_data):
6083
frame = multiindex_dataframe_random_data
6184
df = frame.T

pandas/tests/series/test_sorting.py

+12
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,18 @@ def test_sort_index_multiindex(self, level):
152152
res = s.sort_index(level=level, sort_remaining=False)
153153
assert_series_equal(s, res)
154154

155+
def test_sort_index_multiindex_key(self):
156+
157+
mi = MultiIndex.from_tuples([[1, 1, 3], [1, 1, 1]], names=list("ABC"))
158+
s = Series([1, 2], mi)
159+
backwards = s.iloc[[1, 0]]
160+
161+
res = s.sort_index(key=lambda x : x[2])
162+
assert_series_equal(backwards, res)
163+
164+
res = s.sort_index(key=lambda x : x[1]) # nothing happens
165+
assert_series_equal(s, res)
166+
155167
def test_sort_index_kind(self):
156168
# GH #14444 & #13589: Add support for sort algo choosing
157169
series = Series(index=[3, 2, 1, 4, 3])

0 commit comments

Comments
 (0)