14
14
Sequence ,
15
15
cast ,
16
16
final ,
17
+ overload ,
17
18
)
18
19
import warnings
19
20
101
102
Categorical ,
102
103
DataFrame ,
103
104
Index ,
105
+ MultiIndex ,
104
106
Series ,
105
107
)
106
108
from pandas .core .arrays import (
@@ -1792,7 +1794,7 @@ def safe_sort(
1792
1794
na_sentinel : int = - 1 ,
1793
1795
assume_unique : bool = False ,
1794
1796
verify : bool = True ,
1795
- ) -> np .ndarray | tuple [np .ndarray , np .ndarray ]:
1797
+ ) -> np .ndarray | MultiIndex | tuple [np .ndarray | MultiIndex , np .ndarray ]:
1796
1798
"""
1797
1799
Sort ``values`` and reorder corresponding ``codes``.
1798
1800
@@ -1821,7 +1823,7 @@ def safe_sort(
1821
1823
1822
1824
Returns
1823
1825
-------
1824
- ordered : ndarray
1826
+ ordered : ndarray or MultiIndex
1825
1827
Sorted ``values``
1826
1828
new_codes : ndarray
1827
1829
Reordered ``codes``; returned when ``codes`` is not None.
@@ -1839,6 +1841,8 @@ def safe_sort(
1839
1841
raise TypeError (
1840
1842
"Only list-like objects are allowed to be passed to safe_sort as values"
1841
1843
)
1844
+ original_values = values
1845
+ is_mi = isinstance (original_values , ABCMultiIndex )
1842
1846
1843
1847
if not isinstance (values , (np .ndarray , ABCExtensionArray )):
1844
1848
# don't convert to string types
@@ -1850,6 +1854,7 @@ def safe_sort(
1850
1854
values = np .asarray (values , dtype = dtype ) # type: ignore[arg-type]
1851
1855
1852
1856
sorter = None
1857
+ ordered : np .ndarray | MultiIndex
1853
1858
1854
1859
if (
1855
1860
not is_extension_array_dtype (values )
@@ -1859,13 +1864,17 @@ def safe_sort(
1859
1864
else :
1860
1865
try :
1861
1866
sorter = values .argsort ()
1862
- ordered = values .take (sorter )
1867
+ if is_mi :
1868
+ # Operate on original object instead of casted array (MultiIndex)
1869
+ ordered = original_values .take (sorter )
1870
+ else :
1871
+ ordered = values .take (sorter )
1863
1872
except TypeError :
1864
1873
# Previous sorters failed or were not applicable, try `_sort_mixed`
1865
1874
# which would work, but which fails for special case of 1d arrays
1866
1875
# with tuples.
1867
1876
if values .size and isinstance (values [0 ], tuple ):
1868
- ordered = _sort_tuples (values )
1877
+ ordered = _sort_tuples (values , original_values )
1869
1878
else :
1870
1879
ordered = _sort_mixed (values )
1871
1880
@@ -1927,19 +1936,33 @@ def _sort_mixed(values) -> np.ndarray:
1927
1936
)
1928
1937
1929
1938
1930
- def _sort_tuples (values : np .ndarray ) -> np .ndarray :
1939
+ @overload
1940
+ def _sort_tuples (values : np .ndarray , original_values : np .ndarray ) -> np .ndarray :
1941
+ ...
1942
+
1943
+
1944
+ @overload
1945
+ def _sort_tuples (values : np .ndarray , original_values : MultiIndex ) -> MultiIndex :
1946
+ ...
1947
+
1948
+
1949
+ def _sort_tuples (
1950
+ values : np .ndarray , original_values : np .ndarray | MultiIndex
1951
+ ) -> np .ndarray | MultiIndex :
1931
1952
"""
1932
1953
Convert array of tuples (1d) to array or array (2d).
1933
1954
We need to keep the columns separately as they contain different types and
1934
1955
nans (can't use `np.sort` as it may fail when str and nan are mixed in a
1935
1956
column as types cannot be compared).
1957
+ We have to apply the indexer to the original values to keep the dtypes in
1958
+ case of MultiIndexes
1936
1959
"""
1937
1960
from pandas .core .internals .construction import to_arrays
1938
1961
from pandas .core .sorting import lexsort_indexer
1939
1962
1940
1963
arrays , _ = to_arrays (values , None )
1941
1964
indexer = lexsort_indexer (arrays , orders = True )
1942
- return values [indexer ]
1965
+ return original_values [indexer ]
1943
1966
1944
1967
1945
1968
def union_with_duplicates (lvals : ArrayLike , rvals : ArrayLike ) -> ArrayLike :
0 commit comments