Skip to content

Commit aaf00be

Browse files
Backport PR #48412 on branch 1.5.x (BUG: safe_sort losing MultiIndex dtypes) (#49002)
* Backport PR #48412: BUG: safe_sort losing MultiIndex dtypes * Update algorithms.py * Update algorithms.py Co-authored-by: Patrick Hoefler <[email protected]>
1 parent 5a9db39 commit aaf00be

File tree

2 files changed

+42
-6
lines changed

2 files changed

+42
-6
lines changed

pandas/core/algorithms.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Sequence,
1515
cast,
1616
final,
17+
overload,
1718
)
1819
import warnings
1920

@@ -101,6 +102,7 @@
101102
Categorical,
102103
DataFrame,
103104
Index,
105+
MultiIndex,
104106
Series,
105107
)
106108
from pandas.core.arrays import (
@@ -1792,7 +1794,7 @@ def safe_sort(
17921794
na_sentinel: int = -1,
17931795
assume_unique: bool = False,
17941796
verify: bool = True,
1795-
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
1797+
) -> np.ndarray | MultiIndex | tuple[np.ndarray | MultiIndex, np.ndarray]:
17961798
"""
17971799
Sort ``values`` and reorder corresponding ``codes``.
17981800
@@ -1821,7 +1823,7 @@ def safe_sort(
18211823
18221824
Returns
18231825
-------
1824-
ordered : ndarray
1826+
ordered : ndarray or MultiIndex
18251827
Sorted ``values``
18261828
new_codes : ndarray
18271829
Reordered ``codes``; returned when ``codes`` is not None.
@@ -1839,6 +1841,8 @@ def safe_sort(
18391841
raise TypeError(
18401842
"Only list-like objects are allowed to be passed to safe_sort as values"
18411843
)
1844+
original_values = values
1845+
is_mi = isinstance(original_values, ABCMultiIndex)
18421846

18431847
if not isinstance(values, (np.ndarray, ABCExtensionArray)):
18441848
# don't convert to string types
@@ -1850,6 +1854,7 @@ def safe_sort(
18501854
values = np.asarray(values, dtype=dtype) # type: ignore[arg-type]
18511855

18521856
sorter = None
1857+
ordered: np.ndarray | MultiIndex
18531858

18541859
if (
18551860
not is_extension_array_dtype(values)
@@ -1859,13 +1864,17 @@ def safe_sort(
18591864
else:
18601865
try:
18611866
sorter = values.argsort()
1862-
ordered = values.take(sorter)
1867+
if is_mi:
1868+
# Operate on original object instead of casted array (MultiIndex)
1869+
ordered = original_values.take(sorter)
1870+
else:
1871+
ordered = values.take(sorter)
18631872
except TypeError:
18641873
# Previous sorters failed or were not applicable, try `_sort_mixed`
18651874
# which would work, but which fails for special case of 1d arrays
18661875
# with tuples.
18671876
if values.size and isinstance(values[0], tuple):
1868-
ordered = _sort_tuples(values)
1877+
ordered = _sort_tuples(values, original_values)
18691878
else:
18701879
ordered = _sort_mixed(values)
18711880

@@ -1927,19 +1936,33 @@ def _sort_mixed(values) -> np.ndarray:
19271936
)
19281937

19291938

1930-
def _sort_tuples(values: np.ndarray) -> np.ndarray:
1939+
@overload
1940+
def _sort_tuples(values: np.ndarray, original_values: np.ndarray) -> np.ndarray:
1941+
...
1942+
1943+
1944+
@overload
1945+
def _sort_tuples(values: np.ndarray, original_values: MultiIndex) -> MultiIndex:
1946+
...
1947+
1948+
1949+
def _sort_tuples(
1950+
values: np.ndarray, original_values: np.ndarray | MultiIndex
1951+
) -> np.ndarray | MultiIndex:
19311952
"""
19321953
Convert array of tuples (1d) to array or array (2d).
19331954
We need to keep the columns separately as they contain different types and
19341955
nans (can't use `np.sort` as it may fail when str and nan are mixed in a
19351956
column as types cannot be compared).
1957+
We have to apply the indexer to the original values to keep the dtypes in
1958+
case of MultiIndexes
19361959
"""
19371960
from pandas.core.internals.construction import to_arrays
19381961
from pandas.core.sorting import lexsort_indexer
19391962

19401963
arrays, _ = to_arrays(values, None)
19411964
indexer = lexsort_indexer(arrays, orders=True)
1942-
return values[indexer]
1965+
return original_values[indexer]
19431966

19441967

19451968
def union_with_duplicates(lvals: ArrayLike, rvals: ArrayLike) -> ArrayLike:

pandas/tests/test_sorting.py

+13
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212

1313
from pandas import (
14+
NA,
1415
DataFrame,
1516
MultiIndex,
1617
Series,
@@ -510,3 +511,15 @@ def test_mixed_str_nan():
510511
result = safe_sort(values)
511512
expected = np.array([np.nan, "a", "b", "b"], dtype=object)
512513
tm.assert_numpy_array_equal(result, expected)
514+
515+
516+
def test_safe_sort_multiindex():
517+
# GH#48412
518+
arr1 = Series([2, 1, NA, NA], dtype="Int64")
519+
arr2 = [2, 1, 3, 3]
520+
midx = MultiIndex.from_arrays([arr1, arr2])
521+
result = safe_sort(midx)
522+
expected = MultiIndex.from_arrays(
523+
[Series([1, 2, NA, NA], dtype="Int64"), [1, 2, 3, 3]]
524+
)
525+
tm.assert_index_equal(result, expected)

0 commit comments

Comments
 (0)