Skip to content

Commit 7617ed1

Browse files
TomAugspurgerjreback
authored andcommitted
ENH: ExtensionArray.searchsorted (#24350)
1 parent c1af4f5 commit 7617ed1

File tree

8 files changed

+100
-1
lines changed

8 files changed

+100
-1
lines changed

doc/source/whatsnew/v0.24.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,7 @@ update the ``ExtensionDtype._metadata`` tuple to match the signature of your
997997
- :meth:`~pandas.api.types.ExtensionArray.repeat` has been added (:issue:`24349`)
998998
- ``ExtensionDtype`` has gained the ability to instantiate from string dtypes, e.g. ``decimal`` would instantiate a registered ``DecimalDtype``; furthermore
999999
the ``ExtensionDtype`` has gained the method ``construct_array_type`` (:issue:`21185`)
1000+
- :meth:`~pandas.api.types.ExtensionArray.searchsorted` has been added (:issue:`24350`)
10001001
- An ``ExtensionArray`` with a boolean dtype now works correctly as a boolean indexer. :meth:`pandas.api.types.is_bool_dtype` now properly considers them boolean (:issue:`22326`)
10011002
- Added ``ExtensionDtype._is_numeric`` for controlling whether an extension dtype is considered numeric (:issue:`22290`).
10021003
- The ``ExtensionArray`` constructor, ``_from_sequence`` now take the keyword arg ``copy=False`` (:issue:`21185`)

pandas/core/arrays/base.py

+49
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class ExtensionArray(object):
6868
* unique
6969
* factorize / _values_for_factorize
7070
* argsort / _values_for_argsort
71+
* searchsorted
7172
7273
The remaining methods implemented on this class should be performant,
7374
as they only compose abstract methods. Still, a more efficient
@@ -518,6 +519,54 @@ def unique(self):
518519
uniques = unique(self.astype(object))
519520
return self._from_sequence(uniques, dtype=self.dtype)
520521

522+
def searchsorted(self, value, side="left", sorter=None):
523+
"""
524+
Find indices where elements should be inserted to maintain order.
525+
526+
.. versionadded:: 0.24.0
527+
528+
Find the indices into a sorted array `self` (a) such that, if the
529+
corresponding elements in `v` were inserted before the indices, the
530+
order of `self` would be preserved.
531+
532+
Assuming that `a` is sorted:
533+
534+
====== ============================
535+
`side` returned index `i` satisfies
536+
====== ============================
537+
left ``self[i-1] < v <= self[i]``
538+
right ``self[i-1] <= v < self[i]``
539+
====== ============================
540+
541+
Parameters
542+
----------
543+
value : array_like
544+
Values to insert into `self`.
545+
side : {'left', 'right'}, optional
546+
If 'left', the index of the first suitable location found is given.
547+
If 'right', return the last such index. If there is no suitable
548+
index, return either 0 or N (where N is the length of `self`).
549+
sorter : 1-D array_like, optional
550+
Optional array of integer indices that sort array a into ascending
551+
order. They are typically the result of argsort.
552+
553+
Returns
554+
-------
555+
indices : array of ints
556+
Array of insertion points with the same shape as `value`.
557+
558+
See Also
559+
--------
560+
numpy.searchsorted : Similar method from NumPy.
561+
"""
562+
# Note: the base tests provided by pandas only test the basics.
563+
# We do not test
564+
# 1. Values outside the range of the `data_for_sorting` fixture
565+
# 2. Values between the values in the `data_for_sorting` fixture
566+
# 3. Missing values.
567+
arr = self.astype(object)
568+
return arr.searchsorted(value, side=side, sorter=sorter)
569+
521570
def _values_for_factorize(self):
522571
# type: () -> Tuple[ndarray, Any]
523572
"""

pandas/core/arrays/sparse.py

+10
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,16 @@ def _take_without_fill(self, indices):
11691169

11701170
return taken
11711171

1172+
def searchsorted(self, v, side="left", sorter=None):
1173+
msg = "searchsorted requires high memory usage."
1174+
warnings.warn(msg, PerformanceWarning, stacklevel=2)
1175+
if not is_scalar(v):
1176+
v = np.asarray(v)
1177+
v = np.asarray(v)
1178+
return np.asarray(self, dtype=self.dtype.subtype).searchsorted(
1179+
v, side, sorter
1180+
)
1181+
11721182
def copy(self, deep=False):
11731183
if deep:
11741184
values = self.sp_values.copy()

pandas/core/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1464,7 +1464,7 @@ def factorize(self, sort=False, na_sentinel=-1):
14641464
@Appender(_shared_docs['searchsorted'])
14651465
def searchsorted(self, value, side='left', sorter=None):
14661466
# needs coercion on the key (DatetimeIndex does already)
1467-
return self.values.searchsorted(value, side=side, sorter=sorter)
1467+
return self._values.searchsorted(value, side=side, sorter=sorter)
14681468

14691469
def drop_duplicates(self, keep='first', inplace=False):
14701470
inplace = validate_bool_kwarg(inplace, 'inplace')

pandas/tests/extension/base/methods.py

+25
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,31 @@ def test_hash_pandas_object_works(self, data, as_frame):
242242
b = pd.util.hash_pandas_object(data)
243243
self.assert_equal(a, b)
244244

245+
@pytest.mark.parametrize("as_series", [True, False])
246+
def test_searchsorted(self, data_for_sorting, as_series):
247+
b, c, a = data_for_sorting
248+
arr = type(data_for_sorting)._from_sequence([a, b, c])
249+
250+
if as_series:
251+
arr = pd.Series(arr)
252+
assert arr.searchsorted(a) == 0
253+
assert arr.searchsorted(a, side="right") == 1
254+
255+
assert arr.searchsorted(b) == 1
256+
assert arr.searchsorted(b, side="right") == 2
257+
258+
assert arr.searchsorted(c) == 2
259+
assert arr.searchsorted(c, side="right") == 3
260+
261+
result = arr.searchsorted(arr.take([0, 2]))
262+
expected = np.array([0, 2], dtype=np.intp)
263+
264+
tm.assert_numpy_array_equal(result, expected)
265+
266+
# sorter
267+
sorter = np.array([1, 2, 0])
268+
assert data_for_sorting.searchsorted(a, sorter=sorter) == 0
269+
245270
@pytest.mark.parametrize("as_frame", [True, False])
246271
def test_where_series(self, data, na_value, as_frame):
247272
assert data[0] != data[1]

pandas/tests/extension/json/test_json.py

+4
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,10 @@ def test_where_series(self, data, na_value):
232232
# with shapes (4,) (4,) (0,)
233233
super().test_where_series(data, na_value)
234234

235+
@pytest.mark.skip(reason="Can't compare dicts.")
236+
def test_searchsorted(self, data_for_sorting):
237+
super(TestMethods, self).test_searchsorted(data_for_sorting)
238+
235239

236240
class TestCasting(BaseJSON, base.BaseCastingTests):
237241
@pytest.mark.skip(reason="failing on np.array(self, dtype=str)")

pandas/tests/extension/test_categorical.py

+4
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,10 @@ def test_combine_add(self, data_repeated):
189189
def test_fillna_length_mismatch(self, data_missing):
190190
super().test_fillna_length_mismatch(data_missing)
191191

192+
def test_searchsorted(self, data_for_sorting):
193+
if not data_for_sorting.ordered:
194+
raise pytest.skip(reason="searchsorted requires ordered data.")
195+
192196

193197
class TestCasting(base.BaseCastingTests):
194198
pass

pandas/tests/extension/test_sparse.py

+6
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,12 @@ def test_combine_first(self, data):
286286
pytest.skip("TODO(SparseArray.__setitem__ will preserve dtype.")
287287
super(TestMethods, self).test_combine_first(data)
288288

289+
@pytest.mark.parametrize("as_series", [True, False])
290+
def test_searchsorted(self, data_for_sorting, as_series):
291+
with tm.assert_produces_warning(PerformanceWarning):
292+
super(TestMethods, self).test_searchsorted(data_for_sorting,
293+
as_series=as_series)
294+
289295

290296
class TestCasting(BaseSparseTests, base.BaseCastingTests):
291297
pass

0 commit comments

Comments
 (0)