From 25de5ee2bed337ad1cbdc01447639c852bbb75e9 Mon Sep 17 00:00:00 2001 From: Andrey Kolomiets Date: Wed, 26 Jun 2024 08:27:16 +0000 Subject: [PATCH 1/3] Typing improvements for Index --- pandas/core/indexes/base.py | 55 +++++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 11 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 71dfff520113c..064e1e3373f73 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -11,6 +11,8 @@ ClassVar, Literal, NoReturn, + Optional, + Union, cast, final, overload, @@ -46,6 +48,7 @@ Axes, Axis, DropKeep, + Dtype, DtypeObj, F, IgnoreRaise, @@ -155,6 +158,8 @@ ExtensionArray, TimedeltaArray, ) +from pandas.core.arrays.floating import FloatingDtype +from pandas.core.arrays.integer import IntegerDtype from pandas.core.arrays.string_ import ( StringArray, StringDtype, @@ -312,6 +317,20 @@ def _new_Index(cls, d): return cls.__new__(cls, **d) +slice_type = Optional[ + Union[ + str, + IntegerDtype, + FloatingDtype, + DatetimeTZDtype, + CategoricalDtype, + PeriodDtype, + IntervalDtype, + abc.Hashable, + ] +] + + class Index(IndexOpsMixin, PandasObject): """ Immutable sequence used for indexing and alignment. @@ -1087,7 +1106,7 @@ def view(self, cls=None): result._id = self._id return result - def astype(self, dtype, copy: bool = True): + def astype(self, dtype: Dtype, copy: bool = True): """ Create an Index with values cast to dtypes. @@ -2957,7 +2976,7 @@ def _dti_setop_align_tzs(self, other: Index, setop: str_t) -> tuple[Index, Index return self, other @final - def union(self, other, sort=None): + def union(self, other, sort: bool | None = None): """ Form the union of two Index objects. @@ -3334,7 +3353,7 @@ def _intersection_via_get_indexer( return result @final - def difference(self, other, sort=None): + def difference(self, other, sort: bool | None = None): """ Return a new Index with elements of index not in `other`. @@ -3420,7 +3439,12 @@ def _wrap_difference_result(self, other, result): # We will override for MultiIndex to handle empty results return self._wrap_setop_result(other, result) - def symmetric_difference(self, other, result_name=None, sort=None): + def symmetric_difference( + self, + other, + result_name: abc.Hashable | None = None, + sort: bool | None = None, + ): """ Compute the symmetric difference of two Index objects. @@ -6389,7 +6413,7 @@ def _transform_index(self, func, *, level=None) -> Index: items = [func(x) for x in self] return Index(items, name=self.name, tupleize_cols=False) - def isin(self, values, level=None) -> npt.NDArray[np.bool_]: + def isin(self, values, level: np.str_ | int | None = None) -> npt.NDArray[np.bool_]: """ Return a boolean array where the index values are in `values`. @@ -6687,7 +6711,12 @@ def get_slice_bound(self, label, side: Literal["left", "right"]) -> int: else: return slc - def slice_locs(self, start=None, end=None, step=None) -> tuple[int, int]: + def slice_locs( + self, + start: slice_type = None, + end: slice_type = None, + step: int | None = None, + ) -> tuple[int, int]: """ Compute slice locations for input labels. @@ -6781,7 +6810,7 @@ def slice_locs(self, start=None, end=None, step=None) -> tuple[int, int]: return start_slice, end_slice - def delete(self, loc) -> Self: + def delete(self, loc: int | list[int] | np.ndarray) -> Self: """ Make new Index with passed location(-s) deleted. @@ -7227,7 +7256,9 @@ def _maybe_disable_logical_methods(self, opname: str_t) -> None: raise TypeError(f"cannot perform {opname} with {type(self).__name__}") @Appender(IndexOpsMixin.argmin.__doc__) - def argmin(self, axis=None, skipna: bool = True, *args, **kwargs) -> int: + def argmin( + self, axis: int | None = None, skipna: bool = True, *args, **kwargs + ) -> int: nv.validate_argmin(args, kwargs) nv.validate_minmax_axis(axis) @@ -7240,7 +7271,9 @@ def argmin(self, axis=None, skipna: bool = True, *args, **kwargs) -> int: return super().argmin(skipna=skipna) @Appender(IndexOpsMixin.argmax.__doc__) - def argmax(self, axis=None, skipna: bool = True, *args, **kwargs) -> int: + def argmax( + self, axis: int | None = None, skipna: bool = True, *args, **kwargs + ) -> int: nv.validate_argmax(args, kwargs) nv.validate_minmax_axis(axis) @@ -7251,7 +7284,7 @@ def argmax(self, axis=None, skipna: bool = True, *args, **kwargs) -> int: raise ValueError("Encountered all NA values") return super().argmax(skipna=skipna) - def min(self, axis=None, skipna: bool = True, *args, **kwargs): + def min(self, axis: int | None = None, skipna: bool = True, *args, **kwargs): """ Return the minimum value of the Index. @@ -7314,7 +7347,7 @@ def min(self, axis=None, skipna: bool = True, *args, **kwargs): return nanops.nanmin(self._values, skipna=skipna) - def max(self, axis=None, skipna: bool = True, *args, **kwargs): + def max(self, axis: int | None = None, skipna: bool = True, *args, **kwargs): """ Return the maximum value of the Index. From 561120b862503e0da5a145059d1ffe2352b832d7 Mon Sep 17 00:00:00 2001 From: Andrey Kolomiets Date: Thu, 27 Jun 2024 12:07:51 +0000 Subject: [PATCH 2/3] better numpy type hints for Index.delete --- pandas/core/indexes/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 064e1e3373f73..80b44d3f1aac9 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -6810,7 +6810,7 @@ def slice_locs( return start_slice, end_slice - def delete(self, loc: int | list[int] | np.ndarray) -> Self: + def delete(self, loc: int | np.integer | list[int] | npt.NDArray[np.int_]) -> Self: """ Make new Index with passed location(-s) deleted. From 36c3495d9f81cf95b1e6cafd6468747367d01fd5 Mon Sep 17 00:00:00 2001 From: Andrey Kolomiets Date: Fri, 19 Jul 2024 20:48:46 +0000 Subject: [PATCH 3/3] replace some hints with literals, move slice_type to _typing.py --- pandas/_typing.py | 2 ++ pandas/core/indexes/base.py | 38 ++++++++++++------------------------- 2 files changed, 14 insertions(+), 26 deletions(-) diff --git a/pandas/_typing.py b/pandas/_typing.py index 09a3f58d6ab7f..d43e6e900546d 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -526,3 +526,5 @@ def closed(self) -> bool: # maintaine the sub-type of any hashable sequence SequenceT = TypeVar("SequenceT", bound=Sequence[Hashable]) + +SliceType = Optional[Hashable] diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 66395c74d90f0..7a25ec7f6135c 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -11,8 +11,6 @@ ClassVar, Literal, NoReturn, - Optional, - Union, cast, final, overload, @@ -47,6 +45,7 @@ ArrayLike, Axes, Axis, + AxisInt, DropKeep, Dtype, DtypeObj, @@ -60,6 +59,7 @@ ReindexMethod, Self, Shape, + SliceType, npt, ) from pandas.compat.numpy import function as nv @@ -158,8 +158,6 @@ ExtensionArray, TimedeltaArray, ) -from pandas.core.arrays.floating import FloatingDtype -from pandas.core.arrays.integer import IntegerDtype from pandas.core.arrays.string_ import ( StringArray, StringDtype, @@ -317,20 +315,6 @@ def _new_Index(cls, d): return cls.__new__(cls, **d) -slice_type = Optional[ - Union[ - str, - IntegerDtype, - FloatingDtype, - DatetimeTZDtype, - CategoricalDtype, - PeriodDtype, - IntervalDtype, - abc.Hashable, - ] -] - - class Index(IndexOpsMixin, PandasObject): """ Immutable sequence used for indexing and alignment. @@ -6413,7 +6397,7 @@ def _transform_index(self, func, *, level=None) -> Index: items = [func(x) for x in self] return Index(items, name=self.name, tupleize_cols=False) - def isin(self, values, level: np.str_ | int | None = None) -> npt.NDArray[np.bool_]: + def isin(self, values, level: str_t | int | None = None) -> npt.NDArray[np.bool_]: """ Return a boolean array where the index values are in `values`. @@ -6713,8 +6697,8 @@ def get_slice_bound(self, label, side: Literal["left", "right"]) -> int: def slice_locs( self, - start: slice_type = None, - end: slice_type = None, + start: SliceType = None, + end: SliceType = None, step: int | None = None, ) -> tuple[int, int]: """ @@ -6810,7 +6794,9 @@ def slice_locs( return start_slice, end_slice - def delete(self, loc: int | np.integer | list[int] | npt.NDArray[np.int_]) -> Self: + def delete( + self, loc: int | np.integer | list[int] | npt.NDArray[np.integer] + ) -> Self: """ Make new Index with passed location(-s) deleted. @@ -7257,7 +7243,7 @@ def _maybe_disable_logical_methods(self, opname: str_t) -> None: @Appender(IndexOpsMixin.argmin.__doc__) def argmin( - self, axis: int | None = None, skipna: bool = True, *args, **kwargs + self, axis: AxisInt | None = None, skipna: bool = True, *args, **kwargs ) -> int: nv.validate_argmin(args, kwargs) nv.validate_minmax_axis(axis) @@ -7272,7 +7258,7 @@ def argmin( @Appender(IndexOpsMixin.argmax.__doc__) def argmax( - self, axis: int | None = None, skipna: bool = True, *args, **kwargs + self, axis: AxisInt | None = None, skipna: bool = True, *args, **kwargs ) -> int: nv.validate_argmax(args, kwargs) nv.validate_minmax_axis(axis) @@ -7284,7 +7270,7 @@ def argmax( raise ValueError("Encountered all NA values") return super().argmax(skipna=skipna) - def min(self, axis: int | None = None, skipna: bool = True, *args, **kwargs): + def min(self, axis: AxisInt | None = None, skipna: bool = True, *args, **kwargs): """ Return the minimum value of the Index. @@ -7347,7 +7333,7 @@ def min(self, axis: int | None = None, skipna: bool = True, *args, **kwargs): return nanops.nanmin(self._values, skipna=skipna) - def max(self, axis: int | None = None, skipna: bool = True, *args, **kwargs): + def max(self, axis: AxisInt | None = None, skipna: bool = True, *args, **kwargs): """ Return the maximum value of the Index.