Skip to content

Commit 669ddfb

Browse files
authored
Implement hash_join for merges (#57970)
1 parent e51039a commit 669ddfb

File tree

7 files changed

+116
-19
lines changed

7 files changed

+116
-19
lines changed

asv_bench/benchmarks/join_merge.py

+17
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,23 @@ def time_i8merge(self, how):
328328
merge(self.left, self.right, how=how)
329329

330330

331+
class UniqueMerge:
332+
params = [4_000_000, 1_000_000]
333+
param_names = ["unique_elements"]
334+
335+
def setup(self, unique_elements):
336+
N = 1_000_000
337+
self.left = DataFrame({"a": np.random.randint(1, unique_elements, (N,))})
338+
self.right = DataFrame({"a": np.random.randint(1, unique_elements, (N,))})
339+
uniques = self.right.a.drop_duplicates()
340+
self.right["a"] = concat(
341+
[uniques, Series(np.arange(0, -(N - len(uniques)), -1))], ignore_index=True
342+
)
343+
344+
def time_unique_merge(self, unique_elements):
345+
merge(self.left, self.right, how="inner")
346+
347+
331348
class MergeDatetime:
332349
params = [
333350
[

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ Performance improvements
286286
- Performance improvement in :meth:`RangeIndex.join` returning a :class:`RangeIndex` instead of a :class:`Index` when possible. (:issue:`57651`, :issue:`57752`)
287287
- Performance improvement in :meth:`RangeIndex.reindex` returning a :class:`RangeIndex` instead of a :class:`Index` when possible. (:issue:`57647`, :issue:`57752`)
288288
- Performance improvement in :meth:`RangeIndex.take` returning a :class:`RangeIndex` instead of a :class:`Index` when possible. (:issue:`57445`, :issue:`57752`)
289+
- Performance improvement in :func:`merge` if hash-join can be used (:issue:`57970`)
289290
- Performance improvement in ``DataFrameGroupBy.__len__`` and ``SeriesGroupBy.__len__`` (:issue:`57595`)
290291
- Performance improvement in indexing operations for string dtypes (:issue:`56997`)
291292
- Performance improvement in unary methods on a :class:`RangeIndex` returning a :class:`RangeIndex` instead of a :class:`Index` when possible. (:issue:`57825`)

pandas/_libs/hashtable.pyi

+7-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def unique_label_indices(
1616
class Factorizer:
1717
count: int
1818
uniques: Any
19-
def __init__(self, size_hint: int) -> None: ...
19+
def __init__(self, size_hint: int, uses_mask: bool = False) -> None: ...
2020
def get_count(self) -> int: ...
2121
def factorize(
2222
self,
@@ -25,6 +25,9 @@ class Factorizer:
2525
na_value=...,
2626
mask=...,
2727
) -> npt.NDArray[np.intp]: ...
28+
def hash_inner_join(
29+
self, values: np.ndarray, mask=...
30+
) -> tuple[np.ndarray, np.ndarray]: ...
2831

2932
class ObjectFactorizer(Factorizer):
3033
table: PyObjectHashTable
@@ -216,6 +219,9 @@ class HashTable:
216219
mask=...,
217220
ignore_na: bool = True,
218221
) -> tuple[np.ndarray, npt.NDArray[np.intp]]: ... # np.ndarray[subclass-specific]
222+
def hash_inner_join(
223+
self, values: np.ndarray, mask=...
224+
) -> tuple[np.ndarray, np.ndarray]: ...
219225

220226
class Complex128HashTable(HashTable): ...
221227
class Complex64HashTable(HashTable): ...

pandas/_libs/hashtable.pyx

+5-2
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ cdef class Factorizer:
7070
cdef readonly:
7171
Py_ssize_t count
7272

73-
def __cinit__(self, size_hint: int):
73+
def __cinit__(self, size_hint: int, uses_mask: bool = False):
7474
self.count = 0
7575

7676
def get_count(self) -> int:
@@ -79,13 +79,16 @@ cdef class Factorizer:
7979
def factorize(self, values, na_sentinel=-1, na_value=None, mask=None) -> np.ndarray:
8080
raise NotImplementedError
8181

82+
def hash_inner_join(self, values, mask=None):
83+
raise NotImplementedError
84+
8285

8386
cdef class ObjectFactorizer(Factorizer):
8487
cdef public:
8588
PyObjectHashTable table
8689
ObjectVector uniques
8790

88-
def __cinit__(self, size_hint: int):
91+
def __cinit__(self, size_hint: int, uses_mask: bool = False):
8992
self.table = PyObjectHashTable(size_hint)
9093
self.uniques = ObjectVector()
9194

pandas/_libs/hashtable_class_helper.pxi.in

+48-2
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,49 @@ cdef class {{name}}HashTable(HashTable):
557557
self.table.vals[k] = i
558558
self.na_position = na_position
559559

560+
@cython.wraparound(False)
561+
@cython.boundscheck(False)
562+
def hash_inner_join(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> tuple[ndarray, ndarray]:
563+
cdef:
564+
Py_ssize_t i, n = len(values)
565+
{{c_type}} val
566+
khiter_t k
567+
Int64Vector locs = Int64Vector()
568+
Int64Vector self_locs = Int64Vector()
569+
Int64VectorData *l
570+
Int64VectorData *sl
571+
int8_t na_position = self.na_position
572+
573+
l = &locs.data
574+
sl = &self_locs.data
575+
576+
if self.uses_mask and mask is None:
577+
raise NotImplementedError # pragma: no cover
578+
579+
with nogil:
580+
for i in range(n):
581+
if self.uses_mask and mask[i]:
582+
if self.na_position == -1:
583+
continue
584+
if needs_resize(l.size, l.capacity):
585+
with gil:
586+
locs.resize(locs.data.capacity * 4)
587+
self_locs.resize(locs.data.capacity * 4)
588+
append_data_int64(l, i)
589+
append_data_int64(sl, na_position)
590+
else:
591+
val = {{to_c_type}}(values[i])
592+
k = kh_get_{{dtype}}(self.table, val)
593+
if k != self.table.n_buckets:
594+
if needs_resize(l.size, l.capacity):
595+
with gil:
596+
locs.resize(locs.data.capacity * 4)
597+
self_locs.resize(locs.data.capacity * 4)
598+
append_data_int64(l, i)
599+
append_data_int64(sl, self.table.vals[k])
600+
601+
return self_locs.to_array(), locs.to_array()
602+
560603
@cython.boundscheck(False)
561604
def lookup(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> ndarray:
562605
# -> np.ndarray[np.intp]
@@ -879,8 +922,8 @@ cdef class {{name}}Factorizer(Factorizer):
879922
{{name}}HashTable table
880923
{{name}}Vector uniques
881924

882-
def __cinit__(self, size_hint: int):
883-
self.table = {{name}}HashTable(size_hint)
925+
def __cinit__(self, size_hint: int, uses_mask: bool = False):
926+
self.table = {{name}}HashTable(size_hint, uses_mask=uses_mask)
884927
self.uniques = {{name}}Vector()
885928

886929
def factorize(self, const {{c_type}}[:] values,
@@ -911,6 +954,9 @@ cdef class {{name}}Factorizer(Factorizer):
911954
self.count = len(self.uniques)
912955
return labels
913956

957+
def hash_inner_join(self, const {{c_type}}[:] values, const uint8_t[:] mask = None) -> tuple[np.ndarray, np.ndarray]:
958+
return self.table.hash_inner_join(values, mask)
959+
914960
{{endfor}}
915961

916962

pandas/core/reshape/merge.py

+37-14
Original file line numberDiff line numberDiff line change
@@ -1780,7 +1780,10 @@ def get_join_indexers_non_unique(
17801780
np.ndarray[np.intp]
17811781
Indexer into right.
17821782
"""
1783-
lkey, rkey, count = _factorize_keys(left, right, sort=sort)
1783+
lkey, rkey, count = _factorize_keys(left, right, sort=sort, how=how)
1784+
if count == -1:
1785+
# hash join
1786+
return lkey, rkey
17841787
if how == "left":
17851788
lidx, ridx = libjoin.left_outer_join(lkey, rkey, count, sort=sort)
17861789
elif how == "right":
@@ -2385,7 +2388,10 @@ def _left_join_on_index(
23852388

23862389

23872390
def _factorize_keys(
2388-
lk: ArrayLike, rk: ArrayLike, sort: bool = True
2391+
lk: ArrayLike,
2392+
rk: ArrayLike,
2393+
sort: bool = True,
2394+
how: str | None = None,
23892395
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp], int]:
23902396
"""
23912397
Encode left and right keys as enumerated types.
@@ -2401,6 +2407,9 @@ def _factorize_keys(
24012407
sort : bool, defaults to True
24022408
If True, the encoding is done such that the unique elements in the
24032409
keys are sorted.
2410+
how: str, optional
2411+
Used to determine if we can use hash-join. If not given, then just factorize
2412+
keys.
24042413
24052414
Returns
24062415
-------
@@ -2409,7 +2418,8 @@ def _factorize_keys(
24092418
np.ndarray[np.intp]
24102419
Right (resp. left if called with `key='right'`) labels, as enumerated type.
24112420
int
2412-
Number of unique elements in union of left and right labels.
2421+
Number of unique elements in union of left and right labels. -1 if we used
2422+
a hash-join.
24132423
24142424
See Also
24152425
--------
@@ -2527,28 +2537,41 @@ def _factorize_keys(
25272537

25282538
klass, lk, rk = _convert_arrays_and_get_rizer_klass(lk, rk)
25292539

2530-
rizer = klass(max(len(lk), len(rk)))
2540+
rizer = klass(
2541+
max(len(lk), len(rk)),
2542+
uses_mask=isinstance(rk, (BaseMaskedArray, ArrowExtensionArray)),
2543+
)
25312544

25322545
if isinstance(lk, BaseMaskedArray):
25332546
assert isinstance(rk, BaseMaskedArray)
2534-
llab = rizer.factorize(lk._data, mask=lk._mask)
2535-
rlab = rizer.factorize(rk._data, mask=rk._mask)
2547+
lk_data, lk_mask = lk._data, lk._mask
2548+
rk_data, rk_mask = rk._data, rk._mask
25362549
elif isinstance(lk, ArrowExtensionArray):
25372550
assert isinstance(rk, ArrowExtensionArray)
25382551
# we can only get here with numeric dtypes
25392552
# TODO: Remove when we have a Factorizer for Arrow
2540-
llab = rizer.factorize(
2541-
lk.to_numpy(na_value=1, dtype=lk.dtype.numpy_dtype), mask=lk.isna()
2542-
)
2543-
rlab = rizer.factorize(
2544-
rk.to_numpy(na_value=1, dtype=lk.dtype.numpy_dtype), mask=rk.isna()
2545-
)
2553+
lk_data = lk.to_numpy(na_value=1, dtype=lk.dtype.numpy_dtype)
2554+
rk_data = rk.to_numpy(na_value=1, dtype=lk.dtype.numpy_dtype)
2555+
lk_mask, rk_mask = lk.isna(), rk.isna()
25462556
else:
25472557
# Argument 1 to "factorize" of "ObjectFactorizer" has incompatible type
25482558
# "Union[ndarray[Any, dtype[signedinteger[_64Bit]]],
25492559
# ndarray[Any, dtype[object_]]]"; expected "ndarray[Any, dtype[object_]]"
2550-
llab = rizer.factorize(lk) # type: ignore[arg-type]
2551-
rlab = rizer.factorize(rk) # type: ignore[arg-type]
2560+
lk_data, rk_data = lk, rk # type: ignore[assignment]
2561+
lk_mask, rk_mask = None, None
2562+
2563+
hash_join_available = how == "inner" and not sort and lk.dtype.kind in "iufb"
2564+
if hash_join_available:
2565+
rlab = rizer.factorize(rk_data, mask=rk_mask)
2566+
if rizer.get_count() == len(rlab):
2567+
ridx, lidx = rizer.hash_inner_join(lk_data, lk_mask)
2568+
return lidx, ridx, -1
2569+
else:
2570+
llab = rizer.factorize(lk_data, mask=lk_mask)
2571+
else:
2572+
llab = rizer.factorize(lk_data, mask=lk_mask)
2573+
rlab = rizer.factorize(rk_data, mask=rk_mask)
2574+
25522575
assert llab.dtype == np.dtype(np.intp), llab.dtype
25532576
assert rlab.dtype == np.dtype(np.intp), rlab.dtype
25542577

scripts/run_stubtest.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
"pandas._libs.hashtable.HashTable.set_na",
4545
"pandas._libs.hashtable.HashTable.sizeof",
4646
"pandas._libs.hashtable.HashTable.unique",
47+
"pandas._libs.hashtable.HashTable.hash_inner_join",
4748
# stubtest might be too sensitive
4849
"pandas._libs.lib.NoDefault",
4950
"pandas._libs.lib._NoDefault.no_default",

0 commit comments

Comments
 (0)