From 01fd50aefb97a2e2cecb039fc62d4f50338ee7b2 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Tue, 14 Feb 2023 23:57:36 -0500 Subject: [PATCH 1/2] Fix __getitem__() for various index types --- pandas-stubs/core/algorithms.pyi | 2 +- pandas-stubs/core/indexes/base.pyi | 28 ++++++- pandas-stubs/core/indexes/datetimelike.pyi | 24 +----- pandas-stubs/core/indexes/datetimes.pyi | 12 ++- pandas-stubs/core/indexes/extension.pyi | 13 +-- pandas-stubs/core/indexes/interval.pyi | 39 +++++---- pandas-stubs/core/indexes/multi.pyi | 20 ++++- pandas-stubs/core/indexes/period.pyi | 30 +++++-- pandas-stubs/core/indexes/range.pyi | 20 ++++- pandas-stubs/core/indexes/timedeltas.pyi | 14 ++-- tests/test_indexes.py | 93 ++++++++++++++++++++++ 11 files changed, 227 insertions(+), 68 deletions(-) diff --git a/pandas-stubs/core/algorithms.pyi b/pandas-stubs/core/algorithms.pyi index c81b8438a..59cd438fa 100644 --- a/pandas-stubs/core/algorithms.pyi +++ b/pandas-stubs/core/algorithms.pyi @@ -25,7 +25,7 @@ def unique(values: PeriodIndex) -> PeriodIndex: ... # type: ignore[misc] # pyri @overload def unique(values: CategoricalIndex) -> CategoricalIndex: ... # type: ignore[misc] @overload -def unique(values: IntervalIndex[IntervalT]) -> IntervalIndex[IntervalT]: ... # type: ignore[misc] +def unique(values: IntervalIndex[IntervalT]) -> IntervalIndex[IntervalT]: ... @overload def unique(values: Index) -> np.ndarray: ... @overload diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index d24c9f161..137cf82c7 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -7,6 +7,7 @@ from collections.abc import ( ) from typing import ( ClassVar, + Generic, Literal, overload, ) @@ -27,6 +28,7 @@ from pandas.core.strings import StringMethods from typing_extensions import Never from pandas._typing import ( + S1, T1, Dtype, DtypeArg, @@ -48,6 +50,22 @@ class InvalidIndexError(Exception): ... _str = str +class _IndexGetitemMixin(Generic[S1]): + # type ignore needed because it doesn't like the type of self + @overload + def __getitem__( # type: ignore[misc] + self: IndexT, + idx: slice + | np_ndarray_anyint + | Sequence[int] + | Index + | Series[bool] + | Sequence[bool] + | np_ndarray_bool, + ) -> IndexT: ... + @overload + def __getitem__(self, idx: int) -> S1: ... + class Index(IndexOpsMixin, PandasObject): __hash__: ClassVar[None] # type: ignore[assignment] @overload @@ -158,7 +176,7 @@ class Index(IndexOpsMixin, PandasObject): __bool__ = ... def union(self, other: list[HashableT] | Index, sort=...) -> Index: ... def intersection(self, other: list[T1] | Index, sort: bool = ...) -> Index: ... - def difference(self, other: list | Index) -> Index: ... + def difference(self, other: list | Index, sort: bool | None = None) -> Index: ... def symmetric_difference( self, other: list[T1] | Index, result_name=..., sort=... ) -> Index: ... @@ -191,7 +209,13 @@ class Index(IndexOpsMixin, PandasObject): @overload def __getitem__( self: IndexT, - idx: slice | np_ndarray_anyint | Index | Series[bool] | np_ndarray_bool, + idx: slice + | np_ndarray_anyint + | Sequence[int] + | Index + | Series[bool] + | Sequence[bool] + | np_ndarray_bool, ) -> IndexT: ... @overload def __getitem__(self, idx: int | tuple[np_ndarray_anyint, ...]) -> Scalar: ... diff --git a/pandas-stubs/core/indexes/datetimelike.pyi b/pandas-stubs/core/indexes/datetimelike.pyi index 073af7855..e0425bd95 100644 --- a/pandas-stubs/core/indexes/datetimelike.pyi +++ b/pandas-stubs/core/indexes/datetimelike.pyi @@ -1,5 +1,5 @@ from pandas.core.indexes.extension import ExtensionIndex -from pandas.core.indexes.numeric import Int64Index +from pandas.core.indexes.timedeltas import TimedeltaIndex from pandas._libs.tslibs import BaseOffset @@ -10,28 +10,10 @@ class DatetimeIndexOpsMixin(ExtensionIndex): def freqstr(self) -> str | None: ... @property def is_all_dates(self) -> bool: ... - @property - def values(self): ... - def __array_wrap__(self, result, context=...): ... - def equals(self, other) -> bool: ... - def __contains__(self, key): ... - def sort_values(self, return_indexer: bool = ..., ascending: bool = ...): ... - def take( - self, indices, axis: int = ..., allow_fill: bool = ..., fill_value=..., **kwargs - ): ... - def tolist(self) -> list: ... def min(self, axis=..., skipna: bool = ..., *args, **kwargs): ... def argmin(self, axis=..., skipna: bool = ..., *args, **kwargs): ... def max(self, axis=..., skipna: bool = ..., *args, **kwargs): ... def argmax(self, axis=..., skipna: bool = ..., *args, **kwargs): ... - def isin(self, values, level=...): ... - def where(self, cond, other=...): ... - def shift(self, periods: int = ..., freq=...): ... - def delete(self, loc): ... + def __rsub__(self, other: DatetimeIndexOpsMixin) -> TimedeltaIndex: ... -class DatetimeTimedeltaMixin(DatetimeIndexOpsMixin, Int64Index): - def difference(self, other, sort=...): ... - def intersection(self, other, sort: bool = ...): ... - def join( - self, other, *, how: str = ..., level=..., return_indexers=..., sort=... - ): ... +class DatetimeTimedeltaMixin(DatetimeIndexOpsMixin): ... diff --git a/pandas-stubs/core/indexes/datetimes.pyi b/pandas-stubs/core/indexes/datetimes.pyi index 12e2c06c1..ed8d66fd4 100644 --- a/pandas-stubs/core/indexes/datetimes.pyi +++ b/pandas-stubs/core/indexes/datetimes.pyi @@ -17,6 +17,7 @@ from pandas import ( ) from pandas.core.indexes.accessors import DatetimeIndexProperties from pandas.core.indexes.api import Float64Index +from pandas.core.indexes.base import _IndexGetitemMixin from pandas.core.indexes.datetimelike import DatetimeTimedeltaMixin from pandas.core.series import ( TimedeltaSeries, @@ -34,7 +35,12 @@ from pandas.core.dtypes.dtypes import DatetimeTZDtype from pandas.tseries.offsets import BaseOffset -class DatetimeIndex(DatetimeTimedeltaMixin, DatetimeIndexProperties): +# type ignore needed because of __getitem__() +class DatetimeIndex( # type: ignore[misc] + _IndexGetitemMixin[Timestamp], + DatetimeTimedeltaMixin, + DatetimeIndexProperties, +): def __init__( self, data: ArrayLike | AnyArrayLike | list | tuple, @@ -53,11 +59,11 @@ class DatetimeIndex(DatetimeTimedeltaMixin, DatetimeIndexProperties): def __reduce__(self): ... # various ignores needed for mypy, as we do want to restrict what can be used in # arithmetic for these types - @overload # type: ignore[override] + @overload def __add__(self, other: TimedeltaSeries) -> TimestampSeries: ... @overload def __add__(self, other: Timedelta | TimedeltaIndex) -> DatetimeIndex: ... - @overload # type: ignore[override] + @overload def __sub__(self, other: TimedeltaSeries) -> TimestampSeries: ... @overload def __sub__(self, other: Timedelta | TimedeltaIndex) -> DatetimeIndex: ... diff --git a/pandas-stubs/core/indexes/extension.pyi b/pandas-stubs/core/indexes/extension.pyi index 2f71ebdde..a5595f622 100644 --- a/pandas-stubs/core/indexes/extension.pyi +++ b/pandas-stubs/core/indexes/extension.pyi @@ -1,14 +1,3 @@ -from typing import Literal - from pandas.core.indexes.base import Index -class ExtensionIndex(Index): - def __iter__(self): ... - def dropna(self, how: Literal["any", "all"] = ...): ... - def repeat(self, repeats, axis=...): ... - def take( - self, indices, axis: int = ..., allow_fill: bool = ..., fill_value=..., **kwargs - ): ... - def unique(self, level=...): ... - def map(self, mapper, na_action=...): ... - def astype(self, dtype, copy: bool = ...): ... +class ExtensionIndex(Index): ... diff --git a/pandas-stubs/core/indexes/interval.pyi b/pandas-stubs/core/indexes/interval.pyi index 19ee6efe8..0d5341e21 100644 --- a/pandas-stubs/core/indexes/interval.pyi +++ b/pandas-stubs/core/indexes/interval.pyi @@ -13,8 +13,8 @@ from typing import ( import numpy as np import pandas as pd from pandas import Index -from pandas.core.indexes.extension import ExtensionIndex from pandas.core.series import ( + Series, TimedeltaSeries, TimestampSeries, ) @@ -32,6 +32,7 @@ from pandas._typing import ( IntervalClosedType, IntervalT, Label, + np_ndarray_anyint, np_ndarray_bool, npt, ) @@ -68,7 +69,7 @@ _EdgesTimedelta: TypeAlias = Union[ _TimestampLike: TypeAlias = Union[pd.Timestamp, np.datetime64, dt.datetime] _TimedeltaLike: TypeAlias = Union[pd.Timedelta, np.timedelta64, dt.timedelta] -class IntervalIndex(IntervalMixin, ExtensionIndex, Generic[IntervalT]): +class IntervalIndex(IntervalMixin, Generic[IntervalT]): def __new__( cls, data: Sequence[IntervalT], @@ -78,10 +79,9 @@ class IntervalIndex(IntervalMixin, ExtensionIndex, Generic[IntervalT]): name: Hashable = ..., verify_integrity: bool = ..., ) -> IntervalIndex[IntervalT]: ... - # ignore[misc] here due to overlap, e.g., Sequence[int] and Sequence[float] @overload @classmethod - def from_breaks( # type:ignore[misc] + def from_breaks( cls, breaks: _EdgesInt, closed: IntervalClosedType = ..., @@ -119,10 +119,9 @@ class IntervalIndex(IntervalMixin, ExtensionIndex, Generic[IntervalT]): copy: bool = ..., dtype: IntervalDtype | None = ..., ) -> IntervalIndex[Interval[pd.Timedelta]]: ... - # ignore[misc] here due to overlap, e.g., Sequence[int] and Sequence[float] @overload @classmethod - def from_arrays( # type:ignore[misc] + def from_arrays( cls, left: _EdgesInt, right: _EdgesInt, @@ -250,41 +249,51 @@ class IntervalIndex(IntervalMixin, ExtensionIndex, Generic[IntervalT]): @property def length(self) -> Index: ... def get_value(self, series: ABCSeries, key): ... + @overload + def __getitem__( + self, + idx: slice + | np_ndarray_anyint + | Sequence[int] + | Index + | Series[bool] + | np_ndarray_bool, + ) -> IntervalIndex[IntervalT]: ... + @overload + def __getitem__(self, idx: int) -> IntervalT: ... @property def is_all_dates(self) -> bool: ... - # override is due to additional types for comparison - # misc is due to overlap with object below - @overload # type: ignore[override] + @overload def __gt__( self, other: IntervalT | IntervalIndex[IntervalT] ) -> np_ndarray_bool: ... @overload def __gt__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ... - @overload # type: ignore[override] + @overload def __ge__( self, other: IntervalT | IntervalIndex[IntervalT] ) -> np_ndarray_bool: ... @overload def __ge__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ... - @overload # type: ignore[override] + @overload def __le__( self, other: IntervalT | IntervalIndex[IntervalT] ) -> np_ndarray_bool: ... @overload def __le__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ... - @overload # type: ignore[override] + @overload def __lt__( self, other: IntervalT | IntervalIndex[IntervalT] ) -> np_ndarray_bool: ... @overload - def __lt__(self, other: pd.Series[IntervalT]) -> bool: ... # type: ignore[misc] - @overload # type: ignore[override] + def __lt__(self, other: pd.Series[IntervalT]) -> bool: ... + @overload def __eq__(self, other: IntervalT | IntervalIndex[IntervalT]) -> np_ndarray_bool: ... # type: ignore[misc] @overload def __eq__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ... # type: ignore[misc] @overload def __eq__(self, other: object) -> Literal[False]: ... - @overload # type: ignore[override] + @overload def __ne__(self, other: IntervalT | IntervalIndex[IntervalT]) -> np_ndarray_bool: ... # type: ignore[misc] @overload def __ne__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ... # type: ignore[misc] diff --git a/pandas-stubs/core/indexes/multi.pyi b/pandas-stubs/core/indexes/multi.pyi index 688917f54..e4b75a560 100644 --- a/pandas-stubs/core/indexes/multi.pyi +++ b/pandas-stubs/core/indexes/multi.pyi @@ -3,7 +3,10 @@ from collections.abc import ( Hashable, Sequence, ) -from typing import Literal +from typing import ( + Literal, + overload, +) import numpy as np import pandas as pd @@ -13,6 +16,7 @@ from pandas._typing import ( T1, DtypeArg, HashableT, + np_ndarray_anyint, np_ndarray_bool, ) @@ -105,7 +109,19 @@ class MultiIndex(Index): @property def levshape(self): ... def __reduce__(self): ... - def __getitem__(self, key): ... + @overload # type: ignore[override] + def __getitem__( + self, + idx: slice + | np_ndarray_anyint + | Sequence[int] + | Index + | pd.Series[bool] + | Sequence[bool] + | np_ndarray_bool, + ) -> MultiIndex: ... + @overload + def __getitem__(self, key: int) -> tuple: ... def take( self, indices, axis: int = ..., allow_fill: bool = ..., fill_value=..., **kwargs ): ... diff --git a/pandas-stubs/core/indexes/period.pyi b/pandas-stubs/core/indexes/period.pyi index dc1ab02a6..8edd2cae7 100644 --- a/pandas-stubs/core/indexes/period.pyi +++ b/pandas-stubs/core/indexes/period.pyi @@ -5,18 +5,25 @@ import numpy as np import pandas as pd from pandas import Index from pandas.core.indexes.accessors import PeriodIndexFieldOps +from pandas.core.indexes.base import _IndexGetitemMixin from pandas.core.indexes.datetimelike import ( DatetimeIndexOpsMixin as DatetimeIndexOpsMixin, ) -from pandas.core.indexes.numeric import Int64Index -from pandas.core.series import OffsetSeries +from pandas.core.indexes.timedeltas import TimedeltaIndex from pandas._libs.tslibs import ( BaseOffset, + NaTType, Period, ) +from pandas._libs.tslibs.period import _PeriodAddSub -class PeriodIndex(DatetimeIndexOpsMixin, Int64Index, PeriodIndexFieldOps): +# type ignore needed because of __getitem__() +class PeriodIndex( # type: ignore[misc] + _IndexGetitemMixin[Period], + DatetimeIndexOpsMixin, + PeriodIndexFieldOps, +): def __new__( cls, data=..., @@ -31,11 +38,22 @@ class PeriodIndex(DatetimeIndexOpsMixin, Int64Index, PeriodIndexFieldOps): @property def values(self): ... def __contains__(self, key) -> bool: ... - # Override due to supertype incompatibility which has it for NumericIndex or complex. - @overload # type: ignore[override] + @overload def __sub__(self, other: Period) -> Index: ... @overload - def __sub__(self, other: PeriodIndex) -> OffsetSeries: ... + def __sub__(self, other: PeriodIndex) -> Index: ... + @overload + def __sub__(self, other: _PeriodAddSub) -> PeriodIndex: ... + @overload + def __sub__(self, other: NaTType) -> NaTType: ... + @overload + def __sub__(self, other: TimedeltaIndex | pd.Timedelta) -> PeriodIndex: ... + @overload # type: ignore[override] + def __rsub__(self, other: Period) -> Index: ... + @overload + def __rsub__(self, other: PeriodIndex) -> Index: ... + @overload + def __rsub__(self, other: NaTType) -> NaTType: ... def __array__(self, dtype=...) -> np.ndarray: ... def __array_wrap__(self, result, context=...): ... def asof_locs(self, where, mask): ... diff --git a/pandas-stubs/core/indexes/range.pyi b/pandas-stubs/core/indexes/range.pyi index fcddbf527..03ef66786 100644 --- a/pandas-stubs/core/indexes/range.pyi +++ b/pandas-stubs/core/indexes/range.pyi @@ -1,9 +1,15 @@ +from collections.abc import Sequence +from typing import overload + import numpy as np +from pandas import Series from pandas.core.indexes.base import Index from pandas.core.indexes.numeric import Int64Index from pandas._typing import ( HashableT, + np_ndarray_anyint, + np_ndarray_bool, npt, ) @@ -70,10 +76,22 @@ class RangeIndex(Int64Index): def __len__(self) -> int: ... @property def size(self) -> int: ... - def __getitem__(self, key): ... def __floordiv__(self, other): ... def all(self) -> bool: ... def any(self) -> bool: ... def union( self, other: list[HashableT] | Index, sort=... ) -> Index | Int64Index | RangeIndex: ... + @overload # type: ignore[override] + def __getitem__( + self, + idx: slice + | np_ndarray_anyint + | Sequence[int] + | Index + | Series[bool] + | Sequence[bool] + | np_ndarray_bool, + ) -> Index: ... + @overload + def __getitem__(self, idx: int) -> int: ... diff --git a/pandas-stubs/core/indexes/timedeltas.pyi b/pandas-stubs/core/indexes/timedeltas.pyi index 9afc618d0..545898190 100644 --- a/pandas-stubs/core/indexes/timedeltas.pyi +++ b/pandas-stubs/core/indexes/timedeltas.pyi @@ -14,6 +14,7 @@ from pandas import ( Period, ) from pandas.core.indexes.accessors import TimedeltaIndexProperties +from pandas.core.indexes.base import _IndexGetitemMixin from pandas.core.indexes.datetimelike import DatetimeTimedeltaMixin from pandas.core.indexes.datetimes import DatetimeIndex from pandas.core.indexes.period import PeriodIndex @@ -30,7 +31,10 @@ from pandas._typing import ( num, ) -class TimedeltaIndex(DatetimeTimedeltaMixin, TimedeltaIndexProperties): +# type ignore needed because of __getitem__() +class TimedeltaIndex( # type: ignore[misc] + _IndexGetitemMixin[Timedelta], DatetimeTimedeltaMixin, TimedeltaIndexProperties +): def __init__( self, data: AnyArrayLike @@ -45,16 +49,16 @@ class TimedeltaIndex(DatetimeTimedeltaMixin, TimedeltaIndexProperties): ): ... # various ignores needed for mypy, as we do want to restrict what can be used in # arithmetic for these types - @overload # type: ignore[override] + @overload def __add__(self, other: Period) -> PeriodIndex: ... @overload def __add__(self, other: DatetimeIndex) -> DatetimeIndex: ... @overload def __add__(self, other: Timedelta | TimedeltaIndex) -> TimedeltaIndex: ... def __radd__(self, other: Timestamp | DatetimeIndex) -> DatetimeIndex: ... # type: ignore[override] - def __sub__(self, other: Timedelta | TimedeltaIndex) -> TimedeltaIndex: ... # type: ignore[override] - def __mul__(self, other: num) -> TimedeltaIndex: ... # type: ignore[override] - def __truediv__(self, other: num) -> TimedeltaIndex: ... # type: ignore[override] + def __sub__(self, other: Timedelta | TimedeltaIndex) -> TimedeltaIndex: ... + def __mul__(self, other: num) -> TimedeltaIndex: ... + def __truediv__(self, other: num) -> TimedeltaIndex: ... def astype(self, dtype, copy: bool = ...): ... def get_value(self, series, key): ... def get_loc(self, key, tolerance=...): ... diff --git a/tests/test_indexes.py b/tests/test_indexes.py index c4c7c48e6..5c2bb46f3 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -761,3 +761,96 @@ def test_index_operators() -> None: 10 ^ i1, # type:ignore[operator] # pyright: ignore[reportGeneralTypeIssues] Never, ) + + +def test_getitem() -> None: + # GH 536 + ip = pd.period_range(start="2022-06-01", periods=10) + check(assert_type(ip, pd.PeriodIndex), pd.PeriodIndex, pd.Period) + check(assert_type(ip[0], pd.Period), pd.Period) + check(assert_type(ip[[0, 2, 4]], pd.PeriodIndex), pd.PeriodIndex, pd.Period) + + idt = pd.DatetimeIndex(["2022-08-14", "2022-08-20", "2022-08-24"]) + check(assert_type(idt, pd.DatetimeIndex), pd.DatetimeIndex, pd.Timestamp) + check(assert_type(idt[0], pd.Timestamp), pd.Timestamp) + check(assert_type(idt[[0, 2]], pd.DatetimeIndex), pd.DatetimeIndex, pd.Timestamp) + + itd = pd.date_range("1/1/2021", "1/5/2021") - pd.Timestamp("1/3/2019") + check(assert_type(itd, pd.TimedeltaIndex), pd.TimedeltaIndex, pd.Timedelta) + check(assert_type(itd[0], pd.Timedelta), pd.Timedelta) + check( + assert_type(itd[[0, 2, 4]], pd.TimedeltaIndex), pd.TimedeltaIndex, pd.Timedelta + ) + + iini = pd.interval_range(0, 10) + check( + assert_type(iini, "pd.IntervalIndex[pd.Interval[int]]"), + pd.IntervalIndex, + pd.Interval, + ) + check(assert_type(iini[0], "pd.Interval[int]"), pd.Interval) + check( + assert_type(iini[[0, 2, 4]], "pd.IntervalIndex[pd.Interval[int]]"), + pd.IntervalIndex, + pd.Interval, + ) + + iinf = pd.interval_range(0.0, 10) + check( + assert_type(iinf, "pd.IntervalIndex[pd.Interval[float]]"), + pd.IntervalIndex, + pd.Interval, + ) + check(assert_type(iinf[0], "pd.Interval[float]"), pd.Interval) + check( + assert_type(iinf[[0, 2, 4]], "pd.IntervalIndex[pd.Interval[float]]"), + pd.IntervalIndex, + pd.Interval, + ) + + iints = pd.interval_range(dt.datetime(2000, 1, 1), dt.datetime(2010, 1, 1), 5) + check( + assert_type( + iints, + "pd.IntervalIndex[pd.Interval[pd.Timestamp]]", + ), + pd.IntervalIndex, + pd.Interval, + ) + check(assert_type(iints[0], "pd.Interval[pd.Timestamp]"), pd.Interval) + check( + assert_type(iints[[0, 2, 4]], "pd.IntervalIndex[pd.Interval[pd.Timestamp]]"), + pd.IntervalIndex, + pd.Interval, + ) + + iintd = pd.interval_range(pd.Timedelta("1D"), pd.Timedelta("10D")) + check( + assert_type( + iintd, + "pd.IntervalIndex[pd.Interval[pd.Timedelta]]", + ), + pd.IntervalIndex, + pd.Interval, + ) + check(assert_type(iintd[0], "pd.Interval[pd.Timedelta]"), pd.Interval) + check( + assert_type(iintd[[0, 2, 4]], "pd.IntervalIndex[pd.Interval[pd.Timedelta]]"), + pd.IntervalIndex, + pd.Interval, + ) + + iri = pd.RangeIndex(0, 10) + check(assert_type(iri, pd.RangeIndex), pd.RangeIndex, int) + check(assert_type(iri[0], int), int) + check(assert_type(iri[[0, 2, 4]], pd.Index), pd.Index, int) + + mi = pd.MultiIndex.from_product([["a", "b"], ["c", "d"]], names=["ab", "cd"]) + check(assert_type(mi, pd.MultiIndex), pd.MultiIndex) + check(assert_type(mi[0], tuple), tuple) + check(assert_type(mi[[0, 2]], pd.MultiIndex), pd.MultiIndex, tuple) + + i0 = pd.Index(["a", "b", "c"]) + check(assert_type(i0, pd.Index), pd.Index) + check(assert_type(i0[0], Scalar), str) + check(assert_type(i0[[0, 2]], pd.Index), pd.Index, str) From 0db8a00a7bd5c7f023991459d79f372df99b1ab3 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Wed, 15 Feb 2023 14:12:04 -0500 Subject: [PATCH 2/2] create TypeVar for the mixin --- pandas-stubs/core/indexes/base.pyi | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index 137cf82c7..e34cec605 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -9,6 +9,7 @@ from typing import ( ClassVar, Generic, Literal, + TypeVar, overload, ) @@ -50,11 +51,12 @@ class InvalidIndexError(Exception): ... _str = str +_IndexGetitemMixinT = TypeVar("_IndexGetitemMixinT", bound=_IndexGetitemMixin) + class _IndexGetitemMixin(Generic[S1]): - # type ignore needed because it doesn't like the type of self @overload - def __getitem__( # type: ignore[misc] - self: IndexT, + def __getitem__( + self: _IndexGetitemMixinT, idx: slice | np_ndarray_anyint | Sequence[int] @@ -62,7 +64,7 @@ class _IndexGetitemMixin(Generic[S1]): | Series[bool] | Sequence[bool] | np_ndarray_bool, - ) -> IndexT: ... + ) -> _IndexGetitemMixinT: ... @overload def __getitem__(self, idx: int) -> S1: ...