diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 80fe74ccee583..fc2050dcc7f4d 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -14,6 +14,7 @@ Sequence, cast, final, + overload, ) import warnings @@ -101,6 +102,7 @@ Categorical, DataFrame, Index, + MultiIndex, Series, ) from pandas.core.arrays import ( @@ -1792,7 +1794,7 @@ def safe_sort( na_sentinel: int = -1, assume_unique: bool = False, verify: bool = True, -) -> np.ndarray | tuple[np.ndarray, np.ndarray]: +) -> np.ndarray | MultiIndex | tuple[np.ndarray | MultiIndex, np.ndarray]: """ Sort ``values`` and reorder corresponding ``codes``. @@ -1821,7 +1823,7 @@ def safe_sort( Returns ------- - ordered : ndarray + ordered : ndarray or MultiIndex Sorted ``values`` new_codes : ndarray Reordered ``codes``; returned when ``codes`` is not None. @@ -1839,6 +1841,8 @@ def safe_sort( raise TypeError( "Only list-like objects are allowed to be passed to safe_sort as values" ) + original_values = values + is_mi = isinstance(original_values, ABCMultiIndex) if not isinstance(values, (np.ndarray, ABCExtensionArray)): # don't convert to string types @@ -1850,6 +1854,7 @@ def safe_sort( values = np.asarray(values, dtype=dtype) # type: ignore[arg-type] sorter = None + ordered: np.ndarray | MultiIndex if ( not is_extension_array_dtype(values) @@ -1859,13 +1864,17 @@ def safe_sort( else: try: sorter = values.argsort() - ordered = values.take(sorter) + if is_mi: + # Operate on original object instead of casted array (MultiIndex) + ordered = original_values.take(sorter) + else: + ordered = values.take(sorter) except TypeError: # Previous sorters failed or were not applicable, try `_sort_mixed` # which would work, but which fails for special case of 1d arrays # with tuples. if values.size and isinstance(values[0], tuple): - ordered = _sort_tuples(values) + ordered = _sort_tuples(values, original_values) else: ordered = _sort_mixed(values) @@ -1927,19 +1936,33 @@ def _sort_mixed(values) -> np.ndarray: ) -def _sort_tuples(values: np.ndarray) -> np.ndarray: +@overload +def _sort_tuples(values: np.ndarray, original_values: np.ndarray) -> np.ndarray: + ... + + +@overload +def _sort_tuples(values: np.ndarray, original_values: MultiIndex) -> MultiIndex: + ... + + +def _sort_tuples( + values: np.ndarray, original_values: np.ndarray | MultiIndex +) -> np.ndarray | MultiIndex: """ Convert array of tuples (1d) to array or array (2d). We need to keep the columns separately as they contain different types and nans (can't use `np.sort` as it may fail when str and nan are mixed in a column as types cannot be compared). + We have to apply the indexer to the original values to keep the dtypes in + case of MultiIndexes """ from pandas.core.internals.construction import to_arrays from pandas.core.sorting import lexsort_indexer arrays, _ = to_arrays(values, None) indexer = lexsort_indexer(arrays, orders=True) - return values[indexer] + return original_values[indexer] def union_with_duplicates(lvals: ArrayLike, rvals: ArrayLike) -> ArrayLike: diff --git a/pandas/tests/test_sorting.py b/pandas/tests/test_sorting.py index 396c4d82d01fc..537792ea8263c 100644 --- a/pandas/tests/test_sorting.py +++ b/pandas/tests/test_sorting.py @@ -11,6 +11,7 @@ ) from pandas import ( + NA, DataFrame, MultiIndex, Series, @@ -510,3 +511,15 @@ def test_mixed_str_nan(): result = safe_sort(values) expected = np.array([np.nan, "a", "b", "b"], dtype=object) tm.assert_numpy_array_equal(result, expected) + + +def test_safe_sort_multiindex(): + # GH#48412 + arr1 = Series([2, 1, NA, NA], dtype="Int64") + arr2 = [2, 1, 3, 3] + midx = MultiIndex.from_arrays([arr1, arr2]) + result = safe_sort(midx) + expected = MultiIndex.from_arrays( + [Series([1, 2, NA, NA], dtype="Int64"), [1, 2, 3, 3]] + ) + tm.assert_index_equal(result, expected)