Skip to content

Commit c302b04

Browse files
TomAugspurgerjavadnoorb
authored andcommitted
ENH: Sorting of ExtensionArrays (pandas-dev#19957)
1 parent 189dd8e commit c302b04

File tree

8 files changed

+235
-18
lines changed

8 files changed

+235
-18
lines changed

pandas/core/arrays/base.py

+52
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33

44
from pandas.errors import AbstractMethodError
5+
from pandas.compat.numpy import function as nv
56

67
_not_implemented_message = "{} does not implement {}."
78

@@ -236,6 +237,57 @@ def isna(self):
236237
"""
237238
raise AbstractMethodError(self)
238239

240+
def _values_for_argsort(self):
241+
# type: () -> ndarray
242+
"""Return values for sorting.
243+
244+
Returns
245+
-------
246+
ndarray
247+
The transformed values should maintain the ordering between values
248+
within the array.
249+
250+
See Also
251+
--------
252+
ExtensionArray.argsort
253+
"""
254+
# Note: this is used in `ExtensionArray.argsort`.
255+
return np.array(self)
256+
257+
def argsort(self, ascending=True, kind='quicksort', *args, **kwargs):
258+
"""
259+
Return the indices that would sort this array.
260+
261+
Parameters
262+
----------
263+
ascending : bool, default True
264+
Whether the indices should result in an ascending
265+
or descending sort.
266+
kind : {'quicksort', 'mergesort', 'heapsort'}, optional
267+
Sorting algorithm.
268+
*args, **kwargs:
269+
passed through to :func:`numpy.argsort`.
270+
271+
Returns
272+
-------
273+
index_array : ndarray
274+
Array of indices that sort ``self``.
275+
276+
See Also
277+
--------
278+
numpy.argsort : Sorting implementation used internally.
279+
"""
280+
# Implementor note: You have two places to override the behavior of
281+
# argsort.
282+
# 1. _values_for_argsort : construct the values passed to np.argsort
283+
# 2. argsort : total control over sorting.
284+
ascending = nv.validate_argsort_with_ascending(ascending, args, kwargs)
285+
values = self._values_for_argsort()
286+
result = np.argsort(values, kind=kind, **kwargs)
287+
if not ascending:
288+
result = result[::-1]
289+
return result
290+
239291
def fillna(self, value=None, method=None, limit=None):
240292
""" Fill NA/NaN values using the specified method.
241293

pandas/core/arrays/categorical.py

+38-15
Original file line numberDiff line numberDiff line change
@@ -1431,17 +1431,24 @@ def check_for_ordered(self, op):
14311431
"you can use .as_ordered() to change the "
14321432
"Categorical to an ordered one\n".format(op=op))
14331433

1434-
def argsort(self, ascending=True, kind='quicksort', *args, **kwargs):
1435-
"""
1436-
Returns the indices that would sort the Categorical instance if
1437-
'sort_values' was called. This function is implemented to provide
1438-
compatibility with numpy ndarray objects.
1434+
def _values_for_argsort(self):
1435+
return self._codes.copy()
14391436

1440-
While an ordering is applied to the category values, arg-sorting
1441-
in this context refers more to organizing and grouping together
1442-
based on matching category values. Thus, this function can be
1443-
called on an unordered Categorical instance unlike the functions
1444-
'Categorical.min' and 'Categorical.max'.
1437+
def argsort(self, *args, **kwargs):
1438+
# TODO(PY2): use correct signature
1439+
# We have to do *args, **kwargs to avoid a a py2-only signature
1440+
# issue since np.argsort differs from argsort.
1441+
"""Return the indicies that would sort the Categorical.
1442+
1443+
Parameters
1444+
----------
1445+
ascending : bool, default True
1446+
Whether the indices should result in an ascending
1447+
or descending sort.
1448+
kind : {'quicksort', 'mergesort', 'heapsort'}, optional
1449+
Sorting algorithm.
1450+
*args, **kwargs:
1451+
passed through to :func:`numpy.argsort`.
14451452
14461453
Returns
14471454
-------
@@ -1450,12 +1457,28 @@ def argsort(self, ascending=True, kind='quicksort', *args, **kwargs):
14501457
See also
14511458
--------
14521459
numpy.ndarray.argsort
1460+
1461+
Notes
1462+
-----
1463+
While an ordering is applied to the category values, arg-sorting
1464+
in this context refers more to organizing and grouping together
1465+
based on matching category values. Thus, this function can be
1466+
called on an unordered Categorical instance unlike the functions
1467+
'Categorical.min' and 'Categorical.max'.
1468+
1469+
Examples
1470+
--------
1471+
>>> pd.Categorical(['b', 'b', 'a', 'c']).argsort()
1472+
array([2, 0, 1, 3])
1473+
1474+
>>> cat = pd.Categorical(['b', 'b', 'a', 'c'],
1475+
... categories=['c', 'b', 'a'],
1476+
... ordered=True)
1477+
>>> cat.argsort()
1478+
array([3, 0, 1, 2])
14531479
"""
1454-
ascending = nv.validate_argsort_with_ascending(ascending, args, kwargs)
1455-
result = np.argsort(self._codes.copy(), kind=kind, **kwargs)
1456-
if not ascending:
1457-
result = result[::-1]
1458-
return result
1480+
# Keep the implementation here just for the docstring.
1481+
return super(Categorical, self).argsort(*args, **kwargs)
14591482

14601483
def sort_values(self, inplace=False, ascending=True, na_position='last'):
14611484
""" Sorts the Categorical by category value returning a new

pandas/tests/extension/base/methods.py

+40
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,46 @@ def test_apply_simple_series(self, data):
3232
result = pd.Series(data).apply(id)
3333
assert isinstance(result, pd.Series)
3434

35+
def test_argsort(self, data_for_sorting):
36+
result = pd.Series(data_for_sorting).argsort()
37+
expected = pd.Series(np.array([2, 0, 1], dtype=np.int64))
38+
self.assert_series_equal(result, expected)
39+
40+
def test_argsort_missing(self, data_missing_for_sorting):
41+
result = pd.Series(data_missing_for_sorting).argsort()
42+
expected = pd.Series(np.array([1, -1, 0], dtype=np.int64))
43+
self.assert_series_equal(result, expected)
44+
45+
@pytest.mark.parametrize('ascending', [True, False])
46+
def test_sort_values(self, data_for_sorting, ascending):
47+
ser = pd.Series(data_for_sorting)
48+
result = ser.sort_values(ascending=ascending)
49+
expected = ser.iloc[[2, 0, 1]]
50+
if not ascending:
51+
expected = expected[::-1]
52+
53+
self.assert_series_equal(result, expected)
54+
55+
@pytest.mark.parametrize('ascending', [True, False])
56+
def test_sort_values_missing(self, data_missing_for_sorting, ascending):
57+
ser = pd.Series(data_missing_for_sorting)
58+
result = ser.sort_values(ascending=ascending)
59+
if ascending:
60+
expected = ser.iloc[[2, 0, 1]]
61+
else:
62+
expected = ser.iloc[[0, 2, 1]]
63+
self.assert_series_equal(result, expected)
64+
65+
@pytest.mark.parametrize('ascending', [True, False])
66+
def test_sort_values_frame(self, data_for_sorting, ascending):
67+
df = pd.DataFrame({"A": [1, 2, 1],
68+
"B": data_for_sorting})
69+
result = df.sort_values(['A', 'B'])
70+
expected = pd.DataFrame({"A": [1, 1, 2],
71+
'B': data_for_sorting.take([2, 0, 1])},
72+
index=[2, 0, 1])
73+
self.assert_frame_equal(result, expected)
74+
3575
@pytest.mark.parametrize('box', [pd.Series, lambda x: x])
3676
@pytest.mark.parametrize('method', [lambda x: x.unique(), pd.unique])
3777
def test_unique(self, data, box, method):

pandas/tests/extension/category/test_categorical.py

+12
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,18 @@ def data_missing():
2929
return Categorical([np.nan, 'A'])
3030

3131

32+
@pytest.fixture
33+
def data_for_sorting():
34+
return Categorical(['A', 'B', 'C'], categories=['C', 'A', 'B'],
35+
ordered=True)
36+
37+
38+
@pytest.fixture
39+
def data_missing_for_sorting():
40+
return Categorical(['A', None, 'B'], categories=['B', 'A'],
41+
ordered=True)
42+
43+
3244
@pytest.fixture
3345
def na_value():
3446
return np.nan

pandas/tests/extension/conftest.py

+20
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,26 @@ def all_data(request, data, data_missing):
3030
return data_missing
3131

3232

33+
@pytest.fixture
34+
def data_for_sorting():
35+
"""Length-3 array with a known sort order.
36+
37+
This should be three items [B, C, A] with
38+
A < B < C
39+
"""
40+
raise NotImplementedError
41+
42+
43+
@pytest.fixture
44+
def data_missing_for_sorting():
45+
"""Length-3 array with a known sort order.
46+
47+
This should be three items [B, NA, A] with
48+
A < B and NA missing.
49+
"""
50+
raise NotImplementedError
51+
52+
3353
@pytest.fixture
3454
def na_cmp():
3555
"""Binary operator for comparing NA values.

pandas/tests/extension/decimal/test_decimal.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,20 @@ def data_missing():
2525
return DecimalArray([decimal.Decimal('NaN'), decimal.Decimal(1)])
2626

2727

28+
@pytest.fixture
29+
def data_for_sorting():
30+
return DecimalArray([decimal.Decimal('1'),
31+
decimal.Decimal('2'),
32+
decimal.Decimal('0')])
33+
34+
35+
@pytest.fixture
36+
def data_missing_for_sorting():
37+
return DecimalArray([decimal.Decimal('1'),
38+
decimal.Decimal('NaN'),
39+
decimal.Decimal('0')])
40+
41+
2842
@pytest.fixture
2943
def na_cmp():
3044
return lambda x, y: x.is_nan() and y.is_nan()
@@ -48,11 +62,17 @@ def assert_series_equal(self, left, right, *args, **kwargs):
4862
*args, **kwargs)
4963

5064
def assert_frame_equal(self, left, right, *args, **kwargs):
51-
self.assert_series_equal(left.dtypes, right.dtypes)
52-
for col in left.columns:
65+
# TODO(EA): select_dtypes
66+
decimals = (left.dtypes == 'decimal').index
67+
68+
for col in decimals:
5369
self.assert_series_equal(left[col], right[col],
5470
*args, **kwargs)
5571

72+
left = left.drop(columns=decimals)
73+
right = right.drop(columns=decimals)
74+
tm.assert_frame_equal(left, right, *args, **kwargs)
75+
5676

5777
class TestDtype(BaseDecimal, base.BaseDtypeTests):
5878
pass

pandas/tests/extension/json/array.py

+11
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@ def __getitem__(self, item):
4444
return self._constructor_from_sequence([
4545
x for x, m in zip(self, item) if m
4646
])
47+
elif isinstance(item, collections.Iterable):
48+
# fancy indexing
49+
return type(self)([self.data[i] for i in item])
4750
else:
51+
# slice
4852
return type(self)(self.data[item])
4953

5054
def __setitem__(self, key, value):
@@ -104,6 +108,13 @@ def _concat_same_type(cls, to_concat):
104108
data = list(itertools.chain.from_iterable([x.data for x in to_concat]))
105109
return cls(data)
106110

111+
def _values_for_argsort(self):
112+
# Disable NumPy's shape inference by including an empty tuple...
113+
# If all the elemnts of self are the same size P, NumPy will
114+
# cast them to an (N, P) array, instead of an (N,) array of tuples.
115+
frozen = [()] + list(tuple(x.items()) for x in self)
116+
return np.array(frozen, dtype=object)[1:]
117+
107118

108119
def make_data():
109120
# TODO: Use a regular dict. See _NDFrameIndexer._setitem_with_indexer

pandas/tests/extension/json/test_json.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,16 @@ def data_missing():
2929
return JSONArray([{}, {'a': 10}])
3030

3131

32+
@pytest.fixture
33+
def data_for_sorting():
34+
return JSONArray([{'b': 1}, {'c': 4}, {'a': 2, 'c': 3}])
35+
36+
37+
@pytest.fixture
38+
def data_missing_for_sorting():
39+
return JSONArray([{'b': 1}, {}, {'a': 4}])
40+
41+
3242
@pytest.fixture
3343
def na_value():
3444
return {}
@@ -70,10 +80,39 @@ def test_fillna_frame(self):
7080

7181

7282
class TestMethods(base.BaseMethodsTests):
73-
@pytest.mark.skip(reason="Unhashable")
83+
unhashable = pytest.mark.skip(reason="Unhashable")
84+
unstable = pytest.mark.skipif(sys.version_info <= (3, 5),
85+
reason="Dictionary order unstable")
86+
87+
@unhashable
7488
def test_value_counts(self, all_data, dropna):
7589
pass
7690

91+
@unhashable
92+
def test_sort_values_frame(self):
93+
# TODO (EA.factorize): see if _values_for_factorize allows this.
94+
pass
95+
96+
@unstable
97+
def test_argsort(self, data_for_sorting):
98+
super(TestMethods, self).test_argsort(data_for_sorting)
99+
100+
@unstable
101+
def test_argsort_missing(self, data_missing_for_sorting):
102+
super(TestMethods, self).test_argsort_missing(
103+
data_missing_for_sorting)
104+
105+
@unstable
106+
@pytest.mark.parametrize('ascending', [True, False])
107+
def test_sort_values(self, data_for_sorting, ascending):
108+
super(TestMethods, self).test_sort_values(
109+
data_for_sorting, ascending)
110+
111+
@pytest.mark.parametrize('ascending', [True, False])
112+
def test_sort_values_missing(self, data_missing_for_sorting, ascending):
113+
super(TestMethods, self).test_sort_values_missing(
114+
data_missing_for_sorting, ascending)
115+
77116

78117
class TestCasting(base.BaseCastingTests):
79118
pass

0 commit comments

Comments
 (0)