diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index ddcf225d3585f..ffb892ed7e505 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -20,7 +20,9 @@ from pandas.core.construction import extract_array from pandas.core.indexers import check_array_indexer -_T = TypeVar("_T", bound="NDArrayBackedExtensionArray") +NDArrayBackedExtensionArrayT = TypeVar( + "NDArrayBackedExtensionArrayT", bound="NDArrayBackedExtensionArray" +) class NDArrayBackedExtensionArray(ExtensionArray): @@ -30,7 +32,9 @@ class NDArrayBackedExtensionArray(ExtensionArray): _ndarray: np.ndarray - def _from_backing_data(self: _T, arr: np.ndarray) -> _T: + def _from_backing_data( + self: NDArrayBackedExtensionArrayT, arr: np.ndarray + ) -> NDArrayBackedExtensionArrayT: """ Construct a new ExtensionArray `new_array` with `arr` as its _ndarray. @@ -52,13 +56,13 @@ def _validate_scalar(self, value): # ------------------------------------------------------------------------ def take( - self: _T, + self: NDArrayBackedExtensionArrayT, indices: Sequence[int], *, allow_fill: bool = False, fill_value: Any = None, axis: int = 0, - ) -> _T: + ) -> NDArrayBackedExtensionArrayT: if allow_fill: fill_value = self._validate_fill_value(fill_value) @@ -113,16 +117,20 @@ def size(self) -> int: def nbytes(self) -> int: return self._ndarray.nbytes - def reshape(self: _T, *args, **kwargs) -> _T: + def reshape( + self: NDArrayBackedExtensionArrayT, *args, **kwargs + ) -> NDArrayBackedExtensionArrayT: new_data = self._ndarray.reshape(*args, **kwargs) return self._from_backing_data(new_data) - def ravel(self: _T, *args, **kwargs) -> _T: + def ravel( + self: NDArrayBackedExtensionArrayT, *args, **kwargs + ) -> NDArrayBackedExtensionArrayT: new_data = self._ndarray.ravel(*args, **kwargs) return self._from_backing_data(new_data) @property - def T(self: _T) -> _T: + def T(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT: new_data = self._ndarray.T return self._from_backing_data(new_data) @@ -138,11 +146,13 @@ def equals(self, other) -> bool: def _values_for_argsort(self): return self._ndarray - def copy(self: _T) -> _T: + def copy(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT: new_data = self._ndarray.copy() return self._from_backing_data(new_data) - def repeat(self: _T, repeats, axis=None) -> _T: + def repeat( + self: NDArrayBackedExtensionArrayT, repeats, axis=None + ) -> NDArrayBackedExtensionArrayT: """ Repeat elements of an array. @@ -154,7 +164,7 @@ def repeat(self: _T, repeats, axis=None) -> _T: new_data = self._ndarray.repeat(repeats, axis=axis) return self._from_backing_data(new_data) - def unique(self: _T) -> _T: + def unique(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT: new_data = unique(self._ndarray) return self._from_backing_data(new_data) @@ -216,7 +226,9 @@ def __getitem__(self, key): return result @doc(ExtensionArray.fillna) - def fillna(self: _T, value=None, method=None, limit=None) -> _T: + def fillna( + self: NDArrayBackedExtensionArrayT, value=None, method=None, limit=None + ) -> NDArrayBackedExtensionArrayT: value, method = validate_fillna_kwargs(value, method) mask = self.isna() diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index afbddc53804ac..0b01ea305b73c 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -12,7 +12,7 @@ import numpy as np from pandas._libs import lib -from pandas._typing import ArrayLike, Shape +from pandas._typing import ArrayLike, Shape, TypeVar from pandas.compat import set_function_name from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError @@ -37,6 +37,8 @@ _extension_array_shared_docs: Dict[str, str] = dict() +ExtensionArrayT = TypeVar("ExtensionArrayT", bound="ExtensionArray") + class ExtensionArray: """ @@ -1016,7 +1018,7 @@ def take(self, indices, allow_fill=False, fill_value=None): # pandas.api.extensions.take raise AbstractMethodError(self) - def copy(self) -> "ExtensionArray": + def copy(self: ExtensionArrayT) -> ExtensionArrayT: """ Return a copy of the array. diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 977e4abff4287..5c1b4f1d781cd 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -1,7 +1,7 @@ import operator from operator import le, lt import textwrap -from typing import TYPE_CHECKING, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union, cast import numpy as np @@ -56,6 +56,8 @@ from pandas import Index from pandas.core.arrays import DatetimeArray, TimedeltaArray +IntervalArrayT = TypeVar("IntervalArrayT", bound="IntervalArray") + _interval_shared_docs = {} _shared_docs_kwargs = dict( @@ -745,7 +747,7 @@ def _concat_same_type(cls, to_concat): combined = _get_combined_data(left, right) # TODO: 1-stage concat return cls._simple_new(combined, closed=closed) - def copy(self): + def copy(self: IntervalArrayT) -> IntervalArrayT: """ Return a copy of the array. diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index d976526955ac2..e3b28e2f47af2 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -4,7 +4,7 @@ from collections import abc import numbers import operator -from typing import Any, Callable, Union +from typing import Any, Callable, Type, TypeVar, Union import warnings import numpy as np @@ -56,6 +56,7 @@ # ---------------------------------------------------------------------------- # Array +SparseArrayT = TypeVar("SparseArrayT", bound="SparseArray") _sparray_doc_kwargs = dict(klass="SparseArray") @@ -397,8 +398,11 @@ def __init__( @classmethod def _simple_new( - cls, sparse_array: np.ndarray, sparse_index: SparseIndex, dtype: SparseDtype - ) -> "SparseArray": + cls: Type[SparseArrayT], + sparse_array: np.ndarray, + sparse_index: SparseIndex, + dtype: SparseDtype, + ) -> SparseArrayT: new = object.__new__(cls) new._sparse_index = sparse_index new._sparse_values = sparse_array @@ -937,7 +941,7 @@ def searchsorted(self, v, side="left", sorter=None): v = np.asarray(v) return np.asarray(self, dtype=self.dtype.subtype).searchsorted(v, side, sorter) - def copy(self): + def copy(self: SparseArrayT) -> SparseArrayT: values = self.sp_values.copy() return self._simple_new(values, self.sp_index, self.dtype)