-
-
Notifications
You must be signed in to change notification settings - Fork 18.6k
ENH: Implemented MultiIndex.searchsorted method ( GH14833) #61435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 14 commits
cffb863
1ba7ff8
ac70f3e
275b0e2
0e0b9b5
4747609
9ac62ab
e2c2c5e
e88da57
94f7c44
1f4a1c9
ffd99d8
5e2caa4
6b0d0ab
73e308b
1342657
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
Any, | ||
Literal, | ||
cast, | ||
overload, | ||
) | ||
import warnings | ||
|
||
|
@@ -44,6 +45,15 @@ | |
Shape, | ||
npt, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from pandas._typing import ( | ||
NumpySorter, | ||
NumpyValueArrayLike, | ||
ScalarLike_co, | ||
) | ||
|
||
|
||
from pandas.compat.numpy import function as nv | ||
from pandas.errors import ( | ||
InvalidIndexError, | ||
|
@@ -3778,6 +3788,99 @@ def _reorder_indexer( | |
ind = np.lexsort(keys) | ||
return indexer[ind] | ||
|
||
@overload | ||
def searchsorted( # type: ignore[overload-overlap] | ||
self, | ||
value: ScalarLike_co, | ||
side: Literal["left", "right"] = ..., | ||
sorter: NumpySorter = ..., | ||
) -> np.intp: ... | ||
|
||
@overload | ||
def searchsorted( | ||
self, | ||
value: npt.ArrayLike | ExtensionArray, | ||
side: Literal["left", "right"] = ..., | ||
sorter: NumpySorter = ..., | ||
) -> npt.NDArray[np.intp]: ... | ||
|
||
def searchsorted( | ||
self, | ||
value: NumpyValueArrayLike | ExtensionArray, | ||
side: Literal["left", "right"] = "left", | ||
sorter: npt.NDArray[np.intp] | None = None, | ||
) -> npt.NDArray[np.intp] | np.intp: | ||
""" | ||
Find the indices where elements should be inserted to maintain order. | ||
|
||
Parameters | ||
---------- | ||
value : Any | ||
The value(s) to search for in the MultiIndex. | ||
side : {'left', 'right'}, default 'left' | ||
If 'left', the index of the first suitable location found is given. | ||
If 'right', return the last such index. Note that if `value` is | ||
already present in the MultiIndex, the results will be different. | ||
sorter : 1-D array-like, optional | ||
Optional array of integer indices that sort the MultiIndex. | ||
|
||
Returns | ||
------- | ||
npt.NDArray[np.intp] or np.intp | ||
The index or indices where the value(s) should be inserted to | ||
maintain order. | ||
|
||
See Also | ||
-------- | ||
Index.searchsorted : Search for insertion point in a 1-D index. | ||
|
||
Examples | ||
-------- | ||
>>> mi = pd.MultiIndex.from_arrays([["a", "b", "c"], ["x", "y", "z"]]) | ||
>>> mi.searchsorted(("b", "y")) | ||
array([1]) | ||
""" | ||
|
||
if not value: | ||
raise ValueError("searchsorted requires a non-empty value") | ||
|
||
if not isinstance(value, (tuple, list)): | ||
raise TypeError("value must be a tuple or list") | ||
|
||
if isinstance(value, tuple): | ||
value = [value] | ||
|
||
if side not in ["left", "right"]: | ||
raise ValueError("side must be either 'left' or 'right'") | ||
|
||
indexer = self.get_indexer(value) | ||
result = [] | ||
|
||
for v, i in zip(value, indexer): | ||
if i != -1: | ||
val = i if side == "left" else i + 1 | ||
result.append(np.intp(val)) | ||
else: | ||
dtype = np.dtype( | ||
[ | ||
(f"level_{i}", np.asarray(level).dtype) | ||
for i, level in enumerate(self.levels) | ||
] | ||
) | ||
|
||
val_array = np.array([v], dtype=dtype) | ||
|
||
pos = np.searchsorted( | ||
np.asarray(self.values, dtype=dtype), | ||
val_array, | ||
side=side, | ||
sorter=sorter, | ||
) | ||
result.append(np.intp(pos[0])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This probably won't preserve pandas nullable extension types There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @mroeschke, thanks for the comment. Would it make sense to switch to using: dtype = np.dtype(
[(f"level_{i}", level.dtype) for i, level in enumerate(self.levels)]
) My concern here is that level.dtype might include extension dtypes (e.g., StringDtype, Int64Dtype) which are not supported by NumPy structured arrays and could raise a TypeError. An alternative would be to avoid np.searchsorted entirely in the fallback path and implement a manual binary search over the MultiIndex.values, which would preserve pandas’ nullable types and avoid reliance on NumPy's structured dtypes. Would you prefer we go with the binary search approach instead, or is there another path you'd recommend? Thanks! |
||
if len(result) == 1: | ||
return result[0] | ||
return np.array(result, dtype=np.intp) | ||
|
||
def truncate(self, before=None, after=None) -> MultiIndex: | ||
""" | ||
Slice index between two labels / tuples, return new MultiIndex. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These types do not match what the type annotation suggest
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment, you're absolutely right.
The current runtime check assumes value must be a tuple or list of tuples but the type annotation allows for array-like or ExtensionArray inputs, which may not satisfy that condition.
I will update this.