diff --git a/pandas/_typing.py b/pandas/_typing.py index ef9f38bbf5168..433f8645d35a8 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -69,6 +69,11 @@ from pandas.io.formats.format import EngFormatter from pandas.tseries.offsets import DateOffset + + # numpy compatible types + NumpyValueArrayLike = Union[npt._ScalarLike_co, npt.ArrayLike] + NumpySorter = Optional[npt._ArrayLikeInt_co] + else: npt: Any = None @@ -85,6 +90,7 @@ PandasScalar = Union["Period", "Timestamp", "Timedelta", "Interval"] Scalar = Union[PythonScalar, PandasScalar] + # timestamp and timedelta convertible types TimestampConvertibleTypes = Union[ diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 3ba18b525a1e8..029daa2a0893e 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -82,6 +82,11 @@ if TYPE_CHECKING: + from pandas._typing import ( + NumpySorter, + NumpyValueArrayLike, + ) + from pandas import ( Categorical, DataFrame, @@ -1517,7 +1522,12 @@ def take( # ------------ # -def searchsorted(arr, value, side="left", sorter=None) -> np.ndarray: +def searchsorted( + arr: ArrayLike, + value: NumpyValueArrayLike, + side: Literal["left", "right"] = "left", + sorter: NumpySorter = None, +) -> npt.NDArray[np.intp] | np.intp: """ Find indices where elements should be inserted to maintain order. @@ -1554,8 +1564,9 @@ def searchsorted(arr, value, side="left", sorter=None) -> np.ndarray: Returns ------- - array of ints - Array of insertion points with the same shape as `value`. + array of ints or int + If value is array-like, array of insertion points. + If value is scalar, a single integer. See Also -------- @@ -1583,9 +1594,10 @@ def searchsorted(arr, value, side="left", sorter=None) -> np.ndarray: dtype = value_arr.dtype if is_scalar(value): - value = dtype.type(value) + # We know that value is int + value = cast(int, dtype.type(value)) else: - value = pd_array(value, dtype=dtype) + value = pd_array(cast(ArrayLike, value), dtype=dtype) elif not ( is_object_dtype(arr) or is_numeric_dtype(arr) or is_categorical_dtype(arr) ): diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 0e8097cf1fc78..f13f1a418c2e9 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -2,6 +2,7 @@ from functools import wraps from typing import ( + TYPE_CHECKING, Any, Sequence, TypeVar, @@ -16,6 +17,7 @@ F, PositionalIndexer2D, Shape, + npt, type_t, ) from pandas.errors import AbstractMethodError @@ -45,6 +47,14 @@ "NDArrayBackedExtensionArrayT", bound="NDArrayBackedExtensionArray" ) +if TYPE_CHECKING: + from typing import Literal + + from pandas._typing import ( + NumpySorter, + NumpyValueArrayLike, + ) + def ravel_compat(meth: F) -> F: """ @@ -157,12 +167,22 @@ def _concat_same_type( return to_concat[0]._from_backing_data(new_values) # type: ignore[arg-type] @doc(ExtensionArray.searchsorted) - def searchsorted(self, value, side="left", sorter=None): - value = self._validate_searchsorted_value(value) - return self._ndarray.searchsorted(value, side=side, sorter=sorter) - - def _validate_searchsorted_value(self, value): - return value + def searchsorted( + self, + value: NumpyValueArrayLike | ExtensionArray, + side: Literal["left", "right"] = "left", + sorter: NumpySorter = None, + ) -> npt.NDArray[np.intp] | np.intp: + npvalue = self._validate_searchsorted_value(value) + return self._ndarray.searchsorted(npvalue, side=side, sorter=sorter) + + def _validate_searchsorted_value( + self, value: NumpyValueArrayLike | ExtensionArray + ) -> NumpyValueArrayLike: + if isinstance(value, ExtensionArray): + return value.to_numpy() + else: + return value @doc(ExtensionArray.shift) def shift(self, periods=1, fill_value=None, axis=0): diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index b362769f50fa8..4cc0d4185b22c 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -29,6 +29,7 @@ FillnaOptions, PositionalIndexer, Shape, + npt, ) from pandas.compat import set_function_name from pandas.compat.numpy import function as nv @@ -81,6 +82,11 @@ def any(self, *, skipna: bool = True) -> bool: def all(self, *, skipna: bool = True) -> bool: pass + from pandas._typing import ( + NumpySorter, + NumpyValueArrayLike, + ) + _extension_array_shared_docs: dict[str, str] = {} @@ -807,7 +813,12 @@ def unique(self: ExtensionArrayT) -> ExtensionArrayT: uniques = unique(self.astype(object)) return self._from_sequence(uniques, dtype=self.dtype) - def searchsorted(self, value, side="left", sorter=None): + def searchsorted( + self, + value: NumpyValueArrayLike | ExtensionArray, + side: Literal["left", "right"] = "left", + sorter: NumpySorter = None, + ) -> npt.NDArray[np.intp] | np.intp: """ Find indices where elements should be inserted to maintain order. @@ -838,8 +849,9 @@ def searchsorted(self, value, side="left", sorter=None): Returns ------- - array of ints - Array of insertion points with the same shape as `value`. + array of ints or int + If value is array-like, array of insertion points. + If value is scalar, a single integer. See Also -------- @@ -1304,7 +1316,7 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs): # ------------------------------------------------------------------------ # Non-Optimized Default Methods - def delete(self: ExtensionArrayT, loc) -> ExtensionArrayT: + def delete(self: ExtensionArrayT, loc: PositionalIndexer) -> ExtensionArrayT: indexer = np.delete(np.arange(len(self)), loc) return self.take(indexer) diff --git a/pandas/core/arrays/period.py b/pandas/core/arrays/period.py index 471ee295ebd2f..488981bcc9687 100644 --- a/pandas/core/arrays/period.py +++ b/pandas/core/arrays/period.py @@ -41,6 +41,7 @@ AnyArrayLike, Dtype, NpDtype, + npt, ) from pandas.util._decorators import ( cache_readonly, @@ -71,11 +72,20 @@ import pandas.core.algorithms as algos from pandas.core.arrays import datetimelike as dtl +from pandas.core.arrays.base import ExtensionArray import pandas.core.common as com if TYPE_CHECKING: + from typing import Literal + + from pandas._typing import ( + NumpySorter, + NumpyValueArrayLike, + ) + from pandas.core.arrays import DatetimeArray + _shared_doc_kwargs = { "klass": "PeriodArray", } @@ -644,12 +654,17 @@ def astype(self, dtype, copy: bool = True): return self.asfreq(dtype.freq) return super().astype(dtype, copy=copy) - def searchsorted(self, value, side="left", sorter=None) -> np.ndarray: - value = self._validate_searchsorted_value(value).view("M8[ns]") + def searchsorted( + self, + value: NumpyValueArrayLike | ExtensionArray, + side: Literal["left", "right"] = "left", + sorter: NumpySorter = None, + ) -> npt.NDArray[np.intp] | np.intp: + npvalue = self._validate_searchsorted_value(value).view("M8[ns]") # Cast to M8 to get datetime-like NaT placement m8arr = self._ndarray.view("M8[ns]") - return m8arr.searchsorted(value, side=side, sorter=sorter) + return m8arr.searchsorted(npvalue, side=side, sorter=sorter) def fillna(self, value=None, method=None, limit=None) -> PeriodArray: if method is not None: diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index 68c9e42ef8e08..b1c794ac03b31 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -7,6 +7,7 @@ import numbers import operator from typing import ( + TYPE_CHECKING, Any, Callable, Sequence, @@ -25,9 +26,11 @@ ) from pandas._libs.tslibs import NaT from pandas._typing import ( + ArrayLike, Dtype, NpDtype, Scalar, + npt, ) from pandas.compat.numpy import function as nv from pandas.errors import PerformanceWarning @@ -77,6 +80,11 @@ import pandas.io.formats.printing as printing +if TYPE_CHECKING: + from typing import Literal + + from pandas._typing import NumpySorter + # ---------------------------------------------------------------------------- # Array @@ -992,7 +1000,13 @@ def _take_without_fill(self, indices) -> np.ndarray | SparseArray: return taken - def searchsorted(self, v, side="left", sorter=None): + def searchsorted( + self, + v: ArrayLike | object, + side: Literal["left", "right"] = "left", + sorter: NumpySorter = None, + ) -> npt.NDArray[np.intp] | np.intp: + msg = "searchsorted requires high memory usage." warnings.warn(msg, PerformanceWarning, stacklevel=2) if not is_scalar(v): diff --git a/pandas/core/base.py b/pandas/core/base.py index 57e015dc378c8..c7a707fd5cd6e 100644 --- a/pandas/core/base.py +++ b/pandas/core/base.py @@ -66,8 +66,14 @@ if TYPE_CHECKING: + from pandas._typing import ( + NumpySorter, + NumpyValueArrayLike, + ) + from pandas import Categorical + _shared_docs: dict[str, str] = {} _indexops_doc_kwargs = { "klass": "IndexOpsMixin", @@ -1222,7 +1228,12 @@ def factorize(self, sort: bool = False, na_sentinel: int | None = -1): """ @doc(_shared_docs["searchsorted"], klass="Index") - def searchsorted(self, value, side="left", sorter=None) -> npt.NDArray[np.intp]: + def searchsorted( + self, + value: NumpyValueArrayLike, + side: Literal["left", "right"] = "left", + sorter: NumpySorter = None, + ) -> npt.NDArray[np.intp] | np.intp: return algorithms.searchsorted(self._values, value, side=side, sorter=sorter) def drop_duplicates(self, keep="first"): diff --git a/pandas/core/computation/pytables.py b/pandas/core/computation/pytables.py index ad76a76a954b1..3e041c088f566 100644 --- a/pandas/core/computation/pytables.py +++ b/pandas/core/computation/pytables.py @@ -11,6 +11,7 @@ Timedelta, Timestamp, ) +from pandas._typing import npt from pandas.compat.chainmap import DeepChainMap from pandas.core.dtypes.common import is_list_like @@ -223,6 +224,7 @@ def stringify(value): return TermValue(int(v), v, kind) elif meta == "category": metadata = extract_array(self.metadata, extract_numpy=True) + result: npt.NDArray[np.intp] | np.intp | int if v not in metadata: result = -1 else: diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 1f51576cc6e90..48daf7c89fe64 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -8251,8 +8251,7 @@ def last(self: FrameOrSeries, offset) -> FrameOrSeries: start_date = self.index[-1] - offset start = self.index.searchsorted(start_date, side="right") - # error: Slice index must be an integer or None - return self.iloc[start:] # type: ignore[misc] + return self.iloc[start:] @final def rank( diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 87e19ce6ef670..645fab0d76a73 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -3730,7 +3730,7 @@ def _get_fill_indexer_searchsorted( "if index and target are monotonic" ) - side = "left" if method == "pad" else "right" + side: Literal["left", "right"] = "left" if method == "pad" else "right" # find exact matches first (this simplifies the algorithm) indexer = self.get_indexer(target) @@ -6063,7 +6063,7 @@ def _maybe_cast_slice_bound(self, label, side: str_t, kind=no_default): return label - def _searchsorted_monotonic(self, label, side: str_t = "left"): + def _searchsorted_monotonic(self, label, side: Literal["left", "right"] = "left"): if self.is_monotonic_increasing: return self.searchsorted(label, side=side) elif self.is_monotonic_decreasing: @@ -6077,7 +6077,9 @@ def _searchsorted_monotonic(self, label, side: str_t = "left"): raise ValueError("index must be monotonic increasing or decreasing") - def get_slice_bound(self, label, side: str_t, kind=no_default) -> int: + def get_slice_bound( + self, label, side: Literal["left", "right"], kind=no_default + ) -> int: """ Calculate slice bound that corresponds to given label. diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 5d778af954eef..d2f598261a776 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -674,8 +674,7 @@ def _fast_union(self: _T, other: _T, sort=None) -> _T: left, right = self, other left_start = left[0] loc = right.searchsorted(left_start, side="left") - # error: Slice index must be an integer or None - right_chunk = right._values[:loc] # type: ignore[misc] + right_chunk = right._values[:loc] dates = concat_compat((left._values, right_chunk)) # With sort being False, we can't infer that result.freq == self.freq # TODO: no tests rely on the _with_freq("infer"); needed? @@ -691,8 +690,7 @@ def _fast_union(self: _T, other: _T, sort=None) -> _T: # concatenate if left_end < right_end: loc = right.searchsorted(left_end, side="right") - # error: Slice index must be an integer or None - right_chunk = right._values[loc:] # type: ignore[misc] + right_chunk = right._values[loc:] dates = concat_compat([left._values, right_chunk]) # The can_fast_union check ensures that the result.freq # should match self.freq diff --git a/pandas/core/indexes/datetimes.py b/pandas/core/indexes/datetimes.py index 8a5811da4dd5a..fbbe6606ba522 100644 --- a/pandas/core/indexes/datetimes.py +++ b/pandas/core/indexes/datetimes.py @@ -11,6 +11,7 @@ from typing import ( TYPE_CHECKING, Hashable, + Literal, ) import warnings @@ -765,7 +766,9 @@ def check_str_or_none(point): return indexer @doc(Index.get_slice_bound) - def get_slice_bound(self, label, side: str, kind=lib.no_default) -> int: + def get_slice_bound( + self, label, side: Literal["left", "right"], kind=lib.no_default + ) -> int: # GH#42855 handle date here instead of _maybe_cast_slice_bound if isinstance(label, date) and not isinstance(label, datetime): label = Timestamp(label).to_pydatetime() diff --git a/pandas/core/indexes/extension.py b/pandas/core/indexes/extension.py index 25f2378511dc4..920af5a13baba 100644 --- a/pandas/core/indexes/extension.py +++ b/pandas/core/indexes/extension.py @@ -4,8 +4,10 @@ from __future__ import annotations from typing import ( + TYPE_CHECKING, Hashable, TypeVar, + overload, ) import numpy as np @@ -39,10 +41,19 @@ TimedeltaArray, ) from pandas.core.arrays._mixins import NDArrayBackedExtensionArray +from pandas.core.arrays.base import ExtensionArray from pandas.core.indexers import deprecate_ndim_indexing from pandas.core.indexes.base import Index from pandas.core.ops import get_op_result_name +if TYPE_CHECKING: + from typing import Literal + + from pandas._typing import ( + NumpySorter, + NumpyValueArrayLike, + ) + _T = TypeVar("_T", bound="NDArrayBackedExtensionIndex") @@ -318,7 +329,37 @@ def __getitem__(self, key): deprecate_ndim_indexing(result) return result - def searchsorted(self, value, side="left", sorter=None) -> np.ndarray: + # This overload is needed so that the call to searchsorted in + # pandas.core.resample.TimeGrouper._get_period_bins picks the correct result + + @overload + # The following ignore is also present in numpy/__init__.pyi + # Possibly a mypy bug?? + # error: Overloaded function signatures 1 and 2 overlap with incompatible + # return types [misc] + def searchsorted( # type: ignore[misc] + self, + value: npt._ScalarLike_co, + side: Literal["left", "right"] = "left", + sorter: NumpySorter = None, + ) -> np.intp: + ... + + @overload + def searchsorted( + self, + value: npt.ArrayLike | ExtensionArray, + side: Literal["left", "right"] = "left", + sorter: NumpySorter = None, + ) -> npt.NDArray[np.intp]: + ... + + def searchsorted( + self, + value: NumpyValueArrayLike | ExtensionArray, + side: Literal["left", "right"] = "left", + sorter: NumpySorter = None, + ) -> npt.NDArray[np.intp] | np.intp: # overriding IndexOpsMixin improves performance GH#38083 return self._data.searchsorted(value, side=side, sorter=sorter) diff --git a/pandas/core/series.py b/pandas/core/series.py index a5ec4125f54a4..6f964ab09e978 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -144,6 +144,11 @@ if TYPE_CHECKING: + from pandas._typing import ( + NumpySorter, + NumpyValueArrayLike, + ) + from pandas.core.frame import DataFrame from pandas.core.groupby.generic import SeriesGroupBy from pandas.core.resample import Resampler @@ -2778,7 +2783,12 @@ def __rmatmul__(self, other): return self.dot(np.transpose(other)) @doc(base.IndexOpsMixin.searchsorted, klass="Series") - def searchsorted(self, value, side="left", sorter=None) -> np.ndarray: + def searchsorted( + self, + value: NumpyValueArrayLike, + side: Literal["left", "right"] = "left", + sorter: NumpySorter = None, + ) -> npt.NDArray[np.intp] | np.intp: return algorithms.searchsorted(self._values, value, side=side, sorter=sorter) # -------------------------------------------------------------------