diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 2e6737b2e61aa..77de1636a92eb 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -14,6 +14,7 @@ Sequence, cast, final, + overload, ) import warnings @@ -101,6 +102,7 @@ Categorical, DataFrame, Index, + MultiIndex, Series, ) from pandas.core.arrays import ( @@ -1780,7 +1782,7 @@ def safe_sort( na_sentinel: int = -1, assume_unique: bool = False, verify: bool = True, -) -> np.ndarray | tuple[np.ndarray, np.ndarray]: +) -> np.ndarray | MultiIndex | tuple[np.ndarray | MultiIndex, np.ndarray]: """ Sort ``values`` and reorder corresponding ``codes``. @@ -1809,7 +1811,7 @@ def safe_sort( Returns ------- - ordered : ndarray + ordered : ndarray or MultiIndex Sorted ``values`` new_codes : ndarray Reordered ``codes``; returned when ``codes`` is not None. @@ -1827,6 +1829,7 @@ def safe_sort( raise TypeError( "Only list-like objects are allowed to be passed to safe_sort as values" ) + original_values = values if not isinstance(values, (np.ndarray, ABCExtensionArray)): # don't convert to string types @@ -1838,6 +1841,7 @@ def safe_sort( values = np.asarray(values, dtype=dtype) # type: ignore[arg-type] sorter = None + ordered: np.ndarray | MultiIndex if ( not is_extension_array_dtype(values) @@ -1853,7 +1857,7 @@ def safe_sort( # which would work, but which fails for special case of 1d arrays # with tuples. if values.size and isinstance(values[0], tuple): - ordered = _sort_tuples(values) + ordered = _sort_tuples(values, original_values) else: ordered = _sort_mixed(values) @@ -1915,19 +1919,33 @@ def _sort_mixed(values) -> np.ndarray: ) -def _sort_tuples(values: np.ndarray) -> np.ndarray: +@overload +def _sort_tuples(values: np.ndarray, original_values: np.ndarray) -> np.ndarray: + ... + + +@overload +def _sort_tuples(values: np.ndarray, original_values: MultiIndex) -> MultiIndex: + ... + + +def _sort_tuples( + values: np.ndarray, original_values: np.ndarray | MultiIndex +) -> np.ndarray | MultiIndex: """ Convert array of tuples (1d) to array or array (2d). We need to keep the columns separately as they contain different types and nans (can't use `np.sort` as it may fail when str and nan are mixed in a column as types cannot be compared). + We have to apply the indexer to the original values to keep the dtypes in + case of MultiIndexes """ from pandas.core.internals.construction import to_arrays from pandas.core.sorting import lexsort_indexer arrays, _ = to_arrays(values, None) indexer = lexsort_indexer(arrays, orders=True) - return values[indexer] + return original_values[indexer] def union_with_duplicates(lvals: ArrayLike, rvals: ArrayLike) -> ArrayLike: diff --git a/pandas/tests/test_sorting.py b/pandas/tests/test_sorting.py index 396c4d82d01fc..537792ea8263c 100644 --- a/pandas/tests/test_sorting.py +++ b/pandas/tests/test_sorting.py @@ -11,6 +11,7 @@ ) from pandas import ( + NA, DataFrame, MultiIndex, Series, @@ -510,3 +511,15 @@ def test_mixed_str_nan(): result = safe_sort(values) expected = np.array([np.nan, "a", "b", "b"], dtype=object) tm.assert_numpy_array_equal(result, expected) + + +def test_safe_sort_multiindex(): + # GH#48412 + arr1 = Series([2, 1, NA, NA], dtype="Int64") + arr2 = [2, 1, 3, 3] + midx = MultiIndex.from_arrays([arr1, arr2]) + result = safe_sort(midx) + expected = MultiIndex.from_arrays( + [Series([1, 2, NA, NA], dtype="Int64"), [1, 2, 3, 3]] + ) + tm.assert_index_equal(result, expected)