Skip to content

Commit a4d5ef6

Browse files
jbrockmendelim-vinicius
authored and
im-vinicius
committed
CLN/TYP: stronger typing in safe_sort (pandas-dev#52973)
* CLN/TYP: stronger typing in safe_sort * mypy fixup
1 parent b10e387 commit a4d5ef6

File tree

4 files changed

+38
-36
lines changed

4 files changed

+38
-36
lines changed

pandas/core/algorithms.py

+15-20
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,7 @@
3232
from pandas.util._decorators import doc
3333
from pandas.util._exceptions import find_stack_level
3434

35-
from pandas.core.dtypes.cast import (
36-
construct_1d_object_array_from_listlike,
37-
infer_dtype_from_array,
38-
)
35+
from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
3936
from pandas.core.dtypes.common import (
4037
ensure_float64,
4138
ensure_object,
@@ -1468,8 +1465,8 @@ def diff(arr, n: int, axis: AxisInt = 0):
14681465
# low-dependency, is used in this module, and used private methods from
14691466
# this module.
14701467
def safe_sort(
1471-
values,
1472-
codes=None,
1468+
values: Index | ArrayLike,
1469+
codes: npt.NDArray[np.intp] | None = None,
14731470
use_na_sentinel: bool = True,
14741471
assume_unique: bool = False,
14751472
verify: bool = True,
@@ -1484,7 +1481,7 @@ def safe_sort(
14841481
----------
14851482
values : list-like
14861483
Sequence; must be unique if ``codes`` is not None.
1487-
codes : list_like, optional
1484+
codes : np.ndarray[intp] or None, default None
14881485
Indices to ``values``. All out of bound indices are treated as
14891486
"not found" and will be masked with ``-1``.
14901487
use_na_sentinel : bool, default True
@@ -1515,20 +1512,12 @@ def safe_sort(
15151512
ValueError
15161513
* If ``codes`` is not None and ``values`` contain duplicates.
15171514
"""
1518-
if not is_list_like(values):
1515+
if not isinstance(values, (np.ndarray, ABCExtensionArray, ABCIndex)):
15191516
raise TypeError(
1520-
"Only list-like objects are allowed to be passed to safe_sort as values"
1517+
"Only np.ndarray, ExtensionArray, and Index objects are allowed to "
1518+
"be passed to safe_sort as values"
15211519
)
15221520

1523-
if not is_array_like(values):
1524-
# don't convert to string types
1525-
dtype, _ = infer_dtype_from_array(values)
1526-
# error: Argument "dtype" to "asarray" has incompatible type "Union[dtype[Any],
1527-
# ExtensionDtype]"; expected "Union[dtype[Any], None, type, _SupportsDType, str,
1528-
# Union[Tuple[Any, int], Tuple[Any, Union[int, Sequence[int]]], List[Any],
1529-
# _DTypeDict, Tuple[Any, Any]]]"
1530-
values = np.asarray(values, dtype=dtype) # type: ignore[arg-type]
1531-
15321521
sorter = None
15331522
ordered: AnyArrayLike
15341523

@@ -1546,7 +1535,10 @@ def safe_sort(
15461535
# which would work, but which fails for special case of 1d arrays
15471536
# with tuples.
15481537
if values.size and isinstance(values[0], tuple):
1549-
ordered = _sort_tuples(values)
1538+
# error: Argument 1 to "_sort_tuples" has incompatible type
1539+
# "Union[Index, ExtensionArray, ndarray[Any, Any]]"; expected
1540+
# "ndarray[Any, Any]"
1541+
ordered = _sort_tuples(values) # type: ignore[arg-type]
15501542
else:
15511543
ordered = _sort_mixed(values)
15521544

@@ -1567,7 +1559,10 @@ def safe_sort(
15671559

15681560
if sorter is None:
15691561
# mixed types
1570-
hash_klass, values = _get_hashtable_algo(values)
1562+
# error: Argument 1 to "_get_hashtable_algo" has incompatible type
1563+
# "Union[Index, ExtensionArray, ndarray[Any, Any]]"; expected
1564+
# "ndarray[Any, Any]"
1565+
hash_klass, values = _get_hashtable_algo(values) # type: ignore[arg-type]
15711566
t = hash_klass(len(values))
15721567
t.map_locations(values)
15731568
sorter = ensure_platform_int(t.lookup(ordered))

pandas/core/apply.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
ABCSeries,
5757
)
5858

59-
from pandas.core.algorithms import safe_sort
6059
from pandas.core.base import SelectionMixin
6160
import pandas.core.common as com
6261
from pandas.core.construction import ensure_wrapped_if_datetimelike
@@ -580,10 +579,11 @@ def normalize_dictlike_arg(
580579

581580
if obj.ndim != 1:
582581
# Check for missing columns on a frame
583-
cols = set(func.keys()) - set(obj.columns)
582+
from pandas import Index
583+
584+
cols = Index(list(func.keys())).difference(obj.columns, sort=True)
584585
if len(cols) > 0:
585-
cols_sorted = list(safe_sort(list(cols)))
586-
raise KeyError(f"Column(s) {cols_sorted} do not exist")
586+
raise KeyError(f"Column(s) {list(cols)} do not exist")
587587

588588
aggregator_types = (list, tuple, dict)
589589

pandas/core/indexes/base.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -7390,10 +7390,15 @@ def _unpack_nested_dtype(other: Index) -> Index:
73907390
return other
73917391

73927392

7393-
def _maybe_try_sort(result, sort):
7393+
def _maybe_try_sort(result: Index | ArrayLike, sort: bool | None):
73947394
if sort is not False:
73957395
try:
7396-
result = algos.safe_sort(result)
7396+
# error: Incompatible types in assignment (expression has type
7397+
# "Union[ExtensionArray, ndarray[Any, Any], Index, Series,
7398+
# Tuple[Union[Union[ExtensionArray, ndarray[Any, Any]], Index, Series],
7399+
# ndarray[Any, Any]]]", variable has type "Union[Index,
7400+
# Union[ExtensionArray, ndarray[Any, Any]]]")
7401+
result = algos.safe_sort(result) # type: ignore[assignment]
73977402
except TypeError as err:
73987403
if sort is True:
73997404
raise

pandas/tests/test_sorting.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -373,12 +373,15 @@ class TestSafeSort:
373373
"arg, exp",
374374
[
375375
[[3, 1, 2, 0, 4], [0, 1, 2, 3, 4]],
376-
[list("baaacb"), np.array(list("aaabbc"), dtype=object)],
376+
[
377+
np.array(list("baaacb"), dtype=object),
378+
np.array(list("aaabbc"), dtype=object),
379+
],
377380
[[], []],
378381
],
379382
)
380383
def test_basic_sort(self, arg, exp):
381-
result = safe_sort(arg)
384+
result = safe_sort(np.array(arg))
382385
expected = np.array(exp)
383386
tm.assert_numpy_array_equal(result, expected)
384387

@@ -391,7 +394,7 @@ def test_basic_sort(self, arg, exp):
391394
],
392395
)
393396
def test_codes(self, verify, codes, exp_codes):
394-
values = [3, 1, 2, 0, 4]
397+
values = np.array([3, 1, 2, 0, 4])
395398
expected = np.array([0, 1, 2, 3, 4])
396399

397400
result, result_codes = safe_sort(
@@ -407,7 +410,7 @@ def test_codes(self, verify, codes, exp_codes):
407410
"Windows fatal exception: access violation",
408411
)
409412
def test_codes_out_of_bound(self):
410-
values = [3, 1, 2, 0, 4]
413+
values = np.array([3, 1, 2, 0, 4])
411414
expected = np.array([0, 1, 2, 3, 4])
412415

413416
# out of bound indices
@@ -417,9 +420,8 @@ def test_codes_out_of_bound(self):
417420
tm.assert_numpy_array_equal(result, expected)
418421
tm.assert_numpy_array_equal(result_codes, expected_codes)
419422

420-
@pytest.mark.parametrize("box", [lambda x: np.array(x, dtype=object), list])
421-
def test_mixed_integer(self, box):
422-
values = box(["b", 1, 0, "a", 0, "b"])
423+
def test_mixed_integer(self):
424+
values = np.array(["b", 1, 0, "a", 0, "b"], dtype=object)
423425
result = safe_sort(values)
424426
expected = np.array([0, 0, 1, "a", "b", "b"], dtype=object)
425427
tm.assert_numpy_array_equal(result, expected)
@@ -443,9 +445,9 @@ def test_unsortable(self):
443445
@pytest.mark.parametrize(
444446
"arg, codes, err, msg",
445447
[
446-
[1, None, TypeError, "Only list-like objects are allowed"],
447-
[[0, 1, 2], 1, TypeError, "Only list-like objects or None"],
448-
[[0, 1, 2, 1], [0, 1], ValueError, "values should be unique"],
448+
[1, None, TypeError, "Only np.ndarray, ExtensionArray, and Index"],
449+
[np.array([0, 1, 2]), 1, TypeError, "Only list-like objects or None"],
450+
[np.array([0, 1, 2, 1]), [0, 1], ValueError, "values should be unique"],
449451
],
450452
)
451453
def test_exceptions(self, arg, codes, err, msg):

0 commit comments

Comments
 (0)