Skip to content

Commit 8cb7cfe

Browse files
authored
BUG: safe_sort losing MultiIndex dtypes (#48412)
1 parent 0daa6bb commit 8cb7cfe

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

pandas/core/algorithms.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Sequence,
1515
cast,
1616
final,
17+
overload,
1718
)
1819
import warnings
1920

@@ -101,6 +102,7 @@
101102
Categorical,
102103
DataFrame,
103104
Index,
105+
MultiIndex,
104106
Series,
105107
)
106108
from pandas.core.arrays import (
@@ -1780,7 +1782,7 @@ def safe_sort(
17801782
na_sentinel: int = -1,
17811783
assume_unique: bool = False,
17821784
verify: bool = True,
1783-
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
1785+
) -> np.ndarray | MultiIndex | tuple[np.ndarray | MultiIndex, np.ndarray]:
17841786
"""
17851787
Sort ``values`` and reorder corresponding ``codes``.
17861788
@@ -1809,7 +1811,7 @@ def safe_sort(
18091811
18101812
Returns
18111813
-------
1812-
ordered : ndarray
1814+
ordered : ndarray or MultiIndex
18131815
Sorted ``values``
18141816
new_codes : ndarray
18151817
Reordered ``codes``; returned when ``codes`` is not None.
@@ -1827,6 +1829,7 @@ def safe_sort(
18271829
raise TypeError(
18281830
"Only list-like objects are allowed to be passed to safe_sort as values"
18291831
)
1832+
original_values = values
18301833

18311834
if not isinstance(values, (np.ndarray, ABCExtensionArray)):
18321835
# don't convert to string types
@@ -1838,6 +1841,7 @@ def safe_sort(
18381841
values = np.asarray(values, dtype=dtype) # type: ignore[arg-type]
18391842

18401843
sorter = None
1844+
ordered: np.ndarray | MultiIndex
18411845

18421846
if (
18431847
not is_extension_array_dtype(values)
@@ -1853,7 +1857,7 @@ def safe_sort(
18531857
# which would work, but which fails for special case of 1d arrays
18541858
# with tuples.
18551859
if values.size and isinstance(values[0], tuple):
1856-
ordered = _sort_tuples(values)
1860+
ordered = _sort_tuples(values, original_values)
18571861
else:
18581862
ordered = _sort_mixed(values)
18591863

@@ -1915,19 +1919,33 @@ def _sort_mixed(values) -> np.ndarray:
19151919
)
19161920

19171921

1918-
def _sort_tuples(values: np.ndarray) -> np.ndarray:
1922+
@overload
1923+
def _sort_tuples(values: np.ndarray, original_values: np.ndarray) -> np.ndarray:
1924+
...
1925+
1926+
1927+
@overload
1928+
def _sort_tuples(values: np.ndarray, original_values: MultiIndex) -> MultiIndex:
1929+
...
1930+
1931+
1932+
def _sort_tuples(
1933+
values: np.ndarray, original_values: np.ndarray | MultiIndex
1934+
) -> np.ndarray | MultiIndex:
19191935
"""
19201936
Convert array of tuples (1d) to array or array (2d).
19211937
We need to keep the columns separately as they contain different types and
19221938
nans (can't use `np.sort` as it may fail when str and nan are mixed in a
19231939
column as types cannot be compared).
1940+
We have to apply the indexer to the original values to keep the dtypes in
1941+
case of MultiIndexes
19241942
"""
19251943
from pandas.core.internals.construction import to_arrays
19261944
from pandas.core.sorting import lexsort_indexer
19271945

19281946
arrays, _ = to_arrays(values, None)
19291947
indexer = lexsort_indexer(arrays, orders=True)
1930-
return values[indexer]
1948+
return original_values[indexer]
19311949

19321950

19331951
def union_with_duplicates(lvals: ArrayLike, rvals: ArrayLike) -> ArrayLike:

pandas/tests/test_sorting.py

+13
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212

1313
from pandas import (
14+
NA,
1415
DataFrame,
1516
MultiIndex,
1617
Series,
@@ -510,3 +511,15 @@ def test_mixed_str_nan():
510511
result = safe_sort(values)
511512
expected = np.array([np.nan, "a", "b", "b"], dtype=object)
512513
tm.assert_numpy_array_equal(result, expected)
514+
515+
516+
def test_safe_sort_multiindex():
517+
# GH#48412
518+
arr1 = Series([2, 1, NA, NA], dtype="Int64")
519+
arr2 = [2, 1, 3, 3]
520+
midx = MultiIndex.from_arrays([arr1, arr2])
521+
result = safe_sort(midx)
522+
expected = MultiIndex.from_arrays(
523+
[Series([1, 2, NA, NA], dtype="Int64"), [1, 2, 3, 3]]
524+
)
525+
tm.assert_index_equal(result, expected)

0 commit comments

Comments
 (0)