Skip to content

Commit 11a8c17

Browse files
authored
TYP: ExtensionArray delete() and searchsorted() (#41513)
1 parent 703b1ef commit 11a8c17

File tree

14 files changed

+177
-32
lines changed

14 files changed

+177
-32
lines changed

pandas/_typing.py

+6
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@
6969

7070
from pandas.io.formats.format import EngFormatter
7171
from pandas.tseries.offsets import DateOffset
72+
73+
# numpy compatible types
74+
NumpyValueArrayLike = Union[npt._ScalarLike_co, npt.ArrayLike]
75+
NumpySorter = Optional[npt._ArrayLikeInt_co]
76+
7277
else:
7378
npt: Any = None
7479

@@ -85,6 +90,7 @@
8590
PandasScalar = Union["Period", "Timestamp", "Timedelta", "Interval"]
8691
Scalar = Union[PythonScalar, PandasScalar]
8792

93+
8894
# timestamp and timedelta convertible types
8995

9096
TimestampConvertibleTypes = Union[

pandas/core/algorithms.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@
8282

8383
if TYPE_CHECKING:
8484

85+
from pandas._typing import (
86+
NumpySorter,
87+
NumpyValueArrayLike,
88+
)
89+
8590
from pandas import (
8691
Categorical,
8792
DataFrame,
@@ -1517,7 +1522,12 @@ def take(
15171522
# ------------ #
15181523

15191524

1520-
def searchsorted(arr, value, side="left", sorter=None) -> np.ndarray:
1525+
def searchsorted(
1526+
arr: ArrayLike,
1527+
value: NumpyValueArrayLike,
1528+
side: Literal["left", "right"] = "left",
1529+
sorter: NumpySorter = None,
1530+
) -> npt.NDArray[np.intp] | np.intp:
15211531
"""
15221532
Find indices where elements should be inserted to maintain order.
15231533
@@ -1554,8 +1564,9 @@ def searchsorted(arr, value, side="left", sorter=None) -> np.ndarray:
15541564
15551565
Returns
15561566
-------
1557-
array of ints
1558-
Array of insertion points with the same shape as `value`.
1567+
array of ints or int
1568+
If value is array-like, array of insertion points.
1569+
If value is scalar, a single integer.
15591570
15601571
See Also
15611572
--------
@@ -1583,9 +1594,10 @@ def searchsorted(arr, value, side="left", sorter=None) -> np.ndarray:
15831594
dtype = value_arr.dtype
15841595

15851596
if is_scalar(value):
1586-
value = dtype.type(value)
1597+
# We know that value is int
1598+
value = cast(int, dtype.type(value))
15871599
else:
1588-
value = pd_array(value, dtype=dtype)
1600+
value = pd_array(cast(ArrayLike, value), dtype=dtype)
15891601
elif not (
15901602
is_object_dtype(arr) or is_numeric_dtype(arr) or is_categorical_dtype(arr)
15911603
):

pandas/core/arrays/_mixins.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from functools import wraps
44
from typing import (
5+
TYPE_CHECKING,
56
Any,
67
Sequence,
78
TypeVar,
@@ -16,6 +17,7 @@
1617
F,
1718
PositionalIndexer2D,
1819
Shape,
20+
npt,
1921
type_t,
2022
)
2123
from pandas.errors import AbstractMethodError
@@ -45,6 +47,14 @@
4547
"NDArrayBackedExtensionArrayT", bound="NDArrayBackedExtensionArray"
4648
)
4749

50+
if TYPE_CHECKING:
51+
from typing import Literal
52+
53+
from pandas._typing import (
54+
NumpySorter,
55+
NumpyValueArrayLike,
56+
)
57+
4858

4959
def ravel_compat(meth: F) -> F:
5060
"""
@@ -157,12 +167,22 @@ def _concat_same_type(
157167
return to_concat[0]._from_backing_data(new_values) # type: ignore[arg-type]
158168

159169
@doc(ExtensionArray.searchsorted)
160-
def searchsorted(self, value, side="left", sorter=None):
161-
value = self._validate_searchsorted_value(value)
162-
return self._ndarray.searchsorted(value, side=side, sorter=sorter)
163-
164-
def _validate_searchsorted_value(self, value):
165-
return value
170+
def searchsorted(
171+
self,
172+
value: NumpyValueArrayLike | ExtensionArray,
173+
side: Literal["left", "right"] = "left",
174+
sorter: NumpySorter = None,
175+
) -> npt.NDArray[np.intp] | np.intp:
176+
npvalue = self._validate_searchsorted_value(value)
177+
return self._ndarray.searchsorted(npvalue, side=side, sorter=sorter)
178+
179+
def _validate_searchsorted_value(
180+
self, value: NumpyValueArrayLike | ExtensionArray
181+
) -> NumpyValueArrayLike:
182+
if isinstance(value, ExtensionArray):
183+
return value.to_numpy()
184+
else:
185+
return value
166186

167187
@doc(ExtensionArray.shift)
168188
def shift(self, periods=1, fill_value=None, axis=0):

pandas/core/arrays/base.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
FillnaOptions,
3030
PositionalIndexer,
3131
Shape,
32+
npt,
3233
)
3334
from pandas.compat import set_function_name
3435
from pandas.compat.numpy import function as nv
@@ -81,6 +82,11 @@ def any(self, *, skipna: bool = True) -> bool:
8182
def all(self, *, skipna: bool = True) -> bool:
8283
pass
8384

85+
from pandas._typing import (
86+
NumpySorter,
87+
NumpyValueArrayLike,
88+
)
89+
8490

8591
_extension_array_shared_docs: dict[str, str] = {}
8692

@@ -807,7 +813,12 @@ def unique(self: ExtensionArrayT) -> ExtensionArrayT:
807813
uniques = unique(self.astype(object))
808814
return self._from_sequence(uniques, dtype=self.dtype)
809815

810-
def searchsorted(self, value, side="left", sorter=None):
816+
def searchsorted(
817+
self,
818+
value: NumpyValueArrayLike | ExtensionArray,
819+
side: Literal["left", "right"] = "left",
820+
sorter: NumpySorter = None,
821+
) -> npt.NDArray[np.intp] | np.intp:
811822
"""
812823
Find indices where elements should be inserted to maintain order.
813824
@@ -838,8 +849,9 @@ def searchsorted(self, value, side="left", sorter=None):
838849
839850
Returns
840851
-------
841-
array of ints
842-
Array of insertion points with the same shape as `value`.
852+
array of ints or int
853+
If value is array-like, array of insertion points.
854+
If value is scalar, a single integer.
843855
844856
See Also
845857
--------
@@ -1304,7 +1316,7 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
13041316
# ------------------------------------------------------------------------
13051317
# Non-Optimized Default Methods
13061318

1307-
def delete(self: ExtensionArrayT, loc) -> ExtensionArrayT:
1319+
def delete(self: ExtensionArrayT, loc: PositionalIndexer) -> ExtensionArrayT:
13081320
indexer = np.delete(np.arange(len(self)), loc)
13091321
return self.take(indexer)
13101322

pandas/core/arrays/period.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
AnyArrayLike,
4242
Dtype,
4343
NpDtype,
44+
npt,
4445
)
4546
from pandas.util._decorators import (
4647
cache_readonly,
@@ -71,11 +72,20 @@
7172

7273
import pandas.core.algorithms as algos
7374
from pandas.core.arrays import datetimelike as dtl
75+
from pandas.core.arrays.base import ExtensionArray
7476
import pandas.core.common as com
7577

7678
if TYPE_CHECKING:
79+
from typing import Literal
80+
81+
from pandas._typing import (
82+
NumpySorter,
83+
NumpyValueArrayLike,
84+
)
85+
7786
from pandas.core.arrays import DatetimeArray
7887

88+
7989
_shared_doc_kwargs = {
8090
"klass": "PeriodArray",
8191
}
@@ -644,12 +654,17 @@ def astype(self, dtype, copy: bool = True):
644654
return self.asfreq(dtype.freq)
645655
return super().astype(dtype, copy=copy)
646656

647-
def searchsorted(self, value, side="left", sorter=None) -> np.ndarray:
648-
value = self._validate_searchsorted_value(value).view("M8[ns]")
657+
def searchsorted(
658+
self,
659+
value: NumpyValueArrayLike | ExtensionArray,
660+
side: Literal["left", "right"] = "left",
661+
sorter: NumpySorter = None,
662+
) -> npt.NDArray[np.intp] | np.intp:
663+
npvalue = self._validate_searchsorted_value(value).view("M8[ns]")
649664

650665
# Cast to M8 to get datetime-like NaT placement
651666
m8arr = self._ndarray.view("M8[ns]")
652-
return m8arr.searchsorted(value, side=side, sorter=sorter)
667+
return m8arr.searchsorted(npvalue, side=side, sorter=sorter)
653668

654669
def fillna(self, value=None, method=None, limit=None) -> PeriodArray:
655670
if method is not None:

pandas/core/arrays/sparse/array.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numbers
88
import operator
99
from typing import (
10+
TYPE_CHECKING,
1011
Any,
1112
Callable,
1213
Sequence,
@@ -25,9 +26,11 @@
2526
)
2627
from pandas._libs.tslibs import NaT
2728
from pandas._typing import (
29+
ArrayLike,
2830
Dtype,
2931
NpDtype,
3032
Scalar,
33+
npt,
3134
)
3235
from pandas.compat.numpy import function as nv
3336
from pandas.errors import PerformanceWarning
@@ -77,6 +80,11 @@
7780

7881
import pandas.io.formats.printing as printing
7982

83+
if TYPE_CHECKING:
84+
from typing import Literal
85+
86+
from pandas._typing import NumpySorter
87+
8088
# ----------------------------------------------------------------------------
8189
# Array
8290

@@ -992,7 +1000,13 @@ def _take_without_fill(self, indices) -> np.ndarray | SparseArray:
9921000

9931001
return taken
9941002

995-
def searchsorted(self, v, side="left", sorter=None):
1003+
def searchsorted(
1004+
self,
1005+
v: ArrayLike | object,
1006+
side: Literal["left", "right"] = "left",
1007+
sorter: NumpySorter = None,
1008+
) -> npt.NDArray[np.intp] | np.intp:
1009+
9961010
msg = "searchsorted requires high memory usage."
9971011
warnings.warn(msg, PerformanceWarning, stacklevel=2)
9981012
if not is_scalar(v):

pandas/core/base.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,14 @@
6666

6767
if TYPE_CHECKING:
6868

69+
from pandas._typing import (
70+
NumpySorter,
71+
NumpyValueArrayLike,
72+
)
73+
6974
from pandas import Categorical
7075

76+
7177
_shared_docs: dict[str, str] = {}
7278
_indexops_doc_kwargs = {
7379
"klass": "IndexOpsMixin",
@@ -1222,7 +1228,12 @@ def factorize(self, sort: bool = False, na_sentinel: int | None = -1):
12221228
"""
12231229

12241230
@doc(_shared_docs["searchsorted"], klass="Index")
1225-
def searchsorted(self, value, side="left", sorter=None) -> npt.NDArray[np.intp]:
1231+
def searchsorted(
1232+
self,
1233+
value: NumpyValueArrayLike,
1234+
side: Literal["left", "right"] = "left",
1235+
sorter: NumpySorter = None,
1236+
) -> npt.NDArray[np.intp] | np.intp:
12261237
return algorithms.searchsorted(self._values, value, side=side, sorter=sorter)
12271238

12281239
def drop_duplicates(self, keep="first"):

pandas/core/computation/pytables.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Timedelta,
1212
Timestamp,
1313
)
14+
from pandas._typing import npt
1415
from pandas.compat.chainmap import DeepChainMap
1516

1617
from pandas.core.dtypes.common import is_list_like
@@ -223,6 +224,7 @@ def stringify(value):
223224
return TermValue(int(v), v, kind)
224225
elif meta == "category":
225226
metadata = extract_array(self.metadata, extract_numpy=True)
227+
result: npt.NDArray[np.intp] | np.intp | int
226228
if v not in metadata:
227229
result = -1
228230
else:

pandas/core/generic.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -8251,8 +8251,7 @@ def last(self: FrameOrSeries, offset) -> FrameOrSeries:
82518251

82528252
start_date = self.index[-1] - offset
82538253
start = self.index.searchsorted(start_date, side="right")
8254-
# error: Slice index must be an integer or None
8255-
return self.iloc[start:] # type: ignore[misc]
8254+
return self.iloc[start:]
82568255

82578256
@final
82588257
def rank(

pandas/core/indexes/base.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -3730,7 +3730,7 @@ def _get_fill_indexer_searchsorted(
37303730
"if index and target are monotonic"
37313731
)
37323732

3733-
side = "left" if method == "pad" else "right"
3733+
side: Literal["left", "right"] = "left" if method == "pad" else "right"
37343734

37353735
# find exact matches first (this simplifies the algorithm)
37363736
indexer = self.get_indexer(target)
@@ -6063,7 +6063,7 @@ def _maybe_cast_slice_bound(self, label, side: str_t, kind=no_default):
60636063

60646064
return label
60656065

6066-
def _searchsorted_monotonic(self, label, side: str_t = "left"):
6066+
def _searchsorted_monotonic(self, label, side: Literal["left", "right"] = "left"):
60676067
if self.is_monotonic_increasing:
60686068
return self.searchsorted(label, side=side)
60696069
elif self.is_monotonic_decreasing:
@@ -6077,7 +6077,9 @@ def _searchsorted_monotonic(self, label, side: str_t = "left"):
60776077

60786078
raise ValueError("index must be monotonic increasing or decreasing")
60796079

6080-
def get_slice_bound(self, label, side: str_t, kind=no_default) -> int:
6080+
def get_slice_bound(
6081+
self, label, side: Literal["left", "right"], kind=no_default
6082+
) -> int:
60816083
"""
60826084
Calculate slice bound that corresponds to given label.
60836085

pandas/core/indexes/datetimelike.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -674,8 +674,7 @@ def _fast_union(self: _T, other: _T, sort=None) -> _T:
674674
left, right = self, other
675675
left_start = left[0]
676676
loc = right.searchsorted(left_start, side="left")
677-
# error: Slice index must be an integer or None
678-
right_chunk = right._values[:loc] # type: ignore[misc]
677+
right_chunk = right._values[:loc]
679678
dates = concat_compat((left._values, right_chunk))
680679
# With sort being False, we can't infer that result.freq == self.freq
681680
# TODO: no tests rely on the _with_freq("infer"); needed?
@@ -691,8 +690,7 @@ def _fast_union(self: _T, other: _T, sort=None) -> _T:
691690
# concatenate
692691
if left_end < right_end:
693692
loc = right.searchsorted(left_end, side="right")
694-
# error: Slice index must be an integer or None
695-
right_chunk = right._values[loc:] # type: ignore[misc]
693+
right_chunk = right._values[loc:]
696694
dates = concat_compat([left._values, right_chunk])
697695
# The can_fast_union check ensures that the result.freq
698696
# should match self.freq

0 commit comments

Comments
 (0)