From e0b9e203646c2d735c5408ced1ba1c09b0665822 Mon Sep 17 00:00:00 2001 From: phofl Date: Sun, 3 Oct 2021 23:00:26 +0200 Subject: [PATCH 1/2] TYP: type insert and nsmallest/largest --- pandas/core/algorithms.py | 6 +++++- pandas/core/frame.py | 19 ++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 77cd73fdfe91b..f0a289927b861 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -4,6 +4,7 @@ """ from __future__ import annotations +from collections import abc import operator from textwrap import dedent from typing import ( @@ -27,6 +28,7 @@ AnyArrayLike, ArrayLike, DtypeObj, + IndexLabel, Scalar, TakeIndexer, npt, @@ -1338,10 +1340,12 @@ class SelectNFrame(SelectN): nordered : DataFrame """ - def __init__(self, obj, n: int, keep: str, columns): + def __init__(self, obj: DataFrame, n: int, keep: str, columns: IndexLabel): super().__init__(obj, n, keep) if not is_list_like(columns) or isinstance(columns, tuple): columns = [columns] + + assert isinstance(columns, abc.Iterable) columns = list(columns) self.columns = columns diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 5f4207d0985ef..36e8177477559 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -1374,7 +1374,7 @@ def dot(self, other: Series) -> Series: def dot(self, other: DataFrame | Index | ArrayLike) -> DataFrame: ... - def dot(self, other: AnyArrayLike | DataFrame | Series) -> DataFrame | Series: + def dot(self, other: AnyArrayLike | DataFrame) -> DataFrame | Series: """ Compute the matrix multiplication between the DataFrame and other. @@ -2155,7 +2155,6 @@ def maybe_reorder( to_remove = [arr_columns.get_loc(col) for col in arr_exclude] arrays = [v for i, v in enumerate(arrays) if i not in to_remove] - arr_columns = arr_columns.drop(arr_exclude) columns = columns.drop(exclude) manager = get_option("mode.data_manager") @@ -4383,7 +4382,13 @@ def predicate(arr: ArrayLike) -> bool: mgr = self._mgr._get_data_subset(predicate) return type(self)(mgr).__finalize__(self) - def insert(self, loc, column, value, allow_duplicates: bool = False) -> None: + def insert( + self, + loc: int, + column: Hashable, + value: Scalar | AnyArrayLike, + allow_duplicates: bool = False, + ) -> None: """ Insert column into DataFrame at specified location. @@ -4396,8 +4401,8 @@ def insert(self, loc, column, value, allow_duplicates: bool = False) -> None: Insertion index. Must verify 0 <= loc <= len(columns). column : str, number, or hashable object Label of the inserted column. - value : int, Series, or array-like - allow_duplicates : bool, optional + value : Scalar, Series, or array-like + allow_duplicates : bool, optional default False See Also -------- @@ -6566,7 +6571,7 @@ def value_counts( return counts - def nlargest(self, n, columns, keep: str = "first") -> DataFrame: + def nlargest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFrame: """ Return the first `n` rows ordered by `columns` in descending order. @@ -6673,7 +6678,7 @@ def nlargest(self, n, columns, keep: str = "first") -> DataFrame: """ return algorithms.SelectNFrame(self, n=n, keep=keep, columns=columns).nlargest() - def nsmallest(self, n, columns, keep: str = "first") -> DataFrame: + def nsmallest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFrame: """ Return the first `n` rows ordered by `columns` in ascending order. From adec87974d075532585b70edd95f05fbd146ee03 Mon Sep 17 00:00:00 2001 From: phofl Date: Sun, 10 Oct 2021 21:35:42 +0200 Subject: [PATCH 2/2] Add cast instead of assert --- pandas/core/algorithms.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index f0a289927b861..e7ec7fde762f4 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -4,12 +4,13 @@ """ from __future__ import annotations -from collections import abc import operator from textwrap import dedent from typing import ( TYPE_CHECKING, + Hashable, Literal, + Sequence, Union, cast, final, @@ -1345,7 +1346,7 @@ def __init__(self, obj: DataFrame, n: int, keep: str, columns: IndexLabel): if not is_list_like(columns) or isinstance(columns, tuple): columns = [columns] - assert isinstance(columns, abc.Iterable) + columns = cast(Sequence[Hashable], columns) columns = list(columns) self.columns = columns