Skip to content

Commit 9ea9fde

Browse files
jbrockmendelproost
authored andcommitted
REF: Simplify _cython_functions lookup (pandas-dev#29246)
1 parent 32e7891 commit 9ea9fde

File tree

2 files changed

+19
-62
lines changed

2 files changed

+19
-62
lines changed

pandas/_libs/groupby.pyx

+7-3
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,7 @@ def group_last(rank_t[:, :] out,
937937
def group_nth(rank_t[:, :] out,
938938
int64_t[:] counts,
939939
rank_t[:, :] values,
940-
const int64_t[:] labels, int64_t rank,
940+
const int64_t[:] labels, int64_t rank=1,
941941
Py_ssize_t min_count=-1):
942942
"""
943943
Only aggregates on axis=0
@@ -1028,8 +1028,9 @@ def group_nth(rank_t[:, :] out,
10281028
def group_rank(float64_t[:, :] out,
10291029
rank_t[:, :] values,
10301030
const int64_t[:] labels,
1031-
bint is_datetimelike, object ties_method,
1032-
bint ascending, bint pct, object na_option):
1031+
int ngroups,
1032+
bint is_datetimelike, object ties_method="average",
1033+
bint ascending=True, bint pct=False, object na_option="keep"):
10331034
"""
10341035
Provides the rank of values within each group.
10351036
@@ -1039,6 +1040,9 @@ def group_rank(float64_t[:, :] out,
10391040
values : array of rank_t values to be ranked
10401041
labels : array containing unique label for each group, with its ordering
10411042
matching up to the corresponding record in `values`
1043+
ngroups : int
1044+
This parameter is not used, is needed to match signatures of other
1045+
groupby functions.
10421046
is_datetimelike : bool, default False
10431047
unused in this method but provided for call compatibility with other
10441048
Cython transformations

pandas/core/groupby/ops.py

+12-59
Original file line numberDiff line numberDiff line change
@@ -319,12 +319,9 @@ def get_group_levels(self):
319319
"min": "group_min",
320320
"max": "group_max",
321321
"mean": "group_mean",
322-
"median": {"name": "group_median"},
322+
"median": "group_median",
323323
"var": "group_var",
324-
"first": {
325-
"name": "group_nth",
326-
"f": lambda func, a, b, c, d, e: func(a, b, c, d, 1, -1),
327-
},
324+
"first": "group_nth",
328325
"last": "group_last",
329326
"ohlc": "group_ohlc",
330327
},
@@ -333,19 +330,7 @@ def get_group_levels(self):
333330
"cumsum": "group_cumsum",
334331
"cummin": "group_cummin",
335332
"cummax": "group_cummax",
336-
"rank": {
337-
"name": "group_rank",
338-
"f": lambda func, a, b, c, d, e, **kwargs: func(
339-
a,
340-
b,
341-
c,
342-
e,
343-
kwargs.get("ties_method", "average"),
344-
kwargs.get("ascending", True),
345-
kwargs.get("pct", False),
346-
kwargs.get("na_option", "keep"),
347-
),
348-
},
333+
"rank": "group_rank",
349334
},
350335
}
351336

@@ -391,21 +376,7 @@ def get_func(fname):
391376

392377
ftype = self._cython_functions[kind][how]
393378

394-
if isinstance(ftype, dict):
395-
func = afunc = get_func(ftype["name"])
396-
397-
# a sub-function
398-
f = ftype.get("f")
399-
if f is not None:
400-
401-
def wrapper(*args, **kwargs):
402-
return f(afunc, *args, **kwargs)
403-
404-
# need to curry our sub-function
405-
func = wrapper
406-
407-
else:
408-
func = get_func(ftype)
379+
func = get_func(ftype)
409380

410381
if func is None:
411382
raise NotImplementedError(
@@ -517,14 +488,7 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1, **kwargs):
517488
)
518489
counts = np.zeros(self.ngroups, dtype=np.int64)
519490
result = self._aggregate(
520-
result,
521-
counts,
522-
values,
523-
labels,
524-
func,
525-
is_numeric,
526-
is_datetimelike,
527-
min_count,
491+
result, counts, values, labels, func, is_datetimelike, min_count
528492
)
529493
elif kind == "transform":
530494
result = _maybe_fill(
@@ -533,7 +497,7 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1, **kwargs):
533497

534498
# TODO: min_count
535499
result = self._transform(
536-
result, values, labels, func, is_numeric, is_datetimelike, **kwargs
500+
result, values, labels, func, is_datetimelike, **kwargs
537501
)
538502

539503
if is_integer_dtype(result) and not is_datetimelike:
@@ -574,33 +538,22 @@ def transform(self, values, how, axis=0, **kwargs):
574538
return self._cython_operation("transform", values, how, axis, **kwargs)
575539

576540
def _aggregate(
577-
self,
578-
result,
579-
counts,
580-
values,
581-
comp_ids,
582-
agg_func,
583-
is_numeric,
584-
is_datetimelike,
585-
min_count=-1,
541+
self, result, counts, values, comp_ids, agg_func, is_datetimelike, min_count=-1
586542
):
587543
if values.ndim > 2:
588544
# punting for now
589545
raise NotImplementedError("number of dimensions is currently limited to 2")
546+
elif agg_func is libgroupby.group_nth:
547+
# different signature from the others
548+
# TODO: should we be using min_count instead of hard-coding it?
549+
agg_func(result, counts, values, comp_ids, rank=1, min_count=-1)
590550
else:
591551
agg_func(result, counts, values, comp_ids, min_count)
592552

593553
return result
594554

595555
def _transform(
596-
self,
597-
result,
598-
values,
599-
comp_ids,
600-
transform_func,
601-
is_numeric,
602-
is_datetimelike,
603-
**kwargs
556+
self, result, values, comp_ids, transform_func, is_datetimelike, **kwargs
604557
):
605558

606559
comp_ids, _, ngroups = self.group_info

0 commit comments

Comments
 (0)