14
14
Sequence ,
15
15
cast ,
16
16
final ,
17
+ overload ,
17
18
)
18
19
import warnings
19
20
101
102
Categorical ,
102
103
DataFrame ,
103
104
Index ,
105
+ MultiIndex ,
104
106
Series ,
105
107
)
106
108
from pandas .core .arrays import (
@@ -1780,7 +1782,7 @@ def safe_sort(
1780
1782
na_sentinel : int = - 1 ,
1781
1783
assume_unique : bool = False ,
1782
1784
verify : bool = True ,
1783
- ) -> np .ndarray | tuple [np .ndarray , np .ndarray ]:
1785
+ ) -> np .ndarray | MultiIndex | tuple [np .ndarray | MultiIndex , np .ndarray ]:
1784
1786
"""
1785
1787
Sort ``values`` and reorder corresponding ``codes``.
1786
1788
@@ -1809,7 +1811,7 @@ def safe_sort(
1809
1811
1810
1812
Returns
1811
1813
-------
1812
- ordered : ndarray
1814
+ ordered : ndarray or MultiIndex
1813
1815
Sorted ``values``
1814
1816
new_codes : ndarray
1815
1817
Reordered ``codes``; returned when ``codes`` is not None.
@@ -1827,6 +1829,7 @@ def safe_sort(
1827
1829
raise TypeError (
1828
1830
"Only list-like objects are allowed to be passed to safe_sort as values"
1829
1831
)
1832
+ original_values = values
1830
1833
1831
1834
if not isinstance (values , (np .ndarray , ABCExtensionArray )):
1832
1835
# don't convert to string types
@@ -1838,6 +1841,7 @@ def safe_sort(
1838
1841
values = np .asarray (values , dtype = dtype ) # type: ignore[arg-type]
1839
1842
1840
1843
sorter = None
1844
+ ordered : np .ndarray | MultiIndex
1841
1845
1842
1846
if (
1843
1847
not is_extension_array_dtype (values )
@@ -1853,7 +1857,7 @@ def safe_sort(
1853
1857
# which would work, but which fails for special case of 1d arrays
1854
1858
# with tuples.
1855
1859
if values .size and isinstance (values [0 ], tuple ):
1856
- ordered = _sort_tuples (values )
1860
+ ordered = _sort_tuples (values , original_values )
1857
1861
else :
1858
1862
ordered = _sort_mixed (values )
1859
1863
@@ -1915,19 +1919,33 @@ def _sort_mixed(values) -> np.ndarray:
1915
1919
)
1916
1920
1917
1921
1918
- def _sort_tuples (values : np .ndarray ) -> np .ndarray :
1922
+ @overload
1923
+ def _sort_tuples (values : np .ndarray , original_values : np .ndarray ) -> np .ndarray :
1924
+ ...
1925
+
1926
+
1927
+ @overload
1928
+ def _sort_tuples (values : np .ndarray , original_values : MultiIndex ) -> MultiIndex :
1929
+ ...
1930
+
1931
+
1932
+ def _sort_tuples (
1933
+ values : np .ndarray , original_values : np .ndarray | MultiIndex
1934
+ ) -> np .ndarray | MultiIndex :
1919
1935
"""
1920
1936
Convert array of tuples (1d) to array or array (2d).
1921
1937
We need to keep the columns separately as they contain different types and
1922
1938
nans (can't use `np.sort` as it may fail when str and nan are mixed in a
1923
1939
column as types cannot be compared).
1940
+ We have to apply the indexer to the original values to keep the dtypes in
1941
+ case of MultiIndexes
1924
1942
"""
1925
1943
from pandas .core .internals .construction import to_arrays
1926
1944
from pandas .core .sorting import lexsort_indexer
1927
1945
1928
1946
arrays , _ = to_arrays (values , None )
1929
1947
indexer = lexsort_indexer (arrays , orders = True )
1930
- return values [indexer ]
1948
+ return original_values [indexer ]
1931
1949
1932
1950
1933
1951
def union_with_duplicates (lvals : ArrayLike , rvals : ArrayLike ) -> ArrayLike :
0 commit comments