Skip to content

Commit 2be9e73

Browse files
ENH: added support for Index.sort_values(key=...)
1 parent 2a8ae69 commit 2be9e73

File tree

7 files changed

+124
-40
lines changed

7 files changed

+124
-40
lines changed

pandas/core/frame.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Tuple,
2626
Type,
2727
Union,
28+
Callable
2829
)
2930
import warnings
3031

@@ -4715,7 +4716,7 @@ def sort_values(
47154716
inplace=False,
47164717
kind="quicksort",
47174718
na_position="last",
4718-
key=None
4719+
key : Union[Callable, None] = None
47194720
):
47204721
inplace = validate_bool_kwarg(inplace, "inplace")
47214722
axis = self._get_axis_number(axis)
@@ -4729,29 +4730,20 @@ def sort_values(
47294730
if len(by) > 1:
47304731
from pandas.core.sorting import lexsort_indexer
47314732

4732-
if key is not None:
4733-
key_func = np.vectorize(key)
4734-
keys = [key_func(self._get_label_or_level_values(x, axis=axis)) for x in by]
4735-
else:
4736-
keys = [self._get_label_or_level_values(x, axis=axis) for x in by]
4737-
4738-
indexer = lexsort_indexer(keys, orders=ascending, na_position=na_position)
4733+
keys = [self._get_label_or_level_values(x, axis=axis) for x in by]
4734+
indexer = lexsort_indexer(keys, orders=ascending, na_position=na_position, key=key)
47394735
indexer = ensure_platform_int(indexer)
47404736
else:
47414737
from pandas.core.sorting import nargsort
47424738

47434739
by = by[0]
47444740
k = self._get_label_or_level_values(by, axis=axis)
47454741

4746-
if key is not None:
4747-
key_func = np.vectorize(key)
4748-
k = key_func(k)
4749-
47504742
if isinstance(ascending, (tuple, list)):
47514743
ascending = ascending[0]
47524744

47534745
indexer = nargsort(
4754-
k, kind=kind, ascending=ascending, na_position=na_position
4746+
k, kind=kind, ascending=ascending, na_position=na_position, key=key
47554747
)
47564748

47574749
new_data = self._data.take(
@@ -4785,7 +4777,7 @@ def sort_index(
47854777
axis = self._get_axis_number(axis)
47864778
labels = self._get_axis(axis)
47874779
if key is not None:
4788-
labels = labels.map(key)
4780+
labels = labels.map(key, na_action="ignore")
47894781

47904782
# make sure that the axis is lexsorted to start
47914783
# if not we need to reconstruct to get the correct indexer

pandas/core/indexes/base.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from datetime import datetime
22
import operator
33
from textwrap import dedent
4-
from typing import FrozenSet, Union
4+
5+
from typing import FrozenSet, Union, Callable
56
import warnings
67

78
import numpy as np
@@ -4400,7 +4401,7 @@ def asof_locs(self, where, mask):
44004401

44014402
return result
44024403

4403-
def sort_values(self, return_indexer=False, ascending=True):
4404+
def sort_values(self, return_indexer=False, ascending=True, key : Callable = None):
44044405
"""
44054406
Return a sorted copy of the index.
44064407
@@ -4413,6 +4414,9 @@ def sort_values(self, return_indexer=False, ascending=True):
44134414
Should the indices that would sort the index be returned.
44144415
ascending : bool, default True
44154416
Should the index values be sorted in an ascending order.
4417+
key : Callable, default None
4418+
Apply a key function to the indices before sorting, like
4419+
built-in sorted function.
44164420
44174421
Returns
44184422
-------
@@ -4443,7 +4447,12 @@ def sort_values(self, return_indexer=False, ascending=True):
44434447
>>> idx.sort_values(ascending=False, return_indexer=True)
44444448
(Int64Index([1000, 100, 10, 1], dtype='int64'), array([3, 1, 0, 2]))
44454449
"""
4446-
_as = self.argsort()
4450+
if key:
4451+
idx = self.map(key, na_action="ignore")
4452+
else:
4453+
idx = self
4454+
4455+
_as = idx.argsort()
44474456
if not ascending:
44484457
_as = _as[::-1]
44494458

@@ -4553,9 +4562,12 @@ def argsort(self, *args, **kwargs):
45534562
>>> idx[order]
45544563
Index(['a', 'b', 'c', 'd'], dtype='object')
45554564
"""
4565+
45564566
result = self.asi8
4567+
45574568
if result is None:
45584569
result = np.array(self)
4570+
45594571
return result.argsort(*args, **kwargs)
45604572

45614573
_index_shared_docs[

pandas/core/series.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -2835,7 +2835,7 @@ def sort_values(
28352835
inplace=False,
28362836
kind="quicksort",
28372837
na_position="last",
2838-
key=None
2838+
key: Callable = None
28392839
):
28402840
"""
28412841
Sort by the values.
@@ -2858,7 +2858,7 @@ def sort_values(
28582858
na_position : {'first' or 'last'}, default 'last'
28592859
Argument 'first' puts NaNs at the beginning, 'last' puts NaNs at
28602860
the end.
2861-
key : function, default None
2861+
key : Callable, default None
28622862
If not None, apply the key function to every value before
28632863
sorting. Identical to key argument in built-in sorted function.
28642864
@@ -3035,7 +3035,7 @@ def sort_index(
30353035
kind="quicksort",
30363036
na_position="last",
30373037
sort_remaining=True,
3038-
key=None
3038+
key : Callable = None
30393039
):
30403040
"""
30413041
Sort Series by index labels.
@@ -3064,7 +3064,7 @@ def sort_index(
30643064
sort_remaining : bool, default True
30653065
If True and sorting by level and index is multilevel, sort by other
30663066
levels too (in order) after sorting by specified level.
3067-
key : function, default None
3067+
key : Callable, default None
30683068
If not None, apply the key function to every index element before
30693069
sorting. Identical to key argument in built-in sorted function.
30703070
@@ -3171,7 +3171,7 @@ def sort_index(
31713171
index = self.index
31723172
true_index = index
31733173
if key is not None:
3174-
index = index.map(key)
3174+
index = index.map(key, na_action="ignore")
31753175

31763176
if level is not None:
31773177
new_index, indexer = index.sortlevel(

pandas/core/sorting.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
""" miscellaneous sorting / groupby utilities """
2+
from typing import Callable, Union
3+
24
import numpy as np
35

46
from pandas._libs import algos, hashtable, lib
@@ -187,7 +189,7 @@ def indexer_from_factorized(labels, shape, compress: bool = True):
187189
return get_group_index_sorter(ids, ngroups)
188190

189191

190-
def lexsort_indexer(keys, orders=None, na_position="last"):
192+
def lexsort_indexer(keys, orders=None, na_position="last", key : Union[Callable, None] = None):
191193
from pandas.core.arrays import Categorical
192194

193195
labels = []
@@ -197,6 +199,10 @@ def lexsort_indexer(keys, orders=None, na_position="last"):
197199
elif orders is None:
198200
orders = [True] * len(keys)
199201

202+
if key:
203+
key_func = np.vectorize(key)
204+
keys = [key_func(entry) if entry.size != 0 else entry for entry in keys]
205+
200206
for key, order in zip(keys, orders):
201207

202208
# we are already a Categorical
@@ -233,7 +239,7 @@ def lexsort_indexer(keys, orders=None, na_position="last"):
233239
return indexer_from_factorized(labels, shape)
234240

235241

236-
def nargsort(items, kind="quicksort", ascending: bool = True, na_position="last"):
242+
def nargsort(items, kind="quicksort", ascending=True, na_position="last", key: Union[Callable, None] = None):
237243
"""
238244
This is intended to be a drop-in replacement for np.argsort which
239245
handles NaNs. It adds ascending and na_position parameters.
@@ -247,6 +253,17 @@ def nargsort(items, kind="quicksort", ascending: bool = True, na_position="last"
247253
else:
248254
items = np.asanyarray(items)
249255

256+
if key is not None:
257+
key_func = np.vectorize(key)
258+
masked = np.ma.MaskedArray(items, mask)
259+
260+
if masked.size == 0:
261+
vals = np.array([]) # vectorize fails on empty object arrays
262+
else:
263+
vals = np.asarray(key_func(masked)) # revert from masked
264+
265+
return nargsort(vals, kind=kind, ascending=ascending, na_position=na_position, key=None)
266+
250267
idx = np.arange(len(items))
251268
non_nans = items[~mask]
252269
non_nan_idx = idx[~mask]

pandas/tests/frame/test_sorting.py

+41-13
Original file line numberDiff line numberDiff line change
@@ -81,30 +81,34 @@ def test_sort_values(self):
8181
with pytest.raises(ValueError, match=msg):
8282
frame.sort_values(by=["A", "B"], axis=0, ascending=[True] * 5)
8383

84-
def test_sort_values_inplace(self):
84+
@pytest.fixture(params=[None, lambda x : x])
85+
def key(self, request):
86+
return request.param
87+
88+
def test_sort_values_inplace(self, key):
8589
frame = DataFrame(
8690
np.random.randn(4, 4), index=[1, 2, 3, 4], columns=["A", "B", "C", "D"]
8791
)
8892

8993
sorted_df = frame.copy()
90-
sorted_df.sort_values(by="A", inplace=True)
91-
expected = frame.sort_values(by="A")
92-
tm.assert_frame_equal(sorted_df, expected)
94+
sorted_df.sort_values(by="A", inplace=True, key=key)
95+
expected = frame.sort_values(by="A", key=key)
96+
assert_frame_equal(sorted_df, expected)
9397

9498
sorted_df = frame.copy()
95-
sorted_df.sort_values(by=1, axis=1, inplace=True)
96-
expected = frame.sort_values(by=1, axis=1)
97-
tm.assert_frame_equal(sorted_df, expected)
99+
sorted_df.sort_values(by=1, axis=1, inplace=True, key=key)
100+
expected = frame.sort_values(by=1, axis=1, key=key)
101+
assert_frame_equal(sorted_df, expected)
98102

99103
sorted_df = frame.copy()
100-
sorted_df.sort_values(by="A", ascending=False, inplace=True)
101-
expected = frame.sort_values(by="A", ascending=False)
102-
tm.assert_frame_equal(sorted_df, expected)
104+
sorted_df.sort_values(by="A", ascending=False, inplace=True, key=key)
105+
expected = frame.sort_values(by="A", ascending=False, key=key)
106+
assert_frame_equal(sorted_df, expected)
103107

104108
sorted_df = frame.copy()
105-
sorted_df.sort_values(by=["A", "B"], ascending=False, inplace=True)
106-
expected = frame.sort_values(by=["A", "B"], ascending=False)
107-
tm.assert_frame_equal(sorted_df, expected)
109+
sorted_df.sort_values(by=["A", "B"], ascending=False, inplace=True, key=key)
110+
expected = frame.sort_values(by=["A", "B"], ascending=False, key=key)
111+
assert_frame_equal(sorted_df, expected)
108112

109113
def test_sort_nan(self):
110114
# GH3917
@@ -247,6 +251,23 @@ def test_sort_multi_index(self):
247251

248252
tm.assert_frame_equal(result, expected)
249253

254+
def test_sort_multi_index_key(self):
255+
# GH 25775, testing that sorting by index works with a multi-index.
256+
df = DataFrame(
257+
{"a": [3, 1, 2], "b": [0, 0, 0], "c": [0, 1, 2], "d": list("abc")}
258+
)
259+
result = df.set_index(list("abc")).sort_index(level=list("ba"), key=lambda x : x[0])
260+
261+
expected = DataFrame(
262+
{"a": [1, 2, 3], "b": [0, 0, 0], "c": [1, 2, 0], "d": list("bca")}
263+
)
264+
expected = expected.set_index(list("abc"))
265+
tm.assert_frame_equal(result, expected)
266+
267+
result = df.set_index(list("abc")).sort_index(level=list("ba"), key=lambda x : x[2])
268+
expected = df.set_index(list("abc"))
269+
tm.assert_frame_equal(result, expected)
270+
250271
def test_stable_categorial(self):
251272
# GH 16793
252273
df = DataFrame({"x": pd.Categorical(np.repeat([1, 2, 3, 4], 5), ordered=True)})
@@ -558,6 +579,13 @@ def test_sort_value_key_nan(self):
558579
expected = df.sort_values(1, key=str.lower, ascending=False)
559580
assert_frame_equal(result, expected)
560581

582+
@pytest.mark.parametrize('key', [None, lambda x : x])
583+
def test_sort_value_key_empty(self, key):
584+
df = DataFrame(np.array([]))
585+
586+
df.sort_values(0, key=key)
587+
df.sort_index(key=key)
588+
561589
def test_sort_index(self):
562590
# GH13496
563591

pandas/tests/indexing/multiindex/test_sorted.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from pandas import DataFrame, MultiIndex, Series
55
import pandas.util.testing as tm
66

7+
import pytest
8+
79

810
class TestMultiIndexSorted:
911
def test_getitem_multilevel_index_tuple_not_sorted(self):
@@ -28,7 +30,8 @@ def test_getitem_slice_not_sorted(self, multiindex_dataframe_random_data):
2830
expected = df.reindex(columns=df.columns[:3])
2931
tm.assert_frame_equal(result, expected)
3032

31-
def test_frame_getitem_not_sorted2(self):
33+
@pytest.mark.parametrize('key', [None, lambda x : x])
34+
def test_frame_getitem_not_sorted2(self, key):
3235
# 13431
3336
df = DataFrame(
3437
{
@@ -47,15 +50,35 @@ def test_frame_getitem_not_sorted2(self):
4750
assert not df2.index.is_monotonic
4851

4952
assert df2_original.index.equals(df2.index)
50-
expected = df2.sort_index()
53+
expected = df2.sort_index(key=key)
5154
assert expected.index.is_lexsorted()
5255
assert expected.index.is_monotonic
5356

54-
result = df2.sort_index(level=0)
57+
result = df2.sort_index(level=0, key=key)
5558
assert result.index.is_lexsorted()
5659
assert result.index.is_monotonic
5760
tm.assert_frame_equal(result, expected)
5861

62+
def test_sort_values_key(self, multiindex_dataframe_random_data):
63+
arrays = [
64+
["bar", "bar", "baz", "baz", "qux", "qux", "foo", "foo"],
65+
["one", "two", "one", "two", "one", "two", "one", "two"],
66+
]
67+
tuples = zip(*arrays)
68+
index = MultiIndex.from_tuples(tuples)
69+
index = index.sort_values(key=lambda x: (x[0][2], x[1][2]))
70+
result = DataFrame(range(8), index=index)
71+
72+
arrays = [
73+
["foo", "foo", "bar", "bar", "qux", "qux", "baz", "baz"],
74+
["one", "two", "one", "two", "one", "two", "one", "two"],
75+
]
76+
tuples = zip(*arrays)
77+
index = MultiIndex.from_tuples(tuples)
78+
expected = DataFrame(range(8), index=index)
79+
80+
tm.assert_frame_equal(result, expected)
81+
5982
def test_frame_getitem_not_sorted(self, multiindex_dataframe_random_data):
6083
frame = multiindex_dataframe_random_data
6184
df = frame.T

pandas/tests/series/test_sorting.py

+12
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,18 @@ def test_sort_index_multiindex(self, level):
155155
res = s.sort_index(level=level, sort_remaining=False)
156156
tm.assert_series_equal(s, res)
157157

158+
def test_sort_index_multiindex_key(self):
159+
160+
mi = MultiIndex.from_tuples([[1, 1, 3], [1, 1, 1]], names=list("ABC"))
161+
s = Series([1, 2], mi)
162+
backwards = s.iloc[[1, 0]]
163+
164+
res = s.sort_index(key=lambda x : x[2])
165+
assert_series_equal(backwards, res)
166+
167+
res = s.sort_index(key=lambda x : x[1]) # nothing happens
168+
assert_series_equal(s, res)
169+
158170
def test_sort_index_kind(self):
159171
# GH #14444 & #13589: Add support for sort algo choosing
160172
series = Series(index=[3, 2, 1, 4, 3])

0 commit comments

Comments
 (0)