Skip to content

Commit f8a37a7

Browse files
authored
TYP: Use Self instead of class-bound TypeVar II (pandas/core/arrays/) (#51497)
TYP: Use Self for type checking (pandas/core/arrays/)
1 parent c8ea34c commit f8a37a7

File tree

9 files changed

+136
-209
lines changed

9 files changed

+136
-209
lines changed

pandas/core/arrays/_mixins.py

+17-30
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
Any,
77
Literal,
88
Sequence,
9-
TypeVar,
109
cast,
1110
overload,
1211
)
@@ -23,11 +22,11 @@
2322
PositionalIndexer2D,
2423
PositionalIndexerTuple,
2524
ScalarIndexer,
25+
Self,
2626
SequenceIndexer,
2727
Shape,
2828
TakeIndexer,
2929
npt,
30-
type_t,
3130
)
3231
from pandas.errors import AbstractMethodError
3332
from pandas.util._decorators import doc
@@ -61,10 +60,6 @@
6160
from pandas.core.indexers import check_array_indexer
6261
from pandas.core.sorting import nargminmax
6362

64-
NDArrayBackedExtensionArrayT = TypeVar(
65-
"NDArrayBackedExtensionArrayT", bound="NDArrayBackedExtensionArray"
66-
)
67-
6863
if TYPE_CHECKING:
6964
from pandas._typing import (
7065
NumpySorter,
@@ -153,13 +148,13 @@ def view(self, dtype: Dtype | None = None) -> ArrayLike:
153148
return arr.view(dtype=dtype) # type: ignore[arg-type]
154149

155150
def take(
156-
self: NDArrayBackedExtensionArrayT,
151+
self,
157152
indices: TakeIndexer,
158153
*,
159154
allow_fill: bool = False,
160155
fill_value: Any = None,
161156
axis: AxisInt = 0,
162-
) -> NDArrayBackedExtensionArrayT:
157+
) -> Self:
163158
if allow_fill:
164159
fill_value = self._validate_scalar(fill_value)
165160

@@ -218,17 +213,17 @@ def argmax(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[overri
218213
raise NotImplementedError
219214
return nargminmax(self, "argmax", axis=axis)
220215

221-
def unique(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT:
216+
def unique(self) -> Self:
222217
new_data = unique(self._ndarray)
223218
return self._from_backing_data(new_data)
224219

225220
@classmethod
226221
@doc(ExtensionArray._concat_same_type)
227222
def _concat_same_type(
228-
cls: type[NDArrayBackedExtensionArrayT],
229-
to_concat: Sequence[NDArrayBackedExtensionArrayT],
223+
cls,
224+
to_concat: Sequence[Self],
230225
axis: AxisInt = 0,
231-
) -> NDArrayBackedExtensionArrayT:
226+
) -> Self:
232227
dtypes = {str(x.dtype) for x in to_concat}
233228
if len(dtypes) != 1:
234229
raise ValueError("to_concat must have the same dtype (tz)", dtypes)
@@ -268,15 +263,15 @@ def __getitem__(self, key: ScalarIndexer) -> Any:
268263

269264
@overload
270265
def __getitem__(
271-
self: NDArrayBackedExtensionArrayT,
266+
self,
272267
key: SequenceIndexer | PositionalIndexerTuple,
273-
) -> NDArrayBackedExtensionArrayT:
268+
) -> Self:
274269
...
275270

276271
def __getitem__(
277-
self: NDArrayBackedExtensionArrayT,
272+
self,
278273
key: PositionalIndexer2D,
279-
) -> NDArrayBackedExtensionArrayT | Any:
274+
) -> Self | Any:
280275
if lib.is_integer(key):
281276
# fast-path
282277
result = self._ndarray[key]
@@ -303,9 +298,7 @@ def _fill_mask_inplace(
303298
func(self._ndarray.T, limit=limit, mask=mask.T)
304299

305300
@doc(ExtensionArray.fillna)
306-
def fillna(
307-
self: NDArrayBackedExtensionArrayT, value=None, method=None, limit=None
308-
) -> NDArrayBackedExtensionArrayT:
301+
def fillna(self, value=None, method=None, limit=None) -> Self:
309302
value, method = validate_fillna_kwargs(
310303
value, method, validate_scalar_dict_value=False
311304
)
@@ -369,9 +362,7 @@ def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
369362

370363
np.putmask(self._ndarray, mask, value)
371364

372-
def _where(
373-
self: NDArrayBackedExtensionArrayT, mask: npt.NDArray[np.bool_], value
374-
) -> NDArrayBackedExtensionArrayT:
365+
def _where(self: Self, mask: npt.NDArray[np.bool_], value) -> Self:
375366
"""
376367
Analogue to np.where(mask, self, value)
377368
@@ -393,9 +384,7 @@ def _where(
393384
# ------------------------------------------------------------------------
394385
# Index compat methods
395386

396-
def insert(
397-
self: NDArrayBackedExtensionArrayT, loc: int, item
398-
) -> NDArrayBackedExtensionArrayT:
387+
def insert(self, loc: int, item) -> Self:
399388
"""
400389
Make new ExtensionArray inserting new item at location. Follows
401390
Python list.append semantics for negative values.
@@ -461,10 +450,10 @@ def value_counts(self, dropna: bool = True) -> Series:
461450
return Series(result._values, index=index, name=result.name)
462451

463452
def _quantile(
464-
self: NDArrayBackedExtensionArrayT,
453+
self,
465454
qs: npt.NDArray[np.float64],
466455
interpolation: str,
467-
) -> NDArrayBackedExtensionArrayT:
456+
) -> Self:
468457
# TODO: disable for Categorical if not ordered?
469458

470459
mask = np.asarray(self.isna())
@@ -488,9 +477,7 @@ def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray:
488477
# numpy-like methods
489478

490479
@classmethod
491-
def _empty(
492-
cls: type_t[NDArrayBackedExtensionArrayT], shape: Shape, dtype: ExtensionDtype
493-
) -> NDArrayBackedExtensionArrayT:
480+
def _empty(cls, shape: Shape, dtype: ExtensionDtype) -> Self:
494481
"""
495482
Analogous to np.empty(shape, dtype=dtype)
496483

pandas/core/arrays/arrow/array.py

+14-22
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
Callable,
99
Literal,
1010
Sequence,
11-
TypeVar,
1211
cast,
1312
)
1413

@@ -24,6 +23,7 @@
2423
NpDtype,
2524
PositionalIndexer,
2625
Scalar,
26+
Self,
2727
SortKind,
2828
TakeIndexer,
2929
TimeAmbiguous,
@@ -140,8 +140,6 @@ def floordiv_compat(
140140

141141
from pandas import Series
142142

143-
ArrowExtensionArrayT = TypeVar("ArrowExtensionArrayT", bound="ArrowExtensionArray")
144-
145143

146144
def get_unit_from_pa_dtype(pa_dtype):
147145
# https://github.com/pandas-dev/pandas/pull/50998#discussion_r1100344804
@@ -419,16 +417,16 @@ def __array__(self, dtype: NpDtype | None = None) -> np.ndarray:
419417
"""Correctly construct numpy arrays when passed to `np.asarray()`."""
420418
return self.to_numpy(dtype=dtype)
421419

422-
def __invert__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
420+
def __invert__(self) -> Self:
423421
return type(self)(pc.invert(self._pa_array))
424422

425-
def __neg__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
423+
def __neg__(self) -> Self:
426424
return type(self)(pc.negate_checked(self._pa_array))
427425

428-
def __pos__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
426+
def __pos__(self) -> Self:
429427
return type(self)(self._pa_array)
430428

431-
def __abs__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
429+
def __abs__(self) -> Self:
432430
return type(self)(pc.abs_checked(self._pa_array))
433431

434432
# GH 42600: __getstate__/__setstate__ not necessary once
@@ -733,7 +731,7 @@ def argmin(self, skipna: bool = True) -> int:
733731
def argmax(self, skipna: bool = True) -> int:
734732
return self._argmin_max(skipna, "max")
735733

736-
def copy(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
734+
def copy(self) -> Self:
737735
"""
738736
Return a shallow copy of the array.
739737
@@ -745,7 +743,7 @@ def copy(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
745743
"""
746744
return type(self)(self._pa_array)
747745

748-
def dropna(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
746+
def dropna(self) -> Self:
749747
"""
750748
Return ArrowExtensionArray without NA values.
751749
@@ -757,11 +755,11 @@ def dropna(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
757755

758756
@doc(ExtensionArray.fillna)
759757
def fillna(
760-
self: ArrowExtensionArrayT,
758+
self,
761759
value: object | ArrayLike | None = None,
762760
method: FillnaOptions | None = None,
763761
limit: int | None = None,
764-
) -> ArrowExtensionArrayT:
762+
) -> Self:
765763
value, method = validate_fillna_kwargs(value, method)
766764

767765
if limit is not None:
@@ -877,9 +875,7 @@ def reshape(self, *args, **kwargs):
877875
f"as backed by a 1D pyarrow.ChunkedArray."
878876
)
879877

880-
def round(
881-
self: ArrowExtensionArrayT, decimals: int = 0, *args, **kwargs
882-
) -> ArrowExtensionArrayT:
878+
def round(self, decimals: int = 0, *args, **kwargs) -> Self:
883879
"""
884880
Round each value in the array a to the given number of decimals.
885881
@@ -1052,7 +1048,7 @@ def to_numpy(
10521048
result[self.isna()] = na_value
10531049
return result
10541050

1055-
def unique(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
1051+
def unique(self) -> Self:
10561052
"""
10571053
Compute the ArrowExtensionArray of unique values.
10581054
@@ -1123,9 +1119,7 @@ def value_counts(self, dropna: bool = True) -> Series:
11231119
return Series(counts, index=index, name="count")
11241120

11251121
@classmethod
1126-
def _concat_same_type(
1127-
cls: type[ArrowExtensionArrayT], to_concat
1128-
) -> ArrowExtensionArrayT:
1122+
def _concat_same_type(cls, to_concat) -> Self:
11291123
"""
11301124
Concatenate multiple ArrowExtensionArrays.
11311125
@@ -1456,9 +1450,7 @@ def _rank(
14561450

14571451
return type(self)(result)
14581452

1459-
def _quantile(
1460-
self: ArrowExtensionArrayT, qs: npt.NDArray[np.float64], interpolation: str
1461-
) -> ArrowExtensionArrayT:
1453+
def _quantile(self, qs: npt.NDArray[np.float64], interpolation: str) -> Self:
14621454
"""
14631455
Compute the quantiles of self for each quantile in `qs`.
14641456
@@ -1495,7 +1487,7 @@ def _quantile(
14951487

14961488
return type(self)(result)
14971489

1498-
def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArrayT:
1490+
def _mode(self, dropna: bool = True) -> Self:
14991491
"""
15001492
Returns the mode(s) of the ExtensionArray.
15011493

0 commit comments

Comments
 (0)