Skip to content

Commit 86d5980

Browse files
authored
REF/PERF: deduplicate kth_smallest (#40559)
1 parent 79f1801 commit 86d5980

File tree

4 files changed

+78
-73
lines changed

4 files changed

+78
-73
lines changed

pandas/_libs/algos.pxd

+1-18
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,4 @@
11
from pandas._libs.util cimport numeric
22

33

4-
cdef inline Py_ssize_t swap(numeric *a, numeric *b) nogil:
5-
cdef:
6-
numeric t
7-
8-
# cython doesn't allow pointer dereference so use array syntax
9-
t = a[0]
10-
a[0] = b[0]
11-
b[0] = t
12-
return 0
13-
14-
15-
cdef enum TiebreakEnumType:
16-
TIEBREAK_AVERAGE
17-
TIEBREAK_MIN,
18-
TIEBREAK_MAX
19-
TIEBREAK_FIRST
20-
TIEBREAK_FIRST_DESCENDING
21-
TIEBREAK_DENSE
4+
cdef numeric kth_smallest_c(numeric* arr, Py_ssize_t k, Py_ssize_t n) nogil

pandas/_libs/algos.pyx

+72-23
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ cdef:
6464
float64_t NaN = <float64_t>np.NaN
6565
int64_t NPY_NAT = get_nat()
6666

67+
cdef enum TiebreakEnumType:
68+
TIEBREAK_AVERAGE
69+
TIEBREAK_MIN,
70+
TIEBREAK_MAX
71+
TIEBREAK_FIRST
72+
TIEBREAK_FIRST_DESCENDING
73+
TIEBREAK_DENSE
74+
6775
tiebreakers = {
6876
"average": TIEBREAK_AVERAGE,
6977
"min": TIEBREAK_MIN,
@@ -237,34 +245,75 @@ def groupsort_indexer(const int64_t[:] index, Py_ssize_t ngroups):
237245
return indexer, counts
238246

239247

248+
cdef inline Py_ssize_t swap(numeric *a, numeric *b) nogil:
249+
cdef:
250+
numeric t
251+
252+
# cython doesn't allow pointer dereference so use array syntax
253+
t = a[0]
254+
a[0] = b[0]
255+
b[0] = t
256+
return 0
257+
258+
259+
cdef inline numeric kth_smallest_c(numeric* arr, Py_ssize_t k, Py_ssize_t n) nogil:
260+
"""
261+
See kth_smallest.__doc__. The additional parameter n specifies the maximum
262+
number of elements considered in arr, needed for compatibility with usage
263+
in groupby.pyx
264+
"""
265+
cdef:
266+
Py_ssize_t i, j, l, m
267+
numeric x
268+
269+
l = 0
270+
m = n - 1
271+
272+
while l < m:
273+
x = arr[k]
274+
i = l
275+
j = m
276+
277+
while 1:
278+
while arr[i] < x: i += 1
279+
while x < arr[j]: j -= 1
280+
if i <= j:
281+
swap(&arr[i], &arr[j])
282+
i += 1; j -= 1
283+
284+
if i > j: break
285+
286+
if j < k: l = i
287+
if k < i: m = j
288+
return arr[k]
289+
290+
240291
@cython.boundscheck(False)
241292
@cython.wraparound(False)
242-
def kth_smallest(numeric[:] a, Py_ssize_t k) -> numeric:
293+
def kth_smallest(numeric[::1] arr, Py_ssize_t k) -> numeric:
294+
"""
295+
Compute the kth smallest value in arr. Note that the input
296+
array will be modified.
297+
298+
Parameters
299+
----------
300+
arr : numeric[::1]
301+
Array to compute the kth smallest value for, must be
302+
contiguous
303+
k : Py_ssize_t
304+
305+
Returns
306+
-------
307+
numeric
308+
The kth smallest value in arr
309+
"""
243310
cdef:
244-
Py_ssize_t i, j, l, m, n = a.shape[0]
245-
numeric x
311+
numeric result
246312

247313
with nogil:
248-
l = 0
249-
m = n - 1
250-
251-
while l < m:
252-
x = a[k]
253-
i = l
254-
j = m
255-
256-
while 1:
257-
while a[i] < x: i += 1
258-
while x < a[j]: j -= 1
259-
if i <= j:
260-
swap(&a[i], &a[j])
261-
i += 1; j -= 1
262-
263-
if i > j: break
264-
265-
if j < k: l = i
266-
if k < i: m = j
267-
return a[k]
314+
result = kth_smallest_c(&arr[0], k, arr.shape[0])
315+
316+
return result
268317

269318

270319
# ----------------------------------------------------------------------

pandas/_libs/groupby.pyx

+2-31
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ from numpy.math cimport NAN
3030

3131
cnp.import_array()
3232

33-
from pandas._libs.algos cimport swap
33+
from pandas._libs.algos cimport kth_smallest_c
3434
from pandas._libs.util cimport (
3535
get_nat,
3636
numeric,
@@ -88,7 +88,7 @@ cdef inline float64_t median_linear(float64_t* a, int n) nogil:
8888
n -= na_count
8989

9090
if n % 2:
91-
result = kth_smallest_c( a, n // 2, n)
91+
result = kth_smallest_c(a, n // 2, n)
9292
else:
9393
result = (kth_smallest_c(a, n // 2, n) +
9494
kth_smallest_c(a, n // 2 - 1, n)) / 2
@@ -99,35 +99,6 @@ cdef inline float64_t median_linear(float64_t* a, int n) nogil:
9999
return result
100100

101101

102-
# TODO: Is this redundant with algos.kth_smallest
103-
cdef inline float64_t kth_smallest_c(float64_t* a,
104-
Py_ssize_t k,
105-
Py_ssize_t n) nogil:
106-
cdef:
107-
Py_ssize_t i, j, l, m
108-
float64_t x, t
109-
110-
l = 0
111-
m = n - 1
112-
while l < m:
113-
x = a[k]
114-
i = l
115-
j = m
116-
117-
while 1:
118-
while a[i] < x: i += 1
119-
while x < a[j]: j -= 1
120-
if i <= j:
121-
swap(&a[i], &a[j])
122-
i += 1; j -= 1
123-
124-
if i > j: break
125-
126-
if j < k: l = i
127-
if k < i: m = j
128-
return a[k]
129-
130-
131102
@cython.boundscheck(False)
132103
@cython.wraparound(False)
133104
def group_median_float64(ndarray[float64_t, ndim=2] out,

pandas/core/algorithms.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1307,7 +1307,9 @@ def compute(self, method: str) -> Series:
13071307
narr = len(arr)
13081308
n = min(n, narr)
13091309

1310-
kth_val = algos.kth_smallest(arr.copy(), n - 1)
1310+
# arr passed into kth_smallest must be contiguous. We copy
1311+
# here because kth_smallest will modify its input
1312+
kth_val = algos.kth_smallest(arr.copy(order="C"), n - 1)
13111313
(ns,) = np.nonzero(arr <= kth_val)
13121314
inds = ns[arr[ns].argsort(kind="mergesort")]
13131315

0 commit comments

Comments
 (0)