Skip to content

Commit 5ea49ef

Browse files
WillAydharisbal
authored and
harisbal
committed
PERF: Cythonize Groupby Rank (pandas-dev#19481)
1 parent 7246381 commit 5ea49ef

File tree

7 files changed

+406
-23
lines changed

7 files changed

+406
-23
lines changed

doc/source/whatsnew/v0.23.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,7 @@ Performance Improvements
581581
- Improved performance of :func:`DataFrame.median` with ``axis=1`` when bottleneck is not installed (:issue:`16468`)
582582
- Improved performance of :func:`MultiIndex.get_loc` for large indexes, at the cost of a reduction in performance for small ones (:issue:`18519`)
583583
- Improved performance of pairwise ``.rolling()`` and ``.expanding()`` with ``.cov()`` and ``.corr()`` operations (:issue:`17917`)
584+
- Improved performance of :func:`DataFrameGroupBy.rank` (:issue:`15779`)
584585

585586
.. _whatsnew_0230.docs:
586587

pandas/_libs/algos.pxd

+8
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,11 @@ cdef inline Py_ssize_t swap(numeric *a, numeric *b) nogil:
1111
a[0] = b[0]
1212
b[0] = t
1313
return 0
14+
15+
cdef enum TiebreakEnumType:
16+
TIEBREAK_AVERAGE
17+
TIEBREAK_MIN,
18+
TIEBREAK_MAX
19+
TIEBREAK_FIRST
20+
TIEBREAK_FIRST_DESCENDING
21+
TIEBREAK_DENSE

pandas/_libs/algos.pyx

-8
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,6 @@ cdef double nan = NaN
3131

3232
cdef int64_t iNaT = get_nat()
3333

34-
cdef:
35-
int TIEBREAK_AVERAGE = 0
36-
int TIEBREAK_MIN = 1
37-
int TIEBREAK_MAX = 2
38-
int TIEBREAK_FIRST = 3
39-
int TIEBREAK_FIRST_DESCENDING = 4
40-
int TIEBREAK_DENSE = 5
41-
4234
tiebreakers = {
4335
'average': TIEBREAK_AVERAGE,
4436
'min': TIEBREAK_MIN,

pandas/_libs/groupby.pyx

+3-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ from numpy cimport (ndarray,
1616
from libc.stdlib cimport malloc, free
1717

1818
from util cimport numeric, get_nat
19-
from algos cimport swap
20-
from algos import take_2d_axis1_float64_float64, groupsort_indexer
19+
from algos cimport (swap, TiebreakEnumType, TIEBREAK_AVERAGE, TIEBREAK_MIN,
20+
TIEBREAK_MAX, TIEBREAK_FIRST, TIEBREAK_DENSE)
21+
from algos import take_2d_axis1_float64_float64, groupsort_indexer, tiebreakers
2122

2223
cdef int64_t iNaT = get_nat()
2324

pandas/_libs/groupby_helper.pxi.in

+165
Original file line numberDiff line numberDiff line change
@@ -444,8 +444,173 @@ def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
444444
else:
445445
out[i, j] = resx[i, j]
446446

447+
448+
@cython.boundscheck(False)
449+
@cython.wraparound(False)
450+
def group_rank_{{name}}(ndarray[float64_t, ndim=2] out,
451+
ndarray[{{c_type}}, ndim=2] values,
452+
ndarray[int64_t] labels,
453+
bint is_datetimelike, object ties_method,
454+
bint ascending, bint pct, object na_option):
455+
"""Provides the rank of values within each group
456+
457+
Parameters
458+
----------
459+
out : array of float64_t values which this method will write its results to
460+
values : array of {{c_type}} values to be ranked
461+
labels : array containing unique label for each group, with its ordering
462+
matching up to the corresponding record in `values`
463+
is_datetimelike : bool
464+
unused in this method but provided for call compatability with other
465+
Cython transformations
466+
ties_method : {'keep', 'top', 'bottom'}
467+
* keep: leave NA values where they are
468+
* top: smallest rank if ascending
469+
* bottom: smallest rank if descending
470+
ascending : boolean
471+
False for ranks by high (1) to low (N)
472+
pct : boolean
473+
Compute percentage rank of data within each group
474+
475+
Notes
476+
-----
477+
This method modifies the `out` parameter rather than returning an object
478+
"""
479+
cdef:
480+
TiebreakEnumType tiebreak
481+
Py_ssize_t i, j, N, K, val_start=0, grp_start=0, dups=0, sum_ranks=0
482+
Py_ssize_t grp_vals_seen=1, grp_na_count=0
483+
ndarray[int64_t] _as
484+
ndarray[float64_t, ndim=2] grp_sizes
485+
ndarray[{{c_type}}] masked_vals
486+
ndarray[uint8_t] mask
487+
bint keep_na
488+
{{c_type}} nan_fill_val
489+
490+
tiebreak = tiebreakers[ties_method]
491+
keep_na = na_option == 'keep'
492+
N, K = (<object> values).shape
493+
grp_sizes = np.ones_like(out)
494+
495+
# Copy values into new array in order to fill missing data
496+
# with mask, without obfuscating location of missing data
497+
# in values array
498+
masked_vals = np.array(values[:, 0], copy=True)
499+
{{if name=='int64'}}
500+
mask = (masked_vals == {{nan_val}}).astype(np.uint8)
501+
{{else}}
502+
mask = np.isnan(masked_vals).astype(np.uint8)
503+
{{endif}}
504+
505+
if ascending ^ (na_option == 'top'):
506+
{{if name == 'int64'}}
507+
nan_fill_val = np.iinfo(np.int64).max
508+
{{else}}
509+
nan_fill_val = np.inf
510+
{{endif}}
511+
order = (masked_vals, mask, labels)
512+
else:
513+
{{if name == 'int64'}}
514+
nan_fill_val = np.iinfo(np.int64).min
515+
{{else}}
516+
nan_fill_val = -np.inf
517+
{{endif}}
518+
order = (masked_vals, ~mask, labels)
519+
np.putmask(masked_vals, mask, nan_fill_val)
520+
521+
# lexsort using labels, then mask, then actual values
522+
# each label corresponds to a different group value,
523+
# the mask helps you differentiate missing values before
524+
# performing sort on the actual values
525+
_as = np.lexsort(order)
526+
527+
if not ascending:
528+
_as = _as[::-1]
529+
530+
with nogil:
531+
# Loop over the length of the value array
532+
# each incremental i value can be looked up in the _as array
533+
# that we sorted previously, which gives us the location of
534+
# that sorted value for retrieval back from the original
535+
# values / masked_vals arrays
536+
for i in range(N):
537+
# dups and sum_ranks will be incremented each loop where
538+
# the value / group remains the same, and should be reset
539+
# when either of those change
540+
# Used to calculate tiebreakers
541+
dups += 1
542+
sum_ranks += i - grp_start + 1
543+
544+
# if keep_na, check for missing values and assign back
545+
# to the result where appropriate
546+
if keep_na and masked_vals[_as[i]] == nan_fill_val:
547+
grp_na_count += 1
548+
out[_as[i], 0] = nan
549+
else:
550+
# this implementation is inefficient because it will
551+
# continue overwriting previously encountered dups
552+
# i.e. if 5 duplicated values are encountered it will
553+
# write to the result as follows (assumes avg tiebreaker):
554+
# 1
555+
# .5 .5
556+
# .33 .33 .33
557+
# .25 .25 .25 .25
558+
# .2 .2 .2 .2 .2
559+
#
560+
# could potentially be optimized to only write to the
561+
# result once the last duplicate value is encountered
562+
if tiebreak == TIEBREAK_AVERAGE:
563+
for j in range(i - dups + 1, i + 1):
564+
out[_as[j], 0] = sum_ranks / <float64_t>dups
565+
elif tiebreak == TIEBREAK_MIN:
566+
for j in range(i - dups + 1, i + 1):
567+
out[_as[j], 0] = i - grp_start - dups + 2
568+
elif tiebreak == TIEBREAK_MAX:
569+
for j in range(i - dups + 1, i + 1):
570+
out[_as[j], 0] = i - grp_start + 1
571+
elif tiebreak == TIEBREAK_FIRST:
572+
for j in range(i - dups + 1, i + 1):
573+
if ascending:
574+
out[_as[j], 0] = j + 1 - grp_start
575+
else:
576+
out[_as[j], 0] = 2 * i - j - dups + 2 - grp_start
577+
elif tiebreak == TIEBREAK_DENSE:
578+
for j in range(i - dups + 1, i + 1):
579+
out[_as[j], 0] = grp_vals_seen
580+
581+
# look forward to the next value (using the sorting in _as)
582+
# if the value does not equal the current value then we need to
583+
# reset the dups and sum_ranks, knowing that a new value is coming
584+
# up. the conditional also needs to handle nan equality and the
585+
# end of iteration
586+
if (i == N - 1 or (
587+
(masked_vals[_as[i]] != masked_vals[_as[i+1]]) and not
588+
(mask[_as[i]] and mask[_as[i+1]]))):
589+
dups = sum_ranks = 0
590+
val_start = i
591+
grp_vals_seen += 1
592+
593+
# Similar to the previous conditional, check now if we are moving
594+
# to a new group. If so, keep track of the index where the new
595+
# group occurs, so the tiebreaker calculations can decrement that
596+
# from their position. fill in the size of each group encountered
597+
# (used by pct calculations later). also be sure to reset any of
598+
# the items helping to calculate dups
599+
if i == N - 1 or labels[_as[i]] != labels[_as[i+1]]:
600+
for j in range(grp_start, i + 1):
601+
grp_sizes[_as[j], 0] = i - grp_start + 1 - grp_na_count
602+
dups = sum_ranks = 0
603+
grp_na_count = 0
604+
val_start = i + 1
605+
grp_start = i + 1
606+
grp_vals_seen = 1
607+
608+
if pct:
609+
for i in range(N):
610+
out[i, 0] = out[i, 0] / grp_sizes[i, 0]
447611
{{endfor}}
448612

613+
449614
#----------------------------------------------------------------------
450615
# group_min, group_max
451616
#----------------------------------------------------------------------

pandas/core/groupby.py

+63-13
Original file line numberDiff line numberDiff line change
@@ -994,20 +994,24 @@ def _transform_should_cast(self, func_nm):
994994
return (self.size().fillna(0) > 0).any() and (func_nm not in
995995
_cython_cast_blacklist)
996996

997-
def _cython_transform(self, how, numeric_only=True):
997+
def _cython_transform(self, how, numeric_only=True, **kwargs):
998998
output = collections.OrderedDict()
999999
for name, obj in self._iterate_slices():
10001000
is_numeric = is_numeric_dtype(obj.dtype)
10011001
if numeric_only and not is_numeric:
10021002
continue
10031003

10041004
try:
1005-
result, names = self.grouper.transform(obj.values, how)
1005+
result, names = self.grouper.transform(obj.values, how,
1006+
**kwargs)
10061007
except NotImplementedError:
10071008
continue
10081009
except AssertionError as e:
10091010
raise GroupByError(str(e))
1010-
output[name] = self._try_cast(result, obj)
1011+
if self._transform_should_cast(how):
1012+
output[name] = self._try_cast(result, obj)
1013+
else:
1014+
output[name] = result
10111015

10121016
if len(output) == 0:
10131017
raise DataError('No numeric types to aggregate')
@@ -1768,6 +1772,37 @@ def cumcount(self, ascending=True):
17681772
cumcounts = self._cumcount_array(ascending=ascending)
17691773
return Series(cumcounts, index)
17701774

1775+
@Substitution(name='groupby')
1776+
@Appender(_doc_template)
1777+
def rank(self, method='average', ascending=True, na_option='keep',
1778+
pct=False, axis=0):
1779+
"""Provides the rank of values within each group
1780+
1781+
Parameters
1782+
----------
1783+
method : {'average', 'min', 'max', 'first', 'dense'}, efault 'average'
1784+
* average: average rank of group
1785+
* min: lowest rank in group
1786+
* max: highest rank in group
1787+
* first: ranks assigned in order they appear in the array
1788+
* dense: like 'min', but rank always increases by 1 between groups
1789+
method : {'keep', 'top', 'bottom'}, default 'keep'
1790+
* keep: leave NA values where they are
1791+
* top: smallest rank if ascending
1792+
* bottom: smallest rank if descending
1793+
ascending : boolean, default True
1794+
False for ranks by high (1) to low (N)
1795+
pct : boolean, default False
1796+
Compute percentage rank of data within each group
1797+
1798+
Returns
1799+
-----
1800+
DataFrame with ranking of values within each group
1801+
"""
1802+
return self._cython_transform('rank', numeric_only=False,
1803+
ties_method=method, ascending=ascending,
1804+
na_option=na_option, pct=pct, axis=axis)
1805+
17711806
@Substitution(name='groupby')
17721807
@Appender(_doc_template)
17731808
def cumprod(self, axis=0, *args, **kwargs):
@@ -2183,6 +2218,16 @@ def get_group_levels(self):
21832218
'cumsum': 'group_cumsum',
21842219
'cummin': 'group_cummin',
21852220
'cummax': 'group_cummax',
2221+
'rank': {
2222+
'name': 'group_rank',
2223+
'f': lambda func, a, b, c, d, **kwargs: func(
2224+
a, b, c, d,
2225+
kwargs.get('ties_method', 'average'),
2226+
kwargs.get('ascending', True),
2227+
kwargs.get('pct', False),
2228+
kwargs.get('na_option', 'keep')
2229+
)
2230+
}
21862231
}
21872232
}
21882233

@@ -2242,7 +2287,8 @@ def wrapper(*args, **kwargs):
22422287
(how, dtype_str))
22432288
return func
22442289

2245-
def _cython_operation(self, kind, values, how, axis, min_count=-1):
2290+
def _cython_operation(self, kind, values, how, axis, min_count=-1,
2291+
**kwargs):
22462292
assert kind in ['transform', 'aggregate']
22472293

22482294
# can we do this operation with our cython functions
@@ -2314,10 +2360,13 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1):
23142360
else:
23152361
raise
23162362

2317-
if is_numeric:
2318-
out_dtype = '%s%d' % (values.dtype.kind, values.dtype.itemsize)
2363+
if how == 'rank':
2364+
out_dtype = 'float'
23192365
else:
2320-
out_dtype = 'object'
2366+
if is_numeric:
2367+
out_dtype = '%s%d' % (values.dtype.kind, values.dtype.itemsize)
2368+
else:
2369+
out_dtype = 'object'
23212370

23222371
labels, _, _ = self.group_info
23232372

@@ -2334,7 +2383,8 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1):
23342383

23352384
# TODO: min_count
23362385
result = self._transform(
2337-
result, values, labels, func, is_numeric, is_datetimelike)
2386+
result, values, labels, func, is_numeric, is_datetimelike,
2387+
**kwargs)
23382388

23392389
if is_integer_dtype(result) and not is_datetimelike:
23402390
mask = result == iNaT
@@ -2373,8 +2423,8 @@ def aggregate(self, values, how, axis=0, min_count=-1):
23732423
return self._cython_operation('aggregate', values, how, axis,
23742424
min_count=min_count)
23752425

2376-
def transform(self, values, how, axis=0):
2377-
return self._cython_operation('transform', values, how, axis)
2426+
def transform(self, values, how, axis=0, **kwargs):
2427+
return self._cython_operation('transform', values, how, axis, **kwargs)
23782428

23792429
def _aggregate(self, result, counts, values, comp_ids, agg_func,
23802430
is_numeric, is_datetimelike, min_count=-1):
@@ -2394,7 +2444,7 @@ def _aggregate(self, result, counts, values, comp_ids, agg_func,
23942444
return result
23952445

23962446
def _transform(self, result, values, comp_ids, transform_func,
2397-
is_numeric, is_datetimelike):
2447+
is_numeric, is_datetimelike, **kwargs):
23982448

23992449
comp_ids, _, ngroups = self.group_info
24002450
if values.ndim > 3:
@@ -2406,9 +2456,9 @@ def _transform(self, result, values, comp_ids, transform_func,
24062456

24072457
chunk = chunk.squeeze()
24082458
transform_func(result[:, :, i], values,
2409-
comp_ids, is_datetimelike)
2459+
comp_ids, is_datetimelike, **kwargs)
24102460
else:
2411-
transform_func(result, values, comp_ids, is_datetimelike)
2461+
transform_func(result, values, comp_ids, is_datetimelike, **kwargs)
24122462

24132463
return result
24142464

0 commit comments

Comments
 (0)