Skip to content

Commit 34ebad8

Browse files
PERF: improve MultiIndex get_loc performance (#16346)
* PERF: improve hash collision check for single MI labels * PERF: specialized hash function for single tuples
1 parent a3021ea commit 34ebad8

File tree

7 files changed

+113
-7
lines changed

7 files changed

+113
-7
lines changed

asv_bench/benchmarks/indexing.py

+12
Original file line numberDiff line numberDiff line change
@@ -227,12 +227,24 @@ def time_multiindex_get_indexer(self):
227227
def time_multiindex_large_get_loc(self):
228228
self.mi_large.get_loc((999, 19, 'Z'))
229229

230+
def time_multiindex_large_get_loc_warm(self):
231+
for _ in range(1000):
232+
self.mi_large.get_loc((999, 19, 'Z'))
233+
230234
def time_multiindex_med_get_loc(self):
231235
self.mi_med.get_loc((999, 9, 'A'))
232236

237+
def time_multiindex_med_get_loc_warm(self):
238+
for _ in range(1000):
239+
self.mi_med.get_loc((999, 9, 'A'))
240+
233241
def time_multiindex_string_get_loc(self):
234242
self.mi_small.get_loc((99, 'A', 'A'))
235243

244+
def time_multiindex_small_get_loc_warm(self):
245+
for _ in range(1000):
246+
self.mi_small.get_loc((99, 'A', 'A'))
247+
236248
def time_is_monotonic(self):
237249
self.miint.is_monotonic
238250

doc/source/whatsnew/v0.20.2.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@ Performance Improvements
2727
~~~~~~~~~~~~~~~~~~~~~~~~
2828

2929
- Performance regression fix when indexing with a list-like (:issue:`16285`)
30-
- Performance regression fix for small MultiIndexes (:issuse:`16319`)
30+
- Performance regression fix for MultiIndexes (:issue:`16319`, :issue:`16346`)
3131
- Improved performance of ``.clip()`` with scalar arguments (:issue:`15400`)
3232

33+
3334
.. _whatsnew_0202.bug_fixes:
3435

3536
Bug Fixes

pandas/_libs/hashtable.pxd

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ cdef class MultiIndexHashTable(HashTable):
3838

3939
cpdef get_item(self, object val)
4040
cpdef set_item(self, object key, Py_ssize_t val)
41+
cdef inline void _check_for_collision(self, Py_ssize_t loc, object label)
42+
4143

4244
cdef class StringHashTable(HashTable):
4345
cdef kh_str_t *table

pandas/_libs/hashtable_class_helper.pxi.in

+17-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ Template for each `dtype` helper function for hashtable
44
WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
55
"""
66

7+
from lib cimport is_null_datetimelike
8+
9+
710
#----------------------------------------------------------------------
811
# VectorData
912
#----------------------------------------------------------------------
@@ -921,6 +924,19 @@ cdef class MultiIndexHashTable(HashTable):
921924
"hash collision\nlocs:\n{}\n"
922925
"result:\n{}\nmi:\n{}".format(alocs, result, mi))
923926

927+
cdef inline void _check_for_collision(self, Py_ssize_t loc, object label):
928+
# validate that the loc maps to the actual value
929+
# version of _check_for_collisions above for single label (tuple)
930+
931+
result = self.mi[loc]
932+
933+
if not all(l == r or (is_null_datetimelike(l)
934+
and is_null_datetimelike(r))
935+
for l, r in zip(result, label)):
936+
raise AssertionError(
937+
"hash collision\nloc:\n{}\n"
938+
"result:\n{}\nmi:\n{}".format(loc, result, label))
939+
924940
def __contains__(self, object key):
925941
try:
926942
self.get_item(key)
@@ -939,8 +955,7 @@ cdef class MultiIndexHashTable(HashTable):
939955
k = kh_get_uint64(self.table, value)
940956
if k != self.table.n_buckets:
941957
loc = self.table.vals[k]
942-
locs = np.array([loc], dtype=np.int64)
943-
self._check_for_collisions(locs, key)
958+
self._check_for_collision(loc, key)
944959
return loc
945960
else:
946961
raise KeyError(key)

pandas/core/indexes/multi.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ def _hashed_indexing_key(self, key):
748748
we need to stringify if we have mixed levels
749749
750750
"""
751-
from pandas.core.util.hashing import hash_tuples
751+
from pandas.core.util.hashing import hash_tuples, hash_tuple
752752

753753
if not isinstance(key, tuple):
754754
return hash_tuples(key)
@@ -762,7 +762,7 @@ def f(k, stringify):
762762
return k
763763
key = tuple([f(k, stringify)
764764
for k, stringify in zip(key, self._have_mixed_levels)])
765-
return hash_tuples(key)
765+
return hash_tuple(key)
766766

767767
@Appender(base._shared_docs['duplicated'] % _index_doc_kwargs)
768768
def duplicated(self, keep='first'):

pandas/core/util/hashing.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,17 @@
44
import itertools
55

66
import numpy as np
7-
from pandas._libs import hashing
7+
from pandas._libs import hashing, tslib
88
from pandas.core.dtypes.generic import (
99
ABCMultiIndex,
1010
ABCIndexClass,
1111
ABCSeries,
1212
ABCDataFrame)
1313
from pandas.core.dtypes.common import (
1414
is_categorical_dtype, is_list_like)
15+
from pandas.core.dtypes.missing import isnull
16+
from pandas.core.dtypes.cast import infer_dtype_from_scalar
17+
1518

1619
# 16 byte long hashing key
1720
_default_hash_key = '0123456789123456'
@@ -164,6 +167,29 @@ def hash_tuples(vals, encoding='utf8', hash_key=None):
164167
return h
165168

166169

170+
def hash_tuple(val, encoding='utf8', hash_key=None):
171+
"""
172+
Hash a single tuple efficiently
173+
174+
Parameters
175+
----------
176+
val : single tuple
177+
encoding : string, default 'utf8'
178+
hash_key : string key to encode, default to _default_hash_key
179+
180+
Returns
181+
-------
182+
hash
183+
184+
"""
185+
hashes = (_hash_scalar(v, encoding=encoding, hash_key=hash_key)
186+
for v in val)
187+
188+
h = _combine_hash_arrays(hashes, len(val))[0]
189+
190+
return h
191+
192+
167193
def _hash_categorical(c, encoding, hash_key):
168194
"""
169195
Hash a Categorical by hashing its categories, and then mapping the codes
@@ -276,3 +302,31 @@ def hash_array(vals, encoding='utf8', hash_key=None, categorize=True):
276302
vals *= np.uint64(0x94d049bb133111eb)
277303
vals ^= vals >> 31
278304
return vals
305+
306+
307+
def _hash_scalar(val, encoding='utf8', hash_key=None):
308+
"""
309+
Hash scalar value
310+
311+
Returns
312+
-------
313+
1d uint64 numpy array of hash value, of length 1
314+
"""
315+
316+
if isnull(val):
317+
# this is to be consistent with the _hash_categorical implementation
318+
return np.array([np.iinfo(np.uint64).max], dtype='u8')
319+
320+
if getattr(val, 'tzinfo', None) is not None:
321+
# for tz-aware datetimes, we need the underlying naive UTC value and
322+
# not the tz aware object or pd extension type (as
323+
# infer_dtype_from_scalar would do)
324+
if not isinstance(val, tslib.Timestamp):
325+
val = tslib.Timestamp(val)
326+
val = val.tz_convert(None)
327+
328+
dtype, val = infer_dtype_from_scalar(val)
329+
vals = np.array([val], dtype=dtype)
330+
331+
return hash_array(vals, hash_key=hash_key, encoding=encoding,
332+
categorize=False)

pandas/tests/util/test_hashing.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import pytest
2+
import datetime
23

34
from warnings import catch_warnings
45
import numpy as np
56
import pandas as pd
67

78
from pandas import DataFrame, Series, Index, MultiIndex
89
from pandas.util import hash_array, hash_pandas_object
9-
from pandas.core.util.hashing import hash_tuples
10+
from pandas.core.util.hashing import hash_tuples, hash_tuple, _hash_scalar
1011
import pandas.util.testing as tm
1112

1213

@@ -79,6 +80,27 @@ def test_hash_tuples(self):
7980
result = hash_tuples(tups[0])
8081
assert result == expected[0]
8182

83+
def test_hash_tuple(self):
84+
# test equivalence between hash_tuples and hash_tuple
85+
for tup in [(1, 'one'), (1, np.nan), (1.0, pd.NaT, 'A'),
86+
('A', pd.Timestamp("2012-01-01"))]:
87+
result = hash_tuple(tup)
88+
expected = hash_tuples([tup])[0]
89+
assert result == expected
90+
91+
def test_hash_scalar(self):
92+
for val in [1, 1.4, 'A', b'A', u'A', pd.Timestamp("2012-01-01"),
93+
pd.Timestamp("2012-01-01", tz='Europe/Brussels'),
94+
datetime.datetime(2012, 1, 1),
95+
pd.Timestamp("2012-01-01", tz='EST').to_pydatetime(),
96+
pd.Timedelta('1 days'), datetime.timedelta(1),
97+
pd.Period('2012-01-01', freq='D'), pd.Interval(0, 1),
98+
np.nan, pd.NaT, None]:
99+
result = _hash_scalar(val)
100+
expected = hash_array(np.array([val], dtype=object),
101+
categorize=True)
102+
assert result[0] == expected[0]
103+
82104
def test_hash_tuples_err(self):
83105

84106
for val in [5, 'foo', pd.Timestamp('20130101')]:

0 commit comments

Comments
 (0)