Skip to content

TYP: ExtensionArray delete() and searchsorted() #41513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@

from pandas.io.formats.format import EngFormatter
from pandas.tseries.offsets import DateOffset

# numpy compatible types
NumpyValueArrayLike = Union[npt._ScalarLike_co, npt.ArrayLike]
NumpySorter = Optional[npt._ArrayLikeInt_co]

else:
npt: Any = None

Expand All @@ -85,6 +90,7 @@
PandasScalar = Union["Period", "Timestamp", "Timedelta", "Interval"]
Scalar = Union[PythonScalar, PandasScalar]


# timestamp and timedelta convertible types

TimestampConvertibleTypes = Union[
Expand Down
22 changes: 17 additions & 5 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@

if TYPE_CHECKING:

from pandas._typing import (
NumpySorter,
NumpyValueArrayLike,
)

from pandas import (
Categorical,
DataFrame,
Expand Down Expand Up @@ -1517,7 +1522,12 @@ def take(
# ------------ #


def searchsorted(arr, value, side="left", sorter=None) -> np.ndarray:
def searchsorted(
arr: ArrayLike,
value: NumpyValueArrayLike,
side: Literal["left", "right"] = "left",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is npt._SortSide? we probably don't want to use private aliases in the codebase, but maybe OK to do so in pandas._typing. thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't use _SortSide from numpy as it is only defined in their pyi files

sorter: NumpySorter = None,
) -> npt.NDArray[np.intp] | np.intp:
"""
Find indices where elements should be inserted to maintain order.

Expand Down Expand Up @@ -1554,8 +1564,9 @@ def searchsorted(arr, value, side="left", sorter=None) -> np.ndarray:

Returns
-------
array of ints
Array of insertion points with the same shape as `value`.
array of ints or int
If value is array-like, array of insertion points.
If value is scalar, a single integer.

See Also
--------
Expand Down Expand Up @@ -1583,9 +1594,10 @@ def searchsorted(arr, value, side="left", sorter=None) -> np.ndarray:
dtype = value_arr.dtype

if is_scalar(value):
value = dtype.type(value)
# We know that value is int
value = cast(int, dtype.type(value))
else:
value = pd_array(value, dtype=dtype)
value = pd_array(cast(ArrayLike, value), dtype=dtype)
elif not (
is_object_dtype(arr) or is_numeric_dtype(arr) or is_categorical_dtype(arr)
):
Expand Down
32 changes: 26 additions & 6 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Sequence,
TypeVar,
Expand All @@ -16,6 +17,7 @@
F,
PositionalIndexer2D,
Shape,
npt,
type_t,
)
from pandas.errors import AbstractMethodError
Expand Down Expand Up @@ -45,6 +47,14 @@
"NDArrayBackedExtensionArrayT", bound="NDArrayBackedExtensionArray"
)

if TYPE_CHECKING:
from typing import Literal

from pandas._typing import (
NumpySorter,
NumpyValueArrayLike,
)


def ravel_compat(meth: F) -> F:
"""
Expand Down Expand Up @@ -157,12 +167,22 @@ def _concat_same_type(
return to_concat[0]._from_backing_data(new_values) # type: ignore[arg-type]

@doc(ExtensionArray.searchsorted)
def searchsorted(self, value, side="left", sorter=None):
value = self._validate_searchsorted_value(value)
return self._ndarray.searchsorted(value, side=side, sorter=sorter)

def _validate_searchsorted_value(self, value):
return value
def searchsorted(
self,
value: NumpyValueArrayLike | ExtensionArray,
side: Literal["left", "right"] = "left",
sorter: NumpySorter = None,
) -> npt.NDArray[np.intp] | np.intp:
npvalue = self._validate_searchsorted_value(value)
return self._ndarray.searchsorted(npvalue, side=side, sorter=sorter)

def _validate_searchsorted_value(
self, value: NumpyValueArrayLike | ExtensionArray
) -> NumpyValueArrayLike:
if isinstance(value, ExtensionArray):
return value.to_numpy()
else:
return value

@doc(ExtensionArray.shift)
def shift(self, periods=1, fill_value=None, axis=0):
Expand Down
20 changes: 16 additions & 4 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
FillnaOptions,
PositionalIndexer,
Shape,
npt,
)
from pandas.compat import set_function_name
from pandas.compat.numpy import function as nv
Expand Down Expand Up @@ -81,6 +82,11 @@ def any(self, *, skipna: bool = True) -> bool:
def all(self, *, skipna: bool = True) -> bool:
pass

from pandas._typing import (
NumpySorter,
NumpyValueArrayLike,
)


_extension_array_shared_docs: dict[str, str] = {}

Expand Down Expand Up @@ -807,7 +813,12 @@ def unique(self: ExtensionArrayT) -> ExtensionArrayT:
uniques = unique(self.astype(object))
return self._from_sequence(uniques, dtype=self.dtype)

def searchsorted(self, value, side="left", sorter=None):
def searchsorted(
self,
value: NumpyValueArrayLike | ExtensionArray,
side: Literal["left", "right"] = "left",
sorter: NumpySorter = None,
) -> npt.NDArray[np.intp] | np.intp:
"""
Find indices where elements should be inserted to maintain order.

Expand Down Expand Up @@ -838,8 +849,9 @@ def searchsorted(self, value, side="left", sorter=None):

Returns
-------
array of ints
Array of insertion points with the same shape as `value`.
array of ints or int
If value is array-like, array of insertion points.
If value is scalar, a single integer.

See Also
--------
Expand Down Expand Up @@ -1304,7 +1316,7 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
# ------------------------------------------------------------------------
# Non-Optimized Default Methods

def delete(self: ExtensionArrayT, loc) -> ExtensionArrayT:
def delete(self: ExtensionArrayT, loc: PositionalIndexer) -> ExtensionArrayT:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in the 'public' base EA class. it was added in #39405. @jbrockmendel is this method public? Is it part of the EA interface? does it need a docstring?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch ill make a note to add a docstring

indexer = np.delete(np.arange(len(self)), loc)
return self.take(indexer)

Expand Down
21 changes: 18 additions & 3 deletions pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
AnyArrayLike,
Dtype,
NpDtype,
npt,
)
from pandas.util._decorators import (
cache_readonly,
Expand Down Expand Up @@ -71,11 +72,20 @@

import pandas.core.algorithms as algos
from pandas.core.arrays import datetimelike as dtl
from pandas.core.arrays.base import ExtensionArray
import pandas.core.common as com

if TYPE_CHECKING:
from typing import Literal

from pandas._typing import (
NumpySorter,
NumpyValueArrayLike,
)

from pandas.core.arrays import DatetimeArray


_shared_doc_kwargs = {
"klass": "PeriodArray",
}
Expand Down Expand Up @@ -644,12 +654,17 @@ def astype(self, dtype, copy: bool = True):
return self.asfreq(dtype.freq)
return super().astype(dtype, copy=copy)

def searchsorted(self, value, side="left", sorter=None) -> np.ndarray:
value = self._validate_searchsorted_value(value).view("M8[ns]")
def searchsorted(
self,
value: NumpyValueArrayLike | ExtensionArray,
side: Literal["left", "right"] = "left",
sorter: NumpySorter = None,
) -> npt.NDArray[np.intp] | np.intp:
npvalue = self._validate_searchsorted_value(value).view("M8[ns]")

# Cast to M8 to get datetime-like NaT placement
m8arr = self._ndarray.view("M8[ns]")
return m8arr.searchsorted(value, side=side, sorter=sorter)
return m8arr.searchsorted(npvalue, side=side, sorter=sorter)

def fillna(self, value=None, method=None, limit=None) -> PeriodArray:
if method is not None:
Expand Down
16 changes: 15 additions & 1 deletion pandas/core/arrays/sparse/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numbers
import operator
from typing import (
TYPE_CHECKING,
Any,
Callable,
Sequence,
Expand All @@ -25,9 +26,11 @@
)
from pandas._libs.tslibs import NaT
from pandas._typing import (
ArrayLike,
Dtype,
NpDtype,
Scalar,
npt,
)
from pandas.compat.numpy import function as nv
from pandas.errors import PerformanceWarning
Expand Down Expand Up @@ -77,6 +80,11 @@

import pandas.io.formats.printing as printing

if TYPE_CHECKING:
from typing import Literal

from pandas._typing import NumpySorter

# ----------------------------------------------------------------------------
# Array

Expand Down Expand Up @@ -992,7 +1000,13 @@ def _take_without_fill(self, indices) -> np.ndarray | SparseArray:

return taken

def searchsorted(self, v, side="left", sorter=None):
def searchsorted(
self,
v: ArrayLike | object,
side: Literal["left", "right"] = "left",
sorter: NumpySorter = None,
) -> npt.NDArray[np.intp] | np.intp:

msg = "searchsorted requires high memory usage."
warnings.warn(msg, PerformanceWarning, stacklevel=2)
if not is_scalar(v):
Expand Down
13 changes: 12 additions & 1 deletion pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,14 @@

if TYPE_CHECKING:

from pandas._typing import (
NumpySorter,
NumpyValueArrayLike,
)

from pandas import Categorical


_shared_docs: dict[str, str] = {}
_indexops_doc_kwargs = {
"klass": "IndexOpsMixin",
Expand Down Expand Up @@ -1222,7 +1228,12 @@ def factorize(self, sort: bool = False, na_sentinel: int | None = -1):
"""

@doc(_shared_docs["searchsorted"], klass="Index")
def searchsorted(self, value, side="left", sorter=None) -> npt.NDArray[np.intp]:
def searchsorted(
self,
value: NumpyValueArrayLike,
side: Literal["left", "right"] = "left",
sorter: NumpySorter = None,
) -> npt.NDArray[np.intp] | np.intp:
return algorithms.searchsorted(self._values, value, side=side, sorter=sorter)

def drop_duplicates(self, keep="first"):
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/computation/pytables.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Timedelta,
Timestamp,
)
from pandas._typing import npt
from pandas.compat.chainmap import DeepChainMap

from pandas.core.dtypes.common import is_list_like
Expand Down Expand Up @@ -223,6 +224,7 @@ def stringify(value):
return TermValue(int(v), v, kind)
elif meta == "category":
metadata = extract_array(self.metadata, extract_numpy=True)
result: npt.NDArray[np.intp] | np.intp | int
if v not in metadata:
result = -1
else:
Expand Down
3 changes: 1 addition & 2 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8251,8 +8251,7 @@ def last(self: FrameOrSeries, offset) -> FrameOrSeries:

start_date = self.index[-1] - offset
start = self.index.searchsorted(start_date, side="right")
# error: Slice index must be an integer or None
return self.iloc[start:] # type: ignore[misc]
return self.iloc[start:]

@final
def rank(
Expand Down
8 changes: 5 additions & 3 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3730,7 +3730,7 @@ def _get_fill_indexer_searchsorted(
"if index and target are monotonic"
)

side = "left" if method == "pad" else "right"
side: Literal["left", "right"] = "left" if method == "pad" else "right"

# find exact matches first (this simplifies the algorithm)
indexer = self.get_indexer(target)
Expand Down Expand Up @@ -6063,7 +6063,7 @@ def _maybe_cast_slice_bound(self, label, side: str_t, kind=no_default):

return label

def _searchsorted_monotonic(self, label, side: str_t = "left"):
def _searchsorted_monotonic(self, label, side: Literal["left", "right"] = "left"):
if self.is_monotonic_increasing:
return self.searchsorted(label, side=side)
elif self.is_monotonic_decreasing:
Expand All @@ -6077,7 +6077,9 @@ def _searchsorted_monotonic(self, label, side: str_t = "left"):

raise ValueError("index must be monotonic increasing or decreasing")

def get_slice_bound(self, label, side: str_t, kind=no_default) -> int:
def get_slice_bound(
self, label, side: Literal["left", "right"], kind=no_default
) -> int:
"""
Calculate slice bound that corresponds to given label.

Expand Down
6 changes: 2 additions & 4 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,7 @@ def _fast_union(self: _T, other: _T, sort=None) -> _T:
left, right = self, other
left_start = left[0]
loc = right.searchsorted(left_start, side="left")
# error: Slice index must be an integer or None
right_chunk = right._values[:loc] # type: ignore[misc]
right_chunk = right._values[:loc]
dates = concat_compat((left._values, right_chunk))
# With sort being False, we can't infer that result.freq == self.freq
# TODO: no tests rely on the _with_freq("infer"); needed?
Expand All @@ -691,8 +690,7 @@ def _fast_union(self: _T, other: _T, sort=None) -> _T:
# concatenate
if left_end < right_end:
loc = right.searchsorted(left_end, side="right")
# error: Slice index must be an integer or None
right_chunk = right._values[loc:] # type: ignore[misc]
right_chunk = right._values[loc:]
dates = concat_compat([left._values, right_chunk])
# The can_fast_union check ensures that the result.freq
# should match self.freq
Expand Down
Loading