Skip to content

Commit f3d4817

Browse files
authored
PERF: Index.get_loc (#43705)
1 parent b2b1c53 commit f3d4817

File tree

3 files changed

+54
-5
lines changed

3 files changed

+54
-5
lines changed

pandas/_libs/index.pyx

+26-5
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,7 @@ cdef class IndexEngine:
8787
values = self.values
8888

8989
self._check_type(val)
90-
try:
91-
loc = _bin_search(values, val) # .searchsorted(val, side='left')
92-
except TypeError:
93-
# GH#35788 e.g. val=None with float64 values
94-
raise KeyError(val)
90+
loc = self._searchsorted_left(val)
9591
if loc >= len(values):
9692
raise KeyError(val)
9793
if values[loc] != val:
@@ -110,6 +106,17 @@ cdef class IndexEngine:
110106
# GH#41775 OverflowError e.g. if we are uint64 and val is -1
111107
raise KeyError(val)
112108

109+
cdef Py_ssize_t _searchsorted_left(self, val) except? -1:
110+
"""
111+
See ObjectEngine._searchsorted_left.__doc__.
112+
"""
113+
try:
114+
loc = self.values.searchsorted(val, side="left")
115+
except TypeError as err:
116+
# GH#35788 e.g. val=None with float64 values
117+
raise KeyError(val)
118+
return loc
119+
113120
cdef inline _get_loc_duplicates(self, object val):
114121
# -> Py_ssize_t | slice | ndarray[bool]
115122
cdef:
@@ -373,6 +380,11 @@ cdef class IndexEngine:
373380

374381

375382
cdef Py_ssize_t _bin_search(ndarray values, object val) except -1:
383+
# GH#1757 ndarray.searchsorted is not safe to use with array of tuples
384+
# (treats a tuple `val` as a sequence of keys instead of a single key),
385+
# so we implement something similar.
386+
# This is equivalent to the stdlib's bisect.bisect_left
387+
376388
cdef:
377389
Py_ssize_t mid = 0, lo = 0, hi = len(values) - 1
378390
object pval
@@ -405,6 +417,15 @@ cdef class ObjectEngine(IndexEngine):
405417
cdef _make_hash_table(self, Py_ssize_t n):
406418
return _hash.PyObjectHashTable(n)
407419

420+
cdef Py_ssize_t _searchsorted_left(self, val) except? -1:
421+
# using values.searchsorted here would treat a tuple `val` as a sequence
422+
# instead of a single key, so we use a different implementation
423+
try:
424+
loc = _bin_search(self.values, val)
425+
except TypeError as err:
426+
raise KeyError(val) from err
427+
return loc
428+
408429

409430
cdef class DatetimeEngine(Int64Engine):
410431

pandas/tests/indexes/base_class/test_indexing.py

+20
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33

4+
import pandas as pd
45
from pandas import Index
56
import pandas._testing as tm
67

@@ -36,3 +37,22 @@ def test_get_indexer_non_unique_dtype_mismatch(self):
3637
indexes, missing = Index(["A", "B"]).get_indexer_non_unique(Index([0]))
3738
tm.assert_numpy_array_equal(np.array([-1], dtype=np.intp), indexes)
3839
tm.assert_numpy_array_equal(np.array([0], dtype=np.intp), missing)
40+
41+
42+
class TestGetLoc:
43+
@pytest.mark.slow # to_flat_index takes a while
44+
def test_get_loc_tuple_monotonic_above_size_cutoff(self):
45+
# Go through the libindex path for which using
46+
# _bin_search vs ndarray.searchsorted makes a difference
47+
48+
lev = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
49+
dti = pd.date_range("2016-01-01", periods=100)
50+
51+
mi = pd.MultiIndex.from_product([lev, range(10 ** 3), dti])
52+
oidx = mi.to_flat_index()
53+
54+
loc = len(oidx) // 2
55+
tup = oidx[loc]
56+
57+
res = oidx.get_loc(tup)
58+
assert res == loc

pandas/tests/indexes/numeric/test_indexing.py

+8
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,14 @@ def test_get_loc_float_index_nan_with_method(self, vals, method):
152152
with tm.assert_produces_warning(FutureWarning, match="deprecated"):
153153
idx.get_loc(np.nan, method=method)
154154

155+
@pytest.mark.parametrize("dtype", ["f8", "i8", "u8"])
156+
def test_get_loc_numericindex_none_raises(self, dtype):
157+
# case that goes through searchsorted and key is non-comparable to values
158+
arr = np.arange(10 ** 7, dtype=dtype)
159+
idx = Index(arr)
160+
with pytest.raises(KeyError, match="None"):
161+
idx.get_loc(None)
162+
155163

156164
class TestGetIndexer:
157165
def test_get_indexer(self):

0 commit comments

Comments
 (0)