Skip to content

Commit 161f762

Browse files
authored
ENH: Support na_position for sort_index and sortlevel (#51672)
* ENH: Support na_position for sort_index and sortlevel * Add additional whatsnew * Fix mypy
1 parent aaa1b90 commit 161f762

File tree

7 files changed

+81
-41
lines changed

7 files changed

+81
-41
lines changed

doc/source/whatsnew/v2.1.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ enhancement2
2828

2929
Other enhancements
3030
^^^^^^^^^^^^^^^^^^
31+
- :meth:`MultiIndex.sortlevel` and :meth:`Index.sortlevel` gained a new keyword ``na_position`` (:issue:`51612`)
3132
- Improve error message when setting :class:`DataFrame` with wrong number of columns through :meth:`DataFrame.isetitem` (:issue:`51701`)
3233
- Let :meth:`DataFrame.to_feather` accept a non-default :class:`Index` and non-string column names (:issue:`51787`)
3334

@@ -109,6 +110,7 @@ Performance improvements
109110
- Performance improvement in :meth:`DataFrame.first_valid_index` and :meth:`DataFrame.last_valid_index` for extension array dtypes (:issue:`51549`)
110111
- Performance improvement in :meth:`DataFrame.where` when ``cond`` is backed by an extension dtype (:issue:`51574`)
111112
- Performance improvement in :meth:`read_orc` when reading a remote URI file path. (:issue:`51609`)
113+
- Performance improvement in :meth:`MultiIndex.sortlevel` when ``ascending`` is a list (:issue:`51612`)
112114
- Performance improvement in :meth:`~arrays.ArrowExtensionArray.isna` when array has zero nulls or is all nulls (:issue:`51630`)
113115
- Performance improvement when parsing strings to ``boolean[pyarrow]`` dtype (:issue:`51730`)
114116

pandas/core/indexes/base.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -1893,7 +1893,11 @@ def _get_level_number(self, level) -> int:
18931893
return 0
18941894

18951895
def sortlevel(
1896-
self, level=None, ascending: bool | list[bool] = True, sort_remaining=None
1896+
self,
1897+
level=None,
1898+
ascending: bool | list[bool] = True,
1899+
sort_remaining=None,
1900+
na_position: str_t = "first",
18971901
):
18981902
"""
18991903
For internal compatibility with the Index API.
@@ -1904,6 +1908,11 @@ def sortlevel(
19041908
----------
19051909
ascending : bool, default True
19061910
False to sort in descending order
1911+
na_position : {'first' or 'last'}, default 'first'
1912+
Argument 'first' puts NaNs at the beginning, 'last' puts NaNs at
1913+
the end.
1914+
1915+
.. versionadded:: 2.1.0
19071916
19081917
level, sort_remaining are compat parameters
19091918
@@ -1925,7 +1934,9 @@ def sortlevel(
19251934
if not isinstance(ascending, bool):
19261935
raise TypeError("ascending must be a bool value")
19271936

1928-
return self.sort_values(return_indexer=True, ascending=ascending)
1937+
return self.sort_values(
1938+
return_indexer=True, ascending=ascending, na_position=na_position
1939+
)
19291940

19301941
def _get_level_values(self, level) -> Index:
19311942
"""

pandas/core/indexes/multi.py

+14-28
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@
9494
from pandas.core.ops.invalid import make_invalid_op
9595
from pandas.core.sorting import (
9696
get_group_index,
97-
indexer_from_factorized,
9897
lexsort_indexer,
9998
)
10099

@@ -2367,6 +2366,7 @@ def sortlevel(
23672366
level: IndexLabel = 0,
23682367
ascending: bool | list[bool] = True,
23692368
sort_remaining: bool = True,
2369+
na_position: str = "first",
23702370
) -> tuple[MultiIndex, npt.NDArray[np.intp]]:
23712371
"""
23722372
Sort MultiIndex at the requested level.
@@ -2383,6 +2383,11 @@ def sortlevel(
23832383
False to sort in descending order.
23842384
Can also be a list to specify a directed ordering.
23852385
sort_remaining : sort by the remaining levels after level
2386+
na_position : {'first' or 'last'}, default 'first'
2387+
Argument 'first' puts NaNs at the beginning, 'last' puts NaNs at
2388+
the end.
2389+
2390+
.. versionadded:: 2.1.0
23862391
23872392
Returns
23882393
-------
@@ -2428,40 +2433,21 @@ def sortlevel(
24282433
]
24292434
sortorder = None
24302435

2436+
codes = [self.codes[lev] for lev in level]
24312437
# we have a directed ordering via ascending
24322438
if isinstance(ascending, list):
24332439
if not len(level) == len(ascending):
24342440
raise ValueError("level must have same length as ascending")
2435-
2436-
indexer = lexsort_indexer(
2437-
[self.codes[lev] for lev in level], orders=ascending
2441+
elif sort_remaining:
2442+
codes.extend(
2443+
[self.codes[lev] for lev in range(len(self.levels)) if lev not in level]
24382444
)
2439-
2440-
# level ordering
24412445
else:
2442-
codes = list(self.codes)
2443-
shape = list(self.levshape)
2444-
2445-
# partition codes and shape
2446-
primary = tuple(codes[lev] for lev in level)
2447-
primshp = tuple(shape[lev] for lev in level)
2448-
2449-
# Reverse sorted to retain the order of
2450-
# smaller indices that needs to be removed
2451-
for lev in sorted(level, reverse=True):
2452-
codes.pop(lev)
2453-
shape.pop(lev)
2454-
2455-
if sort_remaining:
2456-
primary += primary + tuple(codes)
2457-
primshp += primshp + tuple(shape)
2458-
else:
2459-
sortorder = level[0]
2460-
2461-
indexer = indexer_from_factorized(primary, primshp, compress=False)
2446+
sortorder = level[0]
24622447

2463-
if not ascending:
2464-
indexer = indexer[::-1]
2448+
indexer = lexsort_indexer(
2449+
codes, orders=ascending, na_position=na_position, codes_given=True
2450+
)
24652451

24662452
indexer = ensure_platform_int(indexer)
24672453
new_codes = [level_codes.take(indexer) for level_codes in self.codes]

pandas/core/sorting.py

+30-11
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,14 @@ def get_indexer_indexer(
8383

8484
if level is not None:
8585
_, indexer = target.sortlevel(
86-
level, ascending=ascending, sort_remaining=sort_remaining
86+
level,
87+
ascending=ascending,
88+
sort_remaining=sort_remaining,
89+
na_position=na_position,
8790
)
8891
elif isinstance(target, ABCMultiIndex):
8992
indexer = lexsort_indexer(
90-
target._get_codes_for_sorting(), orders=ascending, na_position=na_position
93+
target.codes, orders=ascending, na_position=na_position, codes_given=True
9194
)
9295
else:
9396
# Check monotonic-ness before sort an index (GH 11080)
@@ -302,7 +305,11 @@ def indexer_from_factorized(
302305

303306

304307
def lexsort_indexer(
305-
keys, orders=None, na_position: str = "last", key: Callable | None = None
308+
keys,
309+
orders=None,
310+
na_position: str = "last",
311+
key: Callable | None = None,
312+
codes_given: bool = False,
306313
) -> npt.NDArray[np.intp]:
307314
"""
308315
Performs lexical sorting on a set of keys
@@ -321,6 +328,8 @@ def lexsort_indexer(
321328
Determines placement of NA elements in the sorted list ("last" or "first")
322329
key : Callable, optional
323330
Callable key function applied to every element in keys before sorting
331+
codes_given: bool, False
332+
Avoid categorical materialization if codes are already provided.
324333
325334
Returns
326335
-------
@@ -338,15 +347,27 @@ def lexsort_indexer(
338347
keys = [ensure_key_mapped(k, key) for k in keys]
339348

340349
for k, order in zip(keys, orders):
341-
cat = Categorical(k, ordered=True)
342-
343350
if na_position not in ["last", "first"]:
344351
raise ValueError(f"invalid na_position: {na_position}")
345352

346-
n = len(cat.categories)
347-
codes = cat.codes.copy()
353+
if codes_given:
354+
mask = k == -1
355+
codes = k.copy()
356+
n = len(codes)
357+
mask_n = n
358+
if mask.any():
359+
n -= 1
360+
361+
else:
362+
cat = Categorical(k, ordered=True)
363+
n = len(cat.categories)
364+
codes = cat.codes.copy()
365+
mask = cat.codes == -1
366+
if mask.any():
367+
mask_n = n + 1
368+
else:
369+
mask_n = n
348370

349-
mask = cat.codes == -1
350371
if order: # ascending
351372
if na_position == "last":
352373
codes = np.where(mask, n, codes)
@@ -357,10 +378,8 @@ def lexsort_indexer(
357378
codes = np.where(mask, n, n - codes - 1)
358379
elif na_position == "first":
359380
codes = np.where(mask, 0, n - codes)
360-
if mask.any():
361-
n += 1
362381

363-
shape.append(n)
382+
shape.append(mask_n)
364383
labels.append(codes)
365384

366385
return indexer_from_factorized(labels, tuple(shape))

pandas/tests/frame/methods/test_sort_index.py

+7
Original file line numberDiff line numberDiff line change
@@ -907,3 +907,10 @@ def test_sort_index_multiindex_sparse_column(self):
907907
result = expected.sort_index(level=0)
908908

909909
tm.assert_frame_equal(result, expected)
910+
911+
def test_sort_index_na_position(self):
912+
# GH#51612
913+
df = DataFrame([1, 2], index=MultiIndex.from_tuples([(1, 1), (1, pd.NA)]))
914+
expected = df.copy()
915+
result = df.sort_index(level=[0, 1], na_position="last")
916+
tm.assert_frame_equal(result, expected)

pandas/tests/indexes/multi/test_sorting.py

+8
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,14 @@ def test_sortlevel_deterministic():
7676
assert sorted_idx.equals(expected[::-1])
7777

7878

79+
def test_sortlevel_na_position():
80+
# GH#51612
81+
midx = MultiIndex.from_tuples([(1, np.nan), (1, 1)])
82+
result = midx.sortlevel(level=[0, 1], na_position="last")[0]
83+
expected = MultiIndex.from_tuples([(1, 1), (1, np.nan)])
84+
tm.assert_index_equal(result, expected)
85+
86+
7987
def test_numpy_argsort(idx):
8088
result = np.argsort(idx)
8189
expected = idx.argsort()

pandas/tests/indexes/test_base.py

+7
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,13 @@ def test_sortlevel(self):
12621262
result = index.sortlevel(ascending=False)
12631263
tm.assert_index_equal(result[0], expected)
12641264

1265+
def test_sortlevel_na_position(self):
1266+
# GH#51612
1267+
idx = Index([1, np.nan])
1268+
result = idx.sortlevel(na_position="first")[0]
1269+
expected = Index([np.nan, 1])
1270+
tm.assert_index_equal(result, expected)
1271+
12651272

12661273
class TestMixedIntIndex(Base):
12671274
# Mostly the tests from common.py for which the results differ

0 commit comments

Comments
 (0)