Skip to content

Commit b0e1130

Browse files
authored
ENH: Add separate numba kernels for groupby aggregations (#53731)
* ENH: Add separate numba kernels for groupby aggregations * add whatsnew * fixes from pre-commit * fix window tests * fix tests? * cleanup * is_grouped_kernel=True in groupby * typing * fix typing? * fix now * Just ignore * remove unnecessary code * remove comment
1 parent 57c7943 commit b0e1130

File tree

11 files changed

+402
-57
lines changed

11 files changed

+402
-57
lines changed

asv_bench/benchmarks/groupby.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -586,12 +586,8 @@ class GroupByCythonAgg:
586586
[
587587
"sum",
588588
"prod",
589-
# TODO: uncomment min/max
590-
# Currently, min/max implemented very inefficiently
591-
# because it re-uses the Window min/max kernel
592-
# so it will time out ASVs
593-
# "min",
594-
# "max",
589+
"min",
590+
"max",
595591
"mean",
596592
"median",
597593
"var",

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,7 @@ Performance improvements
628628
- Performance improvement in :func:`concat` (:issue:`52291`, :issue:`52290`)
629629
- :class:`Period`'s default formatter (`period_format`) is now significantly (~twice) faster. This improves performance of ``str(Period)``, ``repr(Period)``, and :meth:`Period.strftime(fmt=None)`, as well as ``PeriodArray.strftime(fmt=None)``, ``PeriodIndex.strftime(fmt=None)`` and ``PeriodIndex.format(fmt=None)``. Finally, ``to_csv`` operations involving :class:`PeriodArray` or :class:`PeriodIndex` with default ``date_format`` are also significantly accelerated. (:issue:`51459`)
630630
- Performance improvement accessing :attr:`arrays.IntegerArrays.dtype` & :attr:`arrays.FloatingArray.dtype` (:issue:`52998`)
631+
- Performance improvement for :class:`DataFrameGroupBy`/:class:`SeriesGroupBy` aggregations (e.g. :meth:`DataFrameGroupBy.sum`) with ``engine="numba"`` (:issue:`53731`)
631632
- Performance improvement in :class:`DataFrame` reductions with ``axis=None`` and extension dtypes (:issue:`54308`)
632633
- Performance improvement in :class:`MultiIndex` and multi-column operations (e.g. :meth:`DataFrame.sort_values`, :meth:`DataFrame.groupby`, :meth:`Series.unstack`) when index/column values are already sorted (:issue:`53806`)
633634
- Performance improvement in :class:`Series` reductions (:issue:`52341`)

pandas/core/_numba/executor.py

+68-24
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,53 @@
1616

1717

1818
@functools.cache
19-
def make_looper(func, result_dtype, nopython, nogil, parallel):
19+
def make_looper(func, result_dtype, is_grouped_kernel, nopython, nogil, parallel):
2020
if TYPE_CHECKING:
2121
import numba
2222
else:
2323
numba = import_optional_dependency("numba")
2424

25-
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
26-
def column_looper(
27-
values: np.ndarray,
28-
start: np.ndarray,
29-
end: np.ndarray,
30-
min_periods: int,
31-
*args,
32-
):
33-
result = np.empty((values.shape[0], len(start)), dtype=result_dtype)
34-
na_positions = {}
35-
for i in numba.prange(values.shape[0]):
36-
output, na_pos = func(
37-
values[i], result_dtype, start, end, min_periods, *args
38-
)
39-
result[i] = output
40-
if len(na_pos) > 0:
41-
na_positions[i] = np.array(na_pos)
42-
return result, na_positions
25+
if is_grouped_kernel:
26+
27+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
28+
def column_looper(
29+
values: np.ndarray,
30+
labels: np.ndarray,
31+
ngroups: int,
32+
min_periods: int,
33+
*args,
34+
):
35+
result = np.empty((values.shape[0], ngroups), dtype=result_dtype)
36+
na_positions = {}
37+
for i in numba.prange(values.shape[0]):
38+
output, na_pos = func(
39+
values[i], result_dtype, labels, ngroups, min_periods, *args
40+
)
41+
result[i] = output
42+
if len(na_pos) > 0:
43+
na_positions[i] = np.array(na_pos)
44+
return result, na_positions
45+
46+
else:
47+
48+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
49+
def column_looper(
50+
values: np.ndarray,
51+
start: np.ndarray,
52+
end: np.ndarray,
53+
min_periods: int,
54+
*args,
55+
):
56+
result = np.empty((values.shape[0], len(start)), dtype=result_dtype)
57+
na_positions = {}
58+
for i in numba.prange(values.shape[0]):
59+
output, na_pos = func(
60+
values[i], result_dtype, start, end, min_periods, *args
61+
)
62+
result[i] = output
63+
if len(na_pos) > 0:
64+
na_positions[i] = np.array(na_pos)
65+
return result, na_positions
4366

4467
return column_looper
4568

@@ -96,6 +119,7 @@ def column_looper(
96119
def generate_shared_aggregator(
97120
func: Callable[..., Scalar],
98121
dtype_mapping: dict[np.dtype, np.dtype],
122+
is_grouped_kernel: bool,
99123
nopython: bool,
100124
nogil: bool,
101125
parallel: bool,
@@ -111,6 +135,11 @@ def generate_shared_aggregator(
111135
dtype_mapping: dict or None
112136
If not None, maps a dtype to a result dtype.
113137
Otherwise, will fall back to default mapping.
138+
is_grouped_kernel: bool, default False
139+
Whether func operates using the group labels (True)
140+
or using starts/ends arrays
141+
142+
If true, you also need to pass the number of groups to this function
114143
nopython : bool
115144
nopython to be passed into numba.jit
116145
nogil : bool
@@ -130,13 +159,28 @@ def generate_shared_aggregator(
130159
# is less than min_periods
131160
# Cannot do this in numba nopython mode
132161
# (you'll run into type-unification error when you cast int -> float)
133-
def looper_wrapper(values, start, end, min_periods, **kwargs):
162+
def looper_wrapper(
163+
values,
164+
start=None,
165+
end=None,
166+
labels=None,
167+
ngroups=None,
168+
min_periods: int = 0,
169+
**kwargs,
170+
):
134171
result_dtype = dtype_mapping[values.dtype]
135-
column_looper = make_looper(func, result_dtype, nopython, nogil, parallel)
136-
# Need to unpack kwargs since numba only supports *args
137-
result, na_positions = column_looper(
138-
values, start, end, min_periods, *kwargs.values()
172+
column_looper = make_looper(
173+
func, result_dtype, is_grouped_kernel, nopython, nogil, parallel
139174
)
175+
# Need to unpack kwargs since numba only supports *args
176+
if is_grouped_kernel:
177+
result, na_positions = column_looper(
178+
values, labels, ngroups, min_periods, *kwargs.values()
179+
)
180+
else:
181+
result, na_positions = column_looper(
182+
values, start, end, min_periods, *kwargs.values()
183+
)
140184
if result.dtype.kind == "i":
141185
# Look if na_positions is not empty
142186
# If so, convert the whole block
+26-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,27 @@
1-
from pandas.core._numba.kernels.mean_ import sliding_mean
2-
from pandas.core._numba.kernels.min_max_ import sliding_min_max
3-
from pandas.core._numba.kernels.sum_ import sliding_sum
4-
from pandas.core._numba.kernels.var_ import sliding_var
1+
from pandas.core._numba.kernels.mean_ import (
2+
grouped_mean,
3+
sliding_mean,
4+
)
5+
from pandas.core._numba.kernels.min_max_ import (
6+
grouped_min_max,
7+
sliding_min_max,
8+
)
9+
from pandas.core._numba.kernels.sum_ import (
10+
grouped_sum,
11+
sliding_sum,
12+
)
13+
from pandas.core._numba.kernels.var_ import (
14+
grouped_var,
15+
sliding_var,
16+
)
517

6-
__all__ = ["sliding_mean", "sliding_sum", "sliding_var", "sliding_min_max"]
18+
__all__ = [
19+
"sliding_mean",
20+
"grouped_mean",
21+
"sliding_sum",
22+
"grouped_sum",
23+
"sliding_var",
24+
"grouped_var",
25+
"sliding_min_max",
26+
"grouped_min_max",
27+
]

pandas/core/_numba/kernels/mean_.py

+41
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,16 @@
88
"""
99
from __future__ import annotations
1010

11+
from typing import TYPE_CHECKING
12+
1113
import numba
1214
import numpy as np
1315

1416
from pandas.core._numba.kernels.shared import is_monotonic_increasing
17+
from pandas.core._numba.kernels.sum_ import grouped_kahan_sum
18+
19+
if TYPE_CHECKING:
20+
from pandas._typing import npt
1521

1622

1723
@numba.jit(nopython=True, nogil=True, parallel=False)
@@ -153,3 +159,38 @@ def sliding_mean(
153159
# empty list of ints on its own
154160
na_pos = [0 for i in range(0)]
155161
return output, na_pos
162+
163+
164+
@numba.jit(nopython=True, nogil=True, parallel=False)
165+
def grouped_mean(
166+
values: np.ndarray,
167+
result_dtype: np.dtype,
168+
labels: npt.NDArray[np.intp],
169+
ngroups: int,
170+
min_periods: int,
171+
) -> tuple[np.ndarray, list[int]]:
172+
output, nobs_arr, comp_arr, consecutive_counts, prev_vals = grouped_kahan_sum(
173+
values, result_dtype, labels, ngroups
174+
)
175+
176+
# Post-processing, replace sums that don't satisfy min_periods
177+
for lab in range(ngroups):
178+
nobs = nobs_arr[lab]
179+
num_consecutive_same_value = consecutive_counts[lab]
180+
prev_value = prev_vals[lab]
181+
sum_x = output[lab]
182+
if nobs >= min_periods:
183+
if num_consecutive_same_value >= nobs:
184+
result = prev_value * nobs
185+
else:
186+
result = sum_x
187+
else:
188+
result = np.nan
189+
result /= nobs
190+
output[lab] = result
191+
192+
# na_position is empty list since float64 can already hold nans
193+
# Do list comprehension, since numba cannot figure out that na_pos is
194+
# empty list of ints on its own
195+
na_pos = [0 for i in range(0)]
196+
return output, na_pos

pandas/core/_numba/kernels/min_max_.py

+51
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,14 @@
88
"""
99
from __future__ import annotations
1010

11+
from typing import TYPE_CHECKING
12+
1113
import numba
1214
import numpy as np
1315

16+
if TYPE_CHECKING:
17+
from pandas._typing import npt
18+
1419

1520
@numba.jit(nopython=True, nogil=True, parallel=False)
1621
def sliding_min_max(
@@ -72,3 +77,49 @@ def sliding_min_max(
7277
na_pos.append(i)
7378

7479
return output, na_pos
80+
81+
82+
@numba.jit(nopython=True, nogil=True, parallel=False)
83+
def grouped_min_max(
84+
values: np.ndarray,
85+
result_dtype: np.dtype,
86+
labels: npt.NDArray[np.intp],
87+
ngroups: int,
88+
min_periods: int,
89+
is_max: bool,
90+
) -> tuple[np.ndarray, list[int]]:
91+
N = len(labels)
92+
nobs = np.zeros(ngroups, dtype=np.int64)
93+
na_pos = []
94+
output = np.empty(ngroups, dtype=result_dtype)
95+
96+
for i in range(N):
97+
lab = labels[i]
98+
val = values[i]
99+
if lab < 0:
100+
continue
101+
102+
if values.dtype.kind == "i" or not np.isnan(val):
103+
nobs[lab] += 1
104+
else:
105+
# NaN value cannot be a min/max value
106+
continue
107+
108+
if nobs[lab] == 1:
109+
# First element in group, set output equal to this
110+
output[lab] = val
111+
continue
112+
113+
if is_max:
114+
if val > output[lab]:
115+
output[lab] = val
116+
else:
117+
if val < output[lab]:
118+
output[lab] = val
119+
120+
# Set labels that don't satisfy min_periods as np.nan
121+
for lab, count in enumerate(nobs):
122+
if count < min_periods:
123+
na_pos.append(lab)
124+
125+
return output, na_pos

0 commit comments

Comments
 (0)