diff --git a/pandas/_typing.py b/pandas/_typing.py index 7f01bcaa1c50e..a4501199fd3a3 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -30,6 +30,7 @@ from typing import final from pandas._libs import Period, Timedelta, Timestamp + from pandas._libs.missing import NAType from pandas.core.dtypes.dtypes import ExtensionDtype @@ -44,9 +45,11 @@ from pandas.core.window.rolling import BaseWindow from pandas.io.formats.format import EngFormatter + else: # typing.final does not exist until py38 final = lambda x: x + NAType = Any # array-like diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 07862e0b9bb48..ae5b6b164fb8d 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Optional, Sequence, Type, TypeVar, Union +from typing import Any, Optional, Sequence, Type, TypeVar, Union, overload import numpy as np @@ -25,6 +25,7 @@ NDArrayBackedExtensionArrayT = TypeVar( "NDArrayBackedExtensionArrayT", bound="NDArrayBackedExtensionArray" ) +EAScalarOrMissing = object # both scalar value and na_value can be any type class NDArrayBackedExtensionArray(ExtensionArray): @@ -214,9 +215,21 @@ def __setitem__(self, key, value): def _validate_setitem_value(self, value): return value + @overload + # error: Overloaded function signatures 1 and 2 overlap with incompatible + # return types [misc] + def __getitem__(self, key: int) -> EAScalarOrMissing: # type: ignore[misc] + ... + + @overload + def __getitem__( + self: NDArrayBackedExtensionArrayT, key: Union[slice, np.ndarray] + ) -> NDArrayBackedExtensionArrayT: + ... + def __getitem__( self: NDArrayBackedExtensionArrayT, key: Union[int, slice, np.ndarray] - ) -> Union[NDArrayBackedExtensionArrayT, Any]: + ) -> Union[NDArrayBackedExtensionArrayT, EAScalarOrMissing]: if lib.is_integer(key): # fast-path result = self._ndarray[key] diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index a0a51495791d1..8cfafd8b52da8 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -20,6 +20,7 @@ TypeVar, Union, cast, + overload, ) import numpy as np @@ -51,6 +52,7 @@ _extension_array_shared_docs: Dict[str, str] = dict() ExtensionArrayT = TypeVar("ExtensionArrayT", bound="ExtensionArray") +EAScalarOrMissing = object # both scalar value and na_value can be any type class ExtensionArray: @@ -256,9 +258,19 @@ def _from_factorized(cls, values, original): # Must be a Sequence # ------------------------------------------------------------------------ + @overload + # error: Overloaded function signatures 1 and 2 overlap with incompatible + # return types [misc] + def __getitem__(self, item: int) -> EAScalarOrMissing: # type: ignore[misc] + ... + + @overload + def __getitem__(self, item: Union[slice, np.ndarray]) -> ExtensionArray: + ... + def __getitem__( self, item: Union[int, slice, np.ndarray] - ) -> Union[ExtensionArray, Any]: + ) -> Union[ExtensionArray, EAScalarOrMissing]: """ Select a subset of self. diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 30674212239bf..06cf8450f9a21 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -13,6 +13,7 @@ TypeVar, Union, cast, + overload, ) import warnings @@ -266,9 +267,19 @@ def __array__(self, dtype=None) -> np.ndarray: return np.array(list(self), dtype=object) return self._ndarray + @overload + def __getitem__(self, key: int) -> DTScalarOrNaT: + ... + + @overload + def __getitem__( + self: DatetimeLikeArrayT, key: Union[slice, np.ndarray] + ) -> DatetimeLikeArrayT: + ... + def __getitem__( - self, key: Union[int, slice, np.ndarray] - ) -> Union[DatetimeLikeArrayMixin, DTScalarOrNaT]: + self: DatetimeLikeArrayT, key: Union[int, slice, np.ndarray] + ) -> Union[DatetimeLikeArrayT, DTScalarOrNaT]: """ This getitem defers to the underlying array, which by-definition can only handle list-likes, slices, and integer scalars diff --git a/pandas/core/arrays/datetimes.py b/pandas/core/arrays/datetimes.py index 057162fedaa98..a05dc717f83c1 100644 --- a/pandas/core/arrays/datetimes.py +++ b/pandas/core/arrays/datetimes.py @@ -1,5 +1,5 @@ from datetime import datetime, time, timedelta, tzinfo -from typing import Optional, Union, cast +from typing import Optional, Union import warnings import numpy as np @@ -444,11 +444,9 @@ def _generate_range( ) if not left_closed and len(index) and index[0] == start: - # TODO: overload DatetimeLikeArrayMixin.__getitem__ - index = cast(DatetimeArray, index[1:]) + index = index[1:] if not right_closed and len(index) and index[-1] == end: - # TODO: overload DatetimeLikeArrayMixin.__getitem__ - index = cast(DatetimeArray, index[:-1]) + index = index[:-1] dtype = tz_to_dtype(tz) return cls._simple_new(index.asi8, freq=freq, dtype=dtype) diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index caed932cd7857..ca4cd0b63722b 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -1,11 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) import numpy as np from pandas._libs import lib, missing as libmissing -from pandas._typing import Scalar +from pandas._typing import NAType, Scalar from pandas.errors import AbstractMethodError from pandas.util._decorators import cache_readonly, doc @@ -30,6 +39,8 @@ BaseMaskedArrayT = TypeVar("BaseMaskedArrayT", bound="BaseMaskedArray") +# scalar value is a Python scalar, missing value is pd.NA +ScalarOrNAType = Union[Scalar, NAType] class BaseMaskedDtype(ExtensionDtype): @@ -102,9 +113,19 @@ def __init__(self, values: np.ndarray, mask: np.ndarray, copy: bool = False): def dtype(self) -> BaseMaskedDtype: raise AbstractMethodError(self) + @overload + # error: Overloaded function signatures 1 and 2 overlap with incompatible return + # types [misc] + def __getitem__(self, item: int) -> ScalarOrNAType: # type: ignore[misc] + ... + + @overload + def __getitem__(self, item: Union[slice, np.ndarray]) -> BaseMaskedArray: + ... + def __getitem__( self, item: Union[int, slice, np.ndarray] - ) -> Union[BaseMaskedArray, Any]: + ) -> Union[BaseMaskedArray, ScalarOrNAType]: if is_integer(item): if self._mask[item]: return self.dtype.na_value diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 9ea0ff323a33d..e8a2565bed9b6 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -3222,14 +3222,8 @@ def _get_nearest_indexer(self, target: "Index", limit, tolerance) -> np.ndarray: right_indexer = self.get_indexer(target, "backfill", limit=limit) target_values = target._values - # error: Unsupported left operand type for - ("ExtensionArray") - left_distances = np.abs( - self._values[left_indexer] - target_values # type: ignore[operator] - ) - # error: Unsupported left operand type for - ("ExtensionArray") - right_distances = np.abs( - self._values[right_indexer] - target_values # type: ignore[operator] - ) + left_distances = np.abs(self._values[left_indexer] - target_values) + right_distances = np.abs(self._values[right_indexer] - target_values) op = operator.lt if self.is_monotonic_increasing else operator.le indexer = np.where( @@ -3248,8 +3242,7 @@ def _filter_indexer_tolerance( indexer: np.ndarray, tolerance, ) -> np.ndarray: - # error: Unsupported left operand type for - ("ExtensionArray") - distance = abs(self._values[indexer] - target) # type: ignore[operator] + distance = abs(self._values[indexer] - target) indexer = np.where(distance <= tolerance, indexer, -1) return indexer @@ -4546,9 +4539,8 @@ def asof_locs(self, where: "Index", mask) -> np.ndarray: result = np.arange(len(self))[mask].take(locs) - # TODO: overload return type of ExtensionArray.__getitem__ - first_value = cast(Any, self._values[mask.argmax()]) - result[(locs == 0) & (where._values < first_value)] = -1 + first = mask.argmax() + result[(locs == 0) & (where._values < self._values[first])] = -1 return result diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index 67dea6d1df8ed..c0a3b95499b3d 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import Any, cast +from typing import Any import numpy as np @@ -673,10 +673,7 @@ def difference(self, other, sort=None): if self.equals(other): # pass an empty PeriodArray with the appropriate dtype - - # TODO: overload DatetimeLikeArrayMixin.__getitem__ - values = cast(PeriodArray, self._data[:0]) - return type(self)._simple_new(values, name=self.name) + return type(self)._simple_new(self._data[:0], name=self.name) if is_object_dtype(other): return self.astype(object).difference(other).astype(self.dtype) diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 08fe67a2979da..c170ebb844c5a 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -511,6 +511,7 @@ def make_a_block(nv, ref_loc): return [block] # ndim > 1 + assert isinstance(new_values, np.ndarray), type(new_values) new_blocks = [] for i, ref_loc in enumerate(self.mgr_locs): m = mask[i]