Skip to content

Commit 3d61f98

Browse files
topper-123proost
authored andcommitted
PERF: improve perf. of Categorical.searchsorted (pandas-dev#28795)
1 parent 55538a0 commit 3d61f98

File tree

5 files changed

+34
-14
lines changed

5 files changed

+34
-14
lines changed

asv_bench/benchmarks/categoricals.py

+14
Original file line numberDiff line numberDiff line change
@@ -282,4 +282,18 @@ def time_sort_values(self):
282282
self.index.sort_values(ascending=False)
283283

284284

285+
class SearchSorted:
286+
def setup(self):
287+
N = 10 ** 5
288+
self.ci = tm.makeCategoricalIndex(N).sort_values()
289+
self.c = self.ci.values
290+
self.key = self.ci.categories[1]
291+
292+
def time_categorical_index_contains(self):
293+
self.ci.searchsorted(self.key)
294+
295+
def time_categorical_contains(self):
296+
self.c.searchsorted(self.key)
297+
298+
285299
from .pandas_vb_common import setup # noqa: F401 isort:skip

doc/source/whatsnew/v1.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ Performance improvements
204204
- Performance improvement in :meth:`DataFrame.corr` when ``method`` is ``"spearman"`` (:issue:`28139`)
205205
- Performance improvement in :meth:`DataFrame.replace` when provided a list of values to replace (:issue:`28099`)
206206
- Performance improvement in :meth:`DataFrame.select_dtypes` by using vectorization instead of iterating over a loop (:issue:`28317`)
207+
- Performance improvement in :meth:`Categorical.searchsorted` and :meth:`CategoricalIndex.searchsorted` (:issue:`28795`)
207208

208209
.. _whatsnew_1000.bug_fixes:
209210

pandas/core/arrays/categorical.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -1399,14 +1399,14 @@ def memory_usage(self, deep=False):
13991399
@Substitution(klass="Categorical")
14001400
@Appender(_shared_docs["searchsorted"])
14011401
def searchsorted(self, value, side="left", sorter=None):
1402-
from pandas.core.series import Series
1403-
1404-
codes = _get_codes_for_values(Series(value).values, self.categories)
1405-
if -1 in codes:
1406-
raise KeyError("Value(s) to be inserted must be in categories.")
1407-
1408-
codes = codes[0] if is_scalar(value) else codes
1409-
1402+
# searchsorted is very performance sensitive. By converting codes
1403+
# to same dtype as self.codes, we get much faster performance.
1404+
if is_scalar(value):
1405+
codes = self.categories.get_loc(value)
1406+
codes = self.codes.dtype.type(codes)
1407+
else:
1408+
locs = [self.categories.get_loc(x) for x in value]
1409+
codes = np.array(locs, dtype=self.codes.dtype)
14101410
return self.codes.searchsorted(codes, side=side, sorter=sorter)
14111411

14121412
def isna(self):

pandas/core/indexes/category.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pandas._libs.hashtable import duplicated_int64
1111
import pandas.compat as compat
1212
from pandas.compat.numpy import function as nv
13-
from pandas.util._decorators import Appender, cache_readonly
13+
from pandas.util._decorators import Appender, Substitution, cache_readonly
1414

1515
from pandas.core.dtypes.common import (
1616
ensure_platform_int,
@@ -27,6 +27,7 @@
2727
from pandas.core import accessor
2828
from pandas.core.algorithms import take_1d
2929
from pandas.core.arrays.categorical import Categorical, _recode_for_categories, contains
30+
from pandas.core.base import _shared_docs
3031
import pandas.core.common as com
3132
import pandas.core.indexes.base as ibase
3233
from pandas.core.indexes.base import Index, _index_shared_docs
@@ -555,6 +556,11 @@ def _can_reindex(self, indexer):
555556
""" always allow reindexing """
556557
pass
557558

559+
@Substitution(klass="CategoricalIndex")
560+
@Appender(_shared_docs["searchsorted"])
561+
def searchsorted(self, value, side="left", sorter=None):
562+
return self._data.searchsorted(value, side=side, sorter=sorter)
563+
558564
@Appender(_index_shared_docs["where"])
559565
def where(self, cond, other=None):
560566
# TODO: Investigate an alternative implementation with

pandas/tests/arrays/categorical/test_analytics.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,15 @@ def test_searchsorted(self, ordered_fixture):
113113
tm.assert_numpy_array_equal(res_ser, exp)
114114

115115
# Searching for a single value that is not from the Categorical
116-
msg = r"Value\(s\) to be inserted must be in categories"
117-
with pytest.raises(KeyError, match=msg):
116+
with pytest.raises(KeyError, match="cucumber"):
118117
cat.searchsorted("cucumber")
119-
with pytest.raises(KeyError, match=msg):
118+
with pytest.raises(KeyError, match="cucumber"):
120119
ser.searchsorted("cucumber")
121120

122121
# Searching for multiple values one of each is not from the Categorical
123-
with pytest.raises(KeyError, match=msg):
122+
with pytest.raises(KeyError, match="cucumber"):
124123
cat.searchsorted(["bread", "cucumber"])
125-
with pytest.raises(KeyError, match=msg):
124+
with pytest.raises(KeyError, match="cucumber"):
126125
ser.searchsorted(["bread", "cucumber"])
127126

128127
def test_unique(self):

0 commit comments

Comments
 (0)