Skip to content

Commit 080add1

Browse files
TYP: Typing improvements for Index (#59105)
* Typing improvements for Index * better numpy type hints for Index.delete * replace some hints with literals, move slice_type to _typing.py
1 parent 18a3eec commit 080add1

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

pandas/_typing.py

+2
Original file line numberDiff line numberDiff line change
@@ -526,3 +526,5 @@ def closed(self) -> bool:
526526

527527
# maintaine the sub-type of any hashable sequence
528528
SequenceT = TypeVar("SequenceT", bound=Sequence[Hashable])
529+
530+
SliceType = Optional[Hashable]

pandas/core/indexes/base.py

+30-11
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@
4545
ArrayLike,
4646
Axes,
4747
Axis,
48+
AxisInt,
4849
DropKeep,
50+
Dtype,
4951
DtypeObj,
5052
F,
5153
IgnoreRaise,
@@ -57,6 +59,7 @@
5759
ReindexMethod,
5860
Self,
5961
Shape,
62+
SliceType,
6063
npt,
6164
)
6265
from pandas.compat.numpy import function as nv
@@ -1087,7 +1090,7 @@ def view(self, cls=None):
10871090
result._id = self._id
10881091
return result
10891092

1090-
def astype(self, dtype, copy: bool = True):
1093+
def astype(self, dtype: Dtype, copy: bool = True):
10911094
"""
10921095
Create an Index with values cast to dtypes.
10931096
@@ -2957,7 +2960,7 @@ def _dti_setop_align_tzs(self, other: Index, setop: str_t) -> tuple[Index, Index
29572960
return self, other
29582961

29592962
@final
2960-
def union(self, other, sort=None):
2963+
def union(self, other, sort: bool | None = None):
29612964
"""
29622965
Form the union of two Index objects.
29632966
@@ -3334,7 +3337,7 @@ def _intersection_via_get_indexer(
33343337
return result
33353338

33363339
@final
3337-
def difference(self, other, sort=None):
3340+
def difference(self, other, sort: bool | None = None):
33383341
"""
33393342
Return a new Index with elements of index not in `other`.
33403343
@@ -3420,7 +3423,12 @@ def _wrap_difference_result(self, other, result):
34203423
# We will override for MultiIndex to handle empty results
34213424
return self._wrap_setop_result(other, result)
34223425

3423-
def symmetric_difference(self, other, result_name=None, sort=None):
3426+
def symmetric_difference(
3427+
self,
3428+
other,
3429+
result_name: abc.Hashable | None = None,
3430+
sort: bool | None = None,
3431+
):
34243432
"""
34253433
Compute the symmetric difference of two Index objects.
34263434
@@ -6389,7 +6397,7 @@ def _transform_index(self, func, *, level=None) -> Index:
63896397
items = [func(x) for x in self]
63906398
return Index(items, name=self.name, tupleize_cols=False)
63916399

6392-
def isin(self, values, level=None) -> npt.NDArray[np.bool_]:
6400+
def isin(self, values, level: str_t | int | None = None) -> npt.NDArray[np.bool_]:
63936401
"""
63946402
Return a boolean array where the index values are in `values`.
63956403
@@ -6687,7 +6695,12 @@ def get_slice_bound(self, label, side: Literal["left", "right"]) -> int:
66876695
else:
66886696
return slc
66896697

6690-
def slice_locs(self, start=None, end=None, step=None) -> tuple[int, int]:
6698+
def slice_locs(
6699+
self,
6700+
start: SliceType = None,
6701+
end: SliceType = None,
6702+
step: int | None = None,
6703+
) -> tuple[int, int]:
66916704
"""
66926705
Compute slice locations for input labels.
66936706
@@ -6781,7 +6794,9 @@ def slice_locs(self, start=None, end=None, step=None) -> tuple[int, int]:
67816794

67826795
return start_slice, end_slice
67836796

6784-
def delete(self, loc) -> Self:
6797+
def delete(
6798+
self, loc: int | np.integer | list[int] | npt.NDArray[np.integer]
6799+
) -> Self:
67856800
"""
67866801
Make new Index with passed location(-s) deleted.
67876802
@@ -7227,7 +7242,9 @@ def _maybe_disable_logical_methods(self, opname: str_t) -> None:
72277242
raise TypeError(f"cannot perform {opname} with {type(self).__name__}")
72287243

72297244
@Appender(IndexOpsMixin.argmin.__doc__)
7230-
def argmin(self, axis=None, skipna: bool = True, *args, **kwargs) -> int:
7245+
def argmin(
7246+
self, axis: AxisInt | None = None, skipna: bool = True, *args, **kwargs
7247+
) -> int:
72317248
nv.validate_argmin(args, kwargs)
72327249
nv.validate_minmax_axis(axis)
72337250

@@ -7240,7 +7257,9 @@ def argmin(self, axis=None, skipna: bool = True, *args, **kwargs) -> int:
72407257
return super().argmin(skipna=skipna)
72417258

72427259
@Appender(IndexOpsMixin.argmax.__doc__)
7243-
def argmax(self, axis=None, skipna: bool = True, *args, **kwargs) -> int:
7260+
def argmax(
7261+
self, axis: AxisInt | None = None, skipna: bool = True, *args, **kwargs
7262+
) -> int:
72447263
nv.validate_argmax(args, kwargs)
72457264
nv.validate_minmax_axis(axis)
72467265

@@ -7251,7 +7270,7 @@ def argmax(self, axis=None, skipna: bool = True, *args, **kwargs) -> int:
72517270
raise ValueError("Encountered all NA values")
72527271
return super().argmax(skipna=skipna)
72537272

7254-
def min(self, axis=None, skipna: bool = True, *args, **kwargs):
7273+
def min(self, axis: AxisInt | None = None, skipna: bool = True, *args, **kwargs):
72557274
"""
72567275
Return the minimum value of the Index.
72577276
@@ -7314,7 +7333,7 @@ def min(self, axis=None, skipna: bool = True, *args, **kwargs):
73147333

73157334
return nanops.nanmin(self._values, skipna=skipna)
73167335

7317-
def max(self, axis=None, skipna: bool = True, *args, **kwargs):
7336+
def max(self, axis: AxisInt | None = None, skipna: bool = True, *args, **kwargs):
73187337
"""
73197338
Return the maximum value of the Index.
73207339

0 commit comments

Comments
 (0)