diff --git a/pandas/_typing.py b/pandas/_typing.py index 7c74fc54b8d67..d6561176deb71 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -187,3 +187,14 @@ # internals Manager = Union["ArrayManager", "BlockManager"] SingleManager = Union["SingleArrayManager", "SingleBlockManager"] + +# indexing +# PositionalIndexer -> valid 1D positional indexer, e.g. can pass +# to ndarray.__getitem__ +# TODO: add Ellipsis, see +# https://github.com/python/typing/issues/684#issuecomment-548203158 +# https://bugs.python.org/issue41810 +PositionalIndexer = Union[int, np.integer, slice, Sequence[int], np.ndarray] +PositionalIndexer2D = Union[ + PositionalIndexer, Tuple[PositionalIndexer, PositionalIndexer] +] diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 0b208d4f2ab72..d4e5ca00b06dd 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -13,6 +13,7 @@ from pandas._libs import lib from pandas._typing import ( F, + PositionalIndexer2D, Shape, ) from pandas.compat.numpy import function as nv @@ -274,7 +275,8 @@ def _validate_setitem_value(self, value): return value def __getitem__( - self: NDArrayBackedExtensionArrayT, key: int | slice | np.ndarray + self: NDArrayBackedExtensionArrayT, + key: PositionalIndexer2D, ) -> NDArrayBackedExtensionArrayT | Any: if lib.is_integer(key): # fast-path diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index ac3ae6407305c..933b829e0b29f 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -24,6 +24,7 @@ from pandas._typing import ( ArrayLike, Dtype, + PositionalIndexer, Shape, ) from pandas.compat import set_function_name @@ -288,7 +289,7 @@ def _from_factorized(cls, values, original): # Must be a Sequence # ------------------------------------------------------------------------ - def __getitem__(self, item: int | slice | np.ndarray) -> ExtensionArray | Any: + def __getitem__(self, item: PositionalIndexer) -> ExtensionArray | Any: """ Select a subset of self. diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 2e7f18965d2b2..b039df12ee3df 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -47,6 +47,7 @@ Dtype, DtypeObj, NpDtype, + PositionalIndexer2D, ) from pandas.compat.numpy import function as nv from pandas.errors import ( @@ -338,7 +339,7 @@ def __array__(self, dtype: NpDtype | None = None) -> np.ndarray: return self._ndarray def __getitem__( - self, key: int | slice | np.ndarray + self, key: PositionalIndexer2D ) -> DatetimeLikeArrayMixin | DTScalarOrNaT: """ This getitem defers to the underlying array, which by-definition can diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index 92bd031aec3a7..93de1cd91d625 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -17,6 +17,7 @@ ArrayLike, Dtype, NpDtype, + PositionalIndexer, Scalar, type_t, ) @@ -134,7 +135,7 @@ def __init__(self, values: np.ndarray, mask: np.ndarray, copy: bool = False): def dtype(self) -> BaseMaskedDtype: raise AbstractMethodError(self) - def __getitem__(self, item: int | slice | np.ndarray) -> BaseMaskedArray | Any: + def __getitem__(self, item: PositionalIndexer) -> BaseMaskedArray | Any: if is_integer(item): if self._mask[item]: return self.dtype.na_value diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 77f01f895a667..faca868873efa 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -17,6 +17,7 @@ from pandas._typing import ( Dtype, NpDtype, + PositionalIndexer, type_t, ) from pandas.util._decorators import doc @@ -310,7 +311,7 @@ def _concat_same_type(cls, to_concat) -> ArrowStringArray: ) ) - def __getitem__(self, item: Any) -> Any: + def __getitem__(self, item: PositionalIndexer) -> Any: """Select a subset of self. Parameters diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 5529f9f430375..90b5366b1a23e 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -1381,9 +1381,7 @@ def iget(self, col): elif isinstance(col, slice): if col != slice(None): raise NotImplementedError(col) - # error: Invalid index type "List[Any]" for "ExtensionArray"; expected - # type "Union[int, slice, ndarray]" - return self.values[[loc]] # type: ignore[index] + return self.values[[loc]] return self.values[loc] else: if col != 0: