diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 77cd73fdfe91b..e7ec7fde762f4 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -8,7 +8,9 @@ from textwrap import dedent from typing import ( TYPE_CHECKING, + Hashable, Literal, + Sequence, Union, cast, final, @@ -27,6 +29,7 @@ AnyArrayLike, ArrayLike, DtypeObj, + IndexLabel, Scalar, TakeIndexer, npt, @@ -1338,10 +1341,12 @@ class SelectNFrame(SelectN): nordered : DataFrame """ - def __init__(self, obj, n: int, keep: str, columns): + def __init__(self, obj: DataFrame, n: int, keep: str, columns: IndexLabel): super().__init__(obj, n, keep) if not is_list_like(columns) or isinstance(columns, tuple): columns = [columns] + + columns = cast(Sequence[Hashable], columns) columns = list(columns) self.columns = columns diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 5f4207d0985ef..36e8177477559 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -1374,7 +1374,7 @@ def dot(self, other: Series) -> Series: def dot(self, other: DataFrame | Index | ArrayLike) -> DataFrame: ... - def dot(self, other: AnyArrayLike | DataFrame | Series) -> DataFrame | Series: + def dot(self, other: AnyArrayLike | DataFrame) -> DataFrame | Series: """ Compute the matrix multiplication between the DataFrame and other. @@ -2155,7 +2155,6 @@ def maybe_reorder( to_remove = [arr_columns.get_loc(col) for col in arr_exclude] arrays = [v for i, v in enumerate(arrays) if i not in to_remove] - arr_columns = arr_columns.drop(arr_exclude) columns = columns.drop(exclude) manager = get_option("mode.data_manager") @@ -4383,7 +4382,13 @@ def predicate(arr: ArrayLike) -> bool: mgr = self._mgr._get_data_subset(predicate) return type(self)(mgr).__finalize__(self) - def insert(self, loc, column, value, allow_duplicates: bool = False) -> None: + def insert( + self, + loc: int, + column: Hashable, + value: Scalar | AnyArrayLike, + allow_duplicates: bool = False, + ) -> None: """ Insert column into DataFrame at specified location. @@ -4396,8 +4401,8 @@ def insert(self, loc, column, value, allow_duplicates: bool = False) -> None: Insertion index. Must verify 0 <= loc <= len(columns). column : str, number, or hashable object Label of the inserted column. - value : int, Series, or array-like - allow_duplicates : bool, optional + value : Scalar, Series, or array-like + allow_duplicates : bool, optional default False See Also -------- @@ -6566,7 +6571,7 @@ def value_counts( return counts - def nlargest(self, n, columns, keep: str = "first") -> DataFrame: + def nlargest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFrame: """ Return the first `n` rows ordered by `columns` in descending order. @@ -6673,7 +6678,7 @@ def nlargest(self, n, columns, keep: str = "first") -> DataFrame: """ return algorithms.SelectNFrame(self, n=n, keep=keep, columns=columns).nlargest() - def nsmallest(self, n, columns, keep: str = "first") -> DataFrame: + def nsmallest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFrame: """ Return the first `n` rows ordered by `columns` in ascending order.