Skip to content

Commit b1a6421

Browse files
TYP: copy method of EAs (#37816)
1 parent 1f42d45 commit b1a6421

File tree

4 files changed

+39
-19
lines changed

4 files changed

+39
-19
lines changed

pandas/core/arrays/_mixins.py

+23-11
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from pandas.core.construction import extract_array
2121
from pandas.core.indexers import check_array_indexer
2222

23-
_T = TypeVar("_T", bound="NDArrayBackedExtensionArray")
23+
NDArrayBackedExtensionArrayT = TypeVar(
24+
"NDArrayBackedExtensionArrayT", bound="NDArrayBackedExtensionArray"
25+
)
2426

2527

2628
class NDArrayBackedExtensionArray(ExtensionArray):
@@ -30,7 +32,9 @@ class NDArrayBackedExtensionArray(ExtensionArray):
3032

3133
_ndarray: np.ndarray
3234

33-
def _from_backing_data(self: _T, arr: np.ndarray) -> _T:
35+
def _from_backing_data(
36+
self: NDArrayBackedExtensionArrayT, arr: np.ndarray
37+
) -> NDArrayBackedExtensionArrayT:
3438
"""
3539
Construct a new ExtensionArray `new_array` with `arr` as its _ndarray.
3640
@@ -52,13 +56,13 @@ def _validate_scalar(self, value):
5256
# ------------------------------------------------------------------------
5357

5458
def take(
55-
self: _T,
59+
self: NDArrayBackedExtensionArrayT,
5660
indices: Sequence[int],
5761
*,
5862
allow_fill: bool = False,
5963
fill_value: Any = None,
6064
axis: int = 0,
61-
) -> _T:
65+
) -> NDArrayBackedExtensionArrayT:
6266
if allow_fill:
6367
fill_value = self._validate_fill_value(fill_value)
6468

@@ -113,16 +117,20 @@ def size(self) -> int:
113117
def nbytes(self) -> int:
114118
return self._ndarray.nbytes
115119

116-
def reshape(self: _T, *args, **kwargs) -> _T:
120+
def reshape(
121+
self: NDArrayBackedExtensionArrayT, *args, **kwargs
122+
) -> NDArrayBackedExtensionArrayT:
117123
new_data = self._ndarray.reshape(*args, **kwargs)
118124
return self._from_backing_data(new_data)
119125

120-
def ravel(self: _T, *args, **kwargs) -> _T:
126+
def ravel(
127+
self: NDArrayBackedExtensionArrayT, *args, **kwargs
128+
) -> NDArrayBackedExtensionArrayT:
121129
new_data = self._ndarray.ravel(*args, **kwargs)
122130
return self._from_backing_data(new_data)
123131

124132
@property
125-
def T(self: _T) -> _T:
133+
def T(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT:
126134
new_data = self._ndarray.T
127135
return self._from_backing_data(new_data)
128136

@@ -138,11 +146,13 @@ def equals(self, other) -> bool:
138146
def _values_for_argsort(self):
139147
return self._ndarray
140148

141-
def copy(self: _T) -> _T:
149+
def copy(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT:
142150
new_data = self._ndarray.copy()
143151
return self._from_backing_data(new_data)
144152

145-
def repeat(self: _T, repeats, axis=None) -> _T:
153+
def repeat(
154+
self: NDArrayBackedExtensionArrayT, repeats, axis=None
155+
) -> NDArrayBackedExtensionArrayT:
146156
"""
147157
Repeat elements of an array.
148158
@@ -154,7 +164,7 @@ def repeat(self: _T, repeats, axis=None) -> _T:
154164
new_data = self._ndarray.repeat(repeats, axis=axis)
155165
return self._from_backing_data(new_data)
156166

157-
def unique(self: _T) -> _T:
167+
def unique(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT:
158168
new_data = unique(self._ndarray)
159169
return self._from_backing_data(new_data)
160170

@@ -216,7 +226,9 @@ def __getitem__(self, key):
216226
return result
217227

218228
@doc(ExtensionArray.fillna)
219-
def fillna(self: _T, value=None, method=None, limit=None) -> _T:
229+
def fillna(
230+
self: NDArrayBackedExtensionArrayT, value=None, method=None, limit=None
231+
) -> NDArrayBackedExtensionArrayT:
220232
value, method = validate_fillna_kwargs(value, method)
221233

222234
mask = self.isna()

pandas/core/arrays/base.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import numpy as np
1313

1414
from pandas._libs import lib
15-
from pandas._typing import ArrayLike, Shape
15+
from pandas._typing import ArrayLike, Shape, TypeVar
1616
from pandas.compat import set_function_name
1717
from pandas.compat.numpy import function as nv
1818
from pandas.errors import AbstractMethodError
@@ -37,6 +37,8 @@
3737

3838
_extension_array_shared_docs: Dict[str, str] = dict()
3939

40+
ExtensionArrayT = TypeVar("ExtensionArrayT", bound="ExtensionArray")
41+
4042

4143
class ExtensionArray:
4244
"""
@@ -1016,7 +1018,7 @@ def take(self, indices, allow_fill=False, fill_value=None):
10161018
# pandas.api.extensions.take
10171019
raise AbstractMethodError(self)
10181020

1019-
def copy(self) -> "ExtensionArray":
1021+
def copy(self: ExtensionArrayT) -> ExtensionArrayT:
10201022
"""
10211023
Return a copy of the array.
10221024

pandas/core/arrays/interval.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import operator
22
from operator import le, lt
33
import textwrap
4-
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
4+
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union, cast
55

66
import numpy as np
77

@@ -56,6 +56,8 @@
5656
from pandas import Index
5757
from pandas.core.arrays import DatetimeArray, TimedeltaArray
5858

59+
IntervalArrayT = TypeVar("IntervalArrayT", bound="IntervalArray")
60+
5961
_interval_shared_docs = {}
6062

6163
_shared_docs_kwargs = dict(
@@ -745,7 +747,7 @@ def _concat_same_type(cls, to_concat):
745747
combined = _get_combined_data(left, right) # TODO: 1-stage concat
746748
return cls._simple_new(combined, closed=closed)
747749

748-
def copy(self):
750+
def copy(self: IntervalArrayT) -> IntervalArrayT:
749751
"""
750752
Return a copy of the array.
751753

pandas/core/arrays/sparse/array.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections import abc
55
import numbers
66
import operator
7-
from typing import Any, Callable, Union
7+
from typing import Any, Callable, Type, TypeVar, Union
88
import warnings
99

1010
import numpy as np
@@ -56,6 +56,7 @@
5656
# ----------------------------------------------------------------------------
5757
# Array
5858

59+
SparseArrayT = TypeVar("SparseArrayT", bound="SparseArray")
5960

6061
_sparray_doc_kwargs = dict(klass="SparseArray")
6162

@@ -397,8 +398,11 @@ def __init__(
397398

398399
@classmethod
399400
def _simple_new(
400-
cls, sparse_array: np.ndarray, sparse_index: SparseIndex, dtype: SparseDtype
401-
) -> "SparseArray":
401+
cls: Type[SparseArrayT],
402+
sparse_array: np.ndarray,
403+
sparse_index: SparseIndex,
404+
dtype: SparseDtype,
405+
) -> SparseArrayT:
402406
new = object.__new__(cls)
403407
new._sparse_index = sparse_index
404408
new._sparse_values = sparse_array
@@ -937,7 +941,7 @@ def searchsorted(self, v, side="left", sorter=None):
937941
v = np.asarray(v)
938942
return np.asarray(self, dtype=self.dtype.subtype).searchsorted(v, side, sorter)
939943

940-
def copy(self):
944+
def copy(self: SparseArrayT) -> SparseArrayT:
941945
values = self.sp_values.copy()
942946
return self._simple_new(values, self.sp_index, self.dtype)
943947

0 commit comments

Comments
 (0)