Skip to content

Commit 01468d1

Browse files
committed
Allowed kwargs to pass through to Cython func
1 parent dfd1549 commit 01468d1

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

pandas/_libs/groupby_helper.pxi.in

+1-1
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
450450
def group_rank_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
451451
ndarray[{{c_type}}, ndim=2] values,
452452
ndarray[int64_t] labels,
453-
bint is_datetimelike):
453+
bint is_datetimelike, **kwargs):
454454
"""
455455
Only transforms on axis=0
456456
"""

pandas/core/groupby.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -982,15 +982,15 @@ def _transform_should_cast(self, func_nm):
982982
return (self.size().fillna(0) > 0).any() and (func_nm not in
983983
_cython_cast_blacklist)
984984

985-
def _cython_transform(self, how, numeric_only=True):
985+
def _cython_transform(self, how, numeric_only=True, **kwargs):
986986
output = collections.OrderedDict()
987987
for name, obj in self._iterate_slices():
988988
is_numeric = is_numeric_dtype(obj.dtype)
989989
if numeric_only and not is_numeric:
990990
continue
991991

992992
try:
993-
result, names = self.grouper.transform(obj.values, how)
993+
result, names = self.grouper.transform(obj.values, how, **kwargs)
994994
except NotImplementedError:
995995
continue
996996
except AssertionError as e:
@@ -1758,9 +1758,12 @@ def cumcount(self, ascending=True):
17581758

17591759
@Substitution(name='groupby')
17601760
@Appender(_doc_template)
1761-
def rank(self, axis=0, *args, **kwargs):
1761+
def rank(self, ties_method='average', ascending=True, na_option='keep',
1762+
pct=False, axis=0):
17621763
"""Rank within each group"""
1763-
return self._cython_transform('rank', **kwargs)
1764+
return self._cython_transform('rank', ties_method=ties_method,
1765+
ascending=ascending, na_option=na_option,
1766+
pct=pct, axis=axis)
17641767

17651768
@Substitution(name='groupby')
17661769
@Appender(_doc_template)
@@ -2237,7 +2240,8 @@ def wrapper(*args, **kwargs):
22372240
(how, dtype_str))
22382241
return func, dtype_str
22392242

2240-
def _cython_operation(self, kind, values, how, axis, min_count=-1):
2243+
def _cython_operation(self, kind, values, how, axis, min_count=-1,
2244+
**kwargs):
22412245
assert kind in ['transform', 'aggregate']
22422246

22432247
# can we do this operation with our cython functions
@@ -2329,7 +2333,8 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1):
23292333

23302334
# TODO: min_count
23312335
result = self._transform(
2332-
result, values, labels, func, is_numeric, is_datetimelike)
2336+
result, values, labels, func, is_numeric, is_datetimelike,
2337+
**kwargs)
23332338

23342339
if is_integer_dtype(result):
23352340
mask = result == iNaT
@@ -2368,8 +2373,8 @@ def aggregate(self, values, how, axis=0, min_count=-1):
23682373
return self._cython_operation('aggregate', values, how, axis,
23692374
min_count=min_count)
23702375

2371-
def transform(self, values, how, axis=0):
2372-
return self._cython_operation('transform', values, how, axis)
2376+
def transform(self, values, how, axis=0, **kwargs):
2377+
return self._cython_operation('transform', values, how, axis, **kwargs)
23732378

23742379
def _aggregate(self, result, counts, values, comp_ids, agg_func,
23752380
is_numeric, is_datetimelike, min_count=-1):
@@ -2389,7 +2394,7 @@ def _aggregate(self, result, counts, values, comp_ids, agg_func,
23892394
return result
23902395

23912396
def _transform(self, result, values, comp_ids, transform_func,
2392-
is_numeric, is_datetimelike):
2397+
is_numeric, is_datetimelike, **kwargs):
23932398

23942399
comp_ids, _, ngroups = self.group_info
23952400
if values.ndim > 3:
@@ -2403,7 +2408,7 @@ def _transform(self, result, values, comp_ids, transform_func,
24032408
transform_func(result[:, :, i], values,
24042409
comp_ids, is_datetimelike)
24052410
else:
2406-
transform_func(result, values, comp_ids, is_datetimelike)
2411+
transform_func(result, values, comp_ids, is_datetimelike, **kwargs)
24072412

24082413
return result
24092414

0 commit comments

Comments
 (0)