Skip to content

Commit 6bb5a44

Browse files
lukemanleynoatamir
authored andcommitted
BUG/PERF: use lexsort_indexer in MultiIndex.argsort (pandas-dev#48495)
1 parent 4c72370 commit 6bb5a44

File tree

4 files changed

+63
-5
lines changed

4 files changed

+63
-5
lines changed

doc/source/whatsnew/v1.6.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ Missing
220220

221221
MultiIndex
222222
^^^^^^^^^^
223+
- Bug in :meth:`MultiIndex.argsort` raising ``TypeError`` when index contains :attr:`NA` (:issue:`48495`)
223224
- Bug in :meth:`MultiIndex.difference` losing extension array dtype (:issue:`48606`)
224225
- Bug in :class:`MultiIndex.set_levels` raising ``IndexError`` when setting empty level (:issue:`48636`)
225226
- Bug in :meth:`MultiIndex.unique` losing extension array dtype (:issue:`48335`)

pandas/core/indexes/multi.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1952,7 +1952,7 @@ def _lexsort_depth(self) -> int:
19521952
return self.sortorder
19531953
return _lexsort_depth(self.codes, self.nlevels)
19541954

1955-
def _sort_levels_monotonic(self) -> MultiIndex:
1955+
def _sort_levels_monotonic(self, raise_if_incomparable: bool = False) -> MultiIndex:
19561956
"""
19571957
This is an *internal* function.
19581958
@@ -1999,7 +1999,8 @@ def _sort_levels_monotonic(self) -> MultiIndex:
19991999
# indexer to reorder the levels
20002000
indexer = lev.argsort()
20012001
except TypeError:
2002-
pass
2002+
if raise_if_incomparable:
2003+
raise
20032004
else:
20042005
lev = lev.take(indexer)
20052006

@@ -2245,9 +2246,9 @@ def append(self, other):
22452246

22462247
def argsort(self, *args, **kwargs) -> npt.NDArray[np.intp]:
22472248
if len(args) == 0 and len(kwargs) == 0:
2248-
# np.lexsort is significantly faster than self._values.argsort()
2249-
values = [self._get_level_values(i) for i in reversed(range(self.nlevels))]
2250-
return np.lexsort(values)
2249+
# lexsort is significantly faster than self._values.argsort()
2250+
target = self._sort_levels_monotonic(raise_if_incomparable=True)
2251+
return lexsort_indexer(target._get_codes_for_sorting())
22512252
return self._values.argsort(*args, **kwargs)
22522253

22532254
@Appender(_index_shared_docs["repeat"] % _index_doc_kwargs)

pandas/tests/indexes/multi/test_sorting.py

+24
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Index,
1515
MultiIndex,
1616
RangeIndex,
17+
Timestamp,
1718
)
1819
import pandas._testing as tm
1920
from pandas.core.indexes.frozen import FrozenList
@@ -280,3 +281,26 @@ def test_remove_unused_levels_with_nan():
280281
result = idx.levels
281282
expected = FrozenList([["a", np.nan], [4]])
282283
assert str(result) == str(expected)
284+
285+
286+
def test_sort_values_nan():
287+
# GH48495, GH48626
288+
midx = MultiIndex(levels=[["A", "B", "C"], ["D"]], codes=[[1, 0, 2], [-1, -1, 0]])
289+
result = midx.sort_values()
290+
expected = MultiIndex(
291+
levels=[["A", "B", "C"], ["D"]], codes=[[0, 1, 2], [-1, -1, 0]]
292+
)
293+
tm.assert_index_equal(result, expected)
294+
295+
296+
def test_sort_values_incomparable():
297+
# GH48495
298+
mi = MultiIndex.from_arrays(
299+
[
300+
[1, Timestamp("2000-01-01")],
301+
[3, 4],
302+
]
303+
)
304+
match = "'<' not supported between instances of 'Timestamp' and 'int'"
305+
with pytest.raises(TypeError, match=match):
306+
mi.sort_values()

pandas/tests/indexing/multiindex/test_sorted.py

+32
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
import pytest
33

44
from pandas import (
5+
NA,
56
DataFrame,
67
MultiIndex,
78
Series,
9+
array,
810
)
911
import pandas._testing as tm
1012

@@ -86,6 +88,36 @@ def test_sort_values_key(self):
8688

8789
tm.assert_frame_equal(result, expected)
8890

91+
def test_argsort_with_na(self):
92+
# GH48495
93+
arrays = [
94+
array([2, NA, 1], dtype="Int64"),
95+
array([1, 2, 3], dtype="Int64"),
96+
]
97+
index = MultiIndex.from_arrays(arrays)
98+
result = index.argsort()
99+
expected = np.array([2, 0, 1], dtype=np.intp)
100+
tm.assert_numpy_array_equal(result, expected)
101+
102+
def test_sort_values_with_na(self):
103+
# GH48495
104+
arrays = [
105+
array([2, NA, 1], dtype="Int64"),
106+
array([1, 2, 3], dtype="Int64"),
107+
]
108+
index = MultiIndex.from_arrays(arrays)
109+
index = index.sort_values()
110+
result = DataFrame(range(3), index=index)
111+
112+
arrays = [
113+
array([1, 2, NA], dtype="Int64"),
114+
array([3, 1, 2], dtype="Int64"),
115+
]
116+
index = MultiIndex.from_arrays(arrays)
117+
expected = DataFrame(range(3), index=index)
118+
119+
tm.assert_frame_equal(result, expected)
120+
89121
def test_frame_getitem_not_sorted(self, multiindex_dataframe_random_data):
90122
frame = multiindex_dataframe_random_data
91123
df = frame.T

0 commit comments

Comments
 (0)