Skip to content

Commit 9a44ced

Browse files
jbrockmendelquintusdias
authored andcommitted
CLN: simplify join take call (pandas-dev#27531)
1 parent a4f0092 commit 9a44ced

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

pandas/_libs/join.pyx

+11-9
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ from numpy cimport (ndarray,
88
uint32_t, uint64_t, float32_t, float64_t)
99
cnp.import_array()
1010

11-
from pandas._libs.algos import groupsort_indexer, ensure_platform_int
12-
from pandas.core.algorithms import take_nd
11+
from pandas._libs.algos import (
12+
groupsort_indexer, ensure_platform_int, take_1d_int64_int64
13+
)
1314

1415

1516
def inner_join(const int64_t[:] left, const int64_t[:] right,
@@ -67,8 +68,8 @@ def left_outer_join(const int64_t[:] left, const int64_t[:] right,
6768
Py_ssize_t max_groups, sort=True):
6869
cdef:
6970
Py_ssize_t i, j, k, count = 0
70-
ndarray[int64_t] left_count, right_count
71-
ndarray left_sorter, right_sorter, rev
71+
ndarray[int64_t] left_count, right_count, left_sorter, right_sorter
72+
ndarray rev
7273
ndarray[int64_t] left_indexer, right_indexer
7374
int64_t lc, rc
7475

@@ -124,10 +125,8 @@ def left_outer_join(const int64_t[:] left, const int64_t[:] right,
124125
# no multiple matches for any row on the left
125126
# this is a short-cut to avoid groupsort_indexer
126127
# otherwise, the `else` path also works in this case
127-
left_sorter = ensure_platform_int(left_sorter)
128-
129128
rev = np.empty(len(left), dtype=np.intp)
130-
rev.put(left_sorter, np.arange(len(left)))
129+
rev.put(ensure_platform_int(left_sorter), np.arange(len(left)))
131130
else:
132131
rev, _ = groupsort_indexer(left_indexer, len(left))
133132

@@ -201,9 +200,12 @@ def full_outer_join(const int64_t[:] left, const int64_t[:] right,
201200
_get_result_indexer(right_sorter, right_indexer))
202201

203202

204-
def _get_result_indexer(sorter, indexer):
203+
cdef _get_result_indexer(ndarray[int64_t] sorter, ndarray[int64_t] indexer):
205204
if len(sorter) > 0:
206-
res = take_nd(sorter, indexer, fill_value=-1)
205+
# cython-only equivalent to
206+
# `res = algos.take_nd(sorter, indexer, fill_value=-1)`
207+
res = np.empty(len(indexer), dtype=np.int64)
208+
take_1d_int64_int64(sorter, indexer, res, -1)
207209
else:
208210
# length-0 case
209211
res = np.empty(len(indexer), dtype=np.int64)

0 commit comments

Comments
 (0)