Skip to content

Commit f29ea3e

Browse files
committed
ENH: Add separate numba kernels for groupby aggregations
1 parent 870a504 commit f29ea3e

File tree

8 files changed

+400
-55
lines changed

8 files changed

+400
-55
lines changed

asv_bench/benchmarks/groupby.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -591,12 +591,8 @@ class GroupByCythonAgg:
591591
[
592592
"sum",
593593
"prod",
594-
# TODO: uncomment min/max
595-
# Currently, min/max implemented very inefficiently
596-
# because it re-uses the Window min/max kernel
597-
# so it will time out ASVs
598-
# "min",
599-
# "max",
594+
"min",
595+
"max",
600596
"mean",
601597
"median",
602598
"var",

pandas/core/_numba/executor.py

+55-24
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,50 @@
1616

1717

1818
@functools.cache
19-
def make_looper(func, result_dtype, nopython, nogil, parallel):
19+
def make_looper(func, result_dtype, is_grouped_kernel, nopython, nogil, parallel):
2020
if TYPE_CHECKING:
2121
import numba
2222
else:
2323
numba = import_optional_dependency("numba")
2424

25-
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
26-
def column_looper(
27-
values: np.ndarray,
28-
start: np.ndarray,
29-
end: np.ndarray,
30-
min_periods: int,
31-
*args,
32-
):
33-
result = np.empty((values.shape[0], len(start)), dtype=result_dtype)
34-
na_positions = {}
35-
for i in numba.prange(values.shape[0]):
36-
output, na_pos = func(
37-
values[i], result_dtype, start, end, min_periods, *args
38-
)
39-
result[i] = output
40-
if len(na_pos) > 0:
41-
na_positions[i] = np.array(na_pos)
42-
return result, na_positions
25+
if is_grouped_kernel:
26+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
27+
def column_looper(
28+
values: np.ndarray,
29+
labels: np.ndarray,
30+
ngroups: int,
31+
min_periods: int,
32+
*args,
33+
):
34+
result = np.empty((values.shape[0], ngroups), dtype=result_dtype)
35+
na_positions = {}
36+
for i in numba.prange(values.shape[0]):
37+
output, na_pos = func(
38+
values[i], result_dtype, labels, ngroups, min_periods, *args
39+
)
40+
result[i] = output
41+
if len(na_pos) > 0:
42+
na_positions[i] = np.array(na_pos)
43+
return result, na_positions
44+
else:
45+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
46+
def column_looper(
47+
values: np.ndarray,
48+
start: np.ndarray,
49+
end: np.ndarray,
50+
min_periods: int,
51+
*args,
52+
):
53+
result = np.empty((values.shape[0], len(start)), dtype=result_dtype)
54+
na_positions = {}
55+
for i in numba.prange(values.shape[0]):
56+
output, na_pos = func(
57+
values[i], result_dtype, start, end, min_periods, *args
58+
)
59+
result[i] = output
60+
if len(na_pos) > 0:
61+
na_positions[i] = np.array(na_pos)
62+
return result, na_positions
4363

4464
return column_looper
4565

@@ -96,6 +116,7 @@ def column_looper(
96116
def generate_shared_aggregator(
97117
func: Callable[..., Scalar],
98118
dtype_mapping: dict[np.dtype, np.dtype],
119+
is_grouped_kernel: bool,
99120
nopython: bool,
100121
nogil: bool,
101122
parallel: bool,
@@ -111,6 +132,11 @@ def generate_shared_aggregator(
111132
dtype_mapping: dict or None
112133
If not None, maps a dtype to a result dtype.
113134
Otherwise, will fall back to default mapping.
135+
is_grouped_kernel: bool, default False
136+
Whether func operates using the group labels (True)
137+
or using starts/ends arrays
138+
139+
If true, you also need to pass the number of groups to this function
114140
nopython : bool
115141
nopython to be passed into numba.jit
116142
nogil : bool
@@ -130,13 +156,18 @@ def generate_shared_aggregator(
130156
# is less than min_periods
131157
# Cannot do this in numba nopython mode
132158
# (you'll run into type-unification error when you cast int -> float)
133-
def looper_wrapper(values, start, end, min_periods, **kwargs):
159+
def looper_wrapper(values, start=None, end=None, labels=None, ngroups=None, min_periods=0, **kwargs):
134160
result_dtype = dtype_mapping[values.dtype]
135-
column_looper = make_looper(func, result_dtype, nopython, nogil, parallel)
161+
column_looper = make_looper(func, result_dtype, is_grouped_kernel, nopython, nogil, parallel)
136162
# Need to unpack kwargs since numba only supports *args
137-
result, na_positions = column_looper(
138-
values, start, end, min_periods, *kwargs.values()
139-
)
163+
if is_grouped_kernel:
164+
result, na_positions = column_looper(
165+
values, labels, ngroups, min_periods, *kwargs.values()
166+
)
167+
else:
168+
result, na_positions = column_looper(
169+
values, start, end, min_periods, *kwargs.values()
170+
)
140171
if result.dtype.kind == "i":
141172
# Look if na_positions is not empty
142173
# If so, convert the whole block
+26-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,27 @@
1-
from pandas.core._numba.kernels.mean_ import sliding_mean
2-
from pandas.core._numba.kernels.min_max_ import sliding_min_max
3-
from pandas.core._numba.kernels.sum_ import sliding_sum
4-
from pandas.core._numba.kernels.var_ import sliding_var
1+
from pandas.core._numba.kernels.mean_ import (
2+
grouped_mean,
3+
sliding_mean,
4+
)
5+
from pandas.core._numba.kernels.min_max_ import (
6+
grouped_min_max,
7+
sliding_min_max,
8+
)
9+
from pandas.core._numba.kernels.sum_ import (
10+
grouped_sum,
11+
sliding_sum,
12+
)
13+
from pandas.core._numba.kernels.var_ import (
14+
grouped_var,
15+
sliding_var,
16+
)
517

6-
__all__ = ["sliding_mean", "sliding_sum", "sliding_var", "sliding_min_max"]
18+
__all__ = [
19+
"sliding_mean",
20+
"grouped_mean",
21+
"sliding_sum",
22+
"grouped_sum",
23+
"sliding_var",
24+
"grouped_var",
25+
"sliding_min_max",
26+
"grouped_min_max",
27+
]

pandas/core/_numba/kernels/mean_.py

+76
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import numpy as np
1313

1414
from pandas.core._numba.kernels.shared import is_monotonic_increasing
15+
from pandas.core._numba.kernels.sum_ import add_sum
1516

1617

1718
@numba.jit(nopython=True, nogil=True, parallel=False)
@@ -153,3 +154,78 @@ def sliding_mean(
153154
# empty list of ints on its own
154155
na_pos = [0 for i in range(0)]
155156
return output, na_pos
157+
158+
159+
@numba.jit(nopython=True, nogil=True, parallel=False)
160+
def grouped_mean(
161+
values: np.ndarray,
162+
result_dtype: np.dtype,
163+
labels: np.ndarray,
164+
ngroups: int,
165+
min_periods: int,
166+
) -> tuple[np.ndarray, list[int]]:
167+
N = len(labels)
168+
169+
nobs_arr = np.zeros(ngroups, dtype=np.int64)
170+
comp_arr = np.zeros(ngroups, dtype=values.dtype)
171+
consecutive_counts = np.zeros(ngroups, dtype=np.int64)
172+
prev_vals = np.zeros(ngroups, dtype=values.dtype)
173+
output = np.zeros(ngroups, dtype=result_dtype)
174+
175+
for i in range(N):
176+
lab = labels[i]
177+
val = values[i]
178+
179+
if lab < 0:
180+
continue
181+
182+
sum_x = output[lab]
183+
nobs = nobs_arr[lab]
184+
compensation_add = comp_arr[lab]
185+
num_consecutive_same_value = consecutive_counts[lab]
186+
prev_value = prev_vals[lab]
187+
188+
(
189+
nobs,
190+
sum_x,
191+
compensation_add,
192+
num_consecutive_same_value,
193+
prev_value,
194+
) = add_sum(
195+
val,
196+
nobs,
197+
sum_x,
198+
compensation_add,
199+
num_consecutive_same_value,
200+
prev_value,
201+
)
202+
203+
output[lab] = sum_x
204+
consecutive_counts[lab] = num_consecutive_same_value
205+
prev_vals[lab] = prev_value
206+
comp_arr[lab] = compensation_add
207+
nobs_arr[lab] = nobs
208+
209+
# Post-processing, replace sums that don't satisfy min_periods
210+
for lab in range(ngroups):
211+
nobs = nobs_arr[lab]
212+
num_consecutive_same_value = consecutive_counts[lab]
213+
prev_value = prev_vals[lab]
214+
sum_x = output[lab]
215+
if nobs == 0 == min_periods:
216+
result = 0.0
217+
elif nobs >= min_periods:
218+
if num_consecutive_same_value >= nobs:
219+
result = prev_value * nobs
220+
else:
221+
result = sum_x
222+
else:
223+
result = np.nan
224+
result /= nobs
225+
output[lab] = result
226+
227+
# na_position is empty list since float64 can already hold nans
228+
# Do list comprehension, since numba cannot figure out that na_pos is
229+
# empty list of ints on its own
230+
na_pos = [0 for i in range(0)]
231+
return output, na_pos

pandas/core/_numba/kernels/min_max_.py

+46
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,49 @@ def sliding_min_max(
7272
na_pos.append(i)
7373

7474
return output, na_pos
75+
76+
77+
@numba.jit(nopython=True, nogil=True, parallel=False)
78+
def grouped_min_max(
79+
values: np.ndarray,
80+
result_dtype: np.dtype,
81+
labels: np.ndarray,
82+
ngroups: int,
83+
min_periods: int,
84+
is_max: bool,
85+
) -> np.ndarray:
86+
N = len(labels)
87+
nobs = np.empty(ngroups, dtype=np.int64)
88+
na_pos = []
89+
output = np.empty(ngroups, dtype=result_dtype)
90+
91+
for i in range(N):
92+
lab = labels[i]
93+
val = values[i]
94+
if lab < 0:
95+
continue
96+
97+
if values.dtype.kind == "i" or not np.isnan(val):
98+
nobs[lab] += 1
99+
else:
100+
# NaN value cannot be a min/max value
101+
continue
102+
103+
if nobs[lab] == 1:
104+
# First element in group, set output equal to this
105+
output[lab] = val
106+
continue
107+
108+
if is_max:
109+
if val > output[lab]:
110+
output[lab] = val
111+
else:
112+
if val < output[lab]:
113+
output[lab] = val
114+
115+
# Set labels that don't satisfy min_periods as np.nan
116+
for lab, count in enumerate(nobs):
117+
if count < min_periods:
118+
na_pos.append(lab)
119+
120+
return output, na_pos

pandas/core/_numba/kernels/sum_.py

+72
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,75 @@ def sliding_sum(
148148
compensation_remove = 0
149149

150150
return output, na_pos
151+
152+
153+
@numba.jit(nopython=True, nogil=True, parallel=False)
154+
def grouped_sum(
155+
values: np.ndarray,
156+
result_dtype: np.dtype,
157+
labels: np.ndarray,
158+
ngroups: int,
159+
min_periods: int,
160+
) -> np.ndarray:
161+
N = len(labels)
162+
na_pos = []
163+
164+
nobs_arr = np.zeros(ngroups, dtype=np.int64)
165+
comp_arr = np.zeros(ngroups, dtype=values.dtype)
166+
consecutive_counts = np.zeros(ngroups, dtype=np.int64)
167+
prev_vals = np.zeros(ngroups, dtype=values.dtype)
168+
output = np.zeros(ngroups, dtype=result_dtype)
169+
170+
for i in range(N):
171+
lab = labels[i]
172+
val = values[i]
173+
174+
if lab < 0:
175+
continue
176+
177+
sum_x = output[lab]
178+
nobs = nobs_arr[lab]
179+
compensation_add = comp_arr[lab]
180+
num_consecutive_same_value = consecutive_counts[lab]
181+
prev_value = prev_vals[lab]
182+
183+
(
184+
nobs,
185+
sum_x,
186+
compensation_add,
187+
num_consecutive_same_value,
188+
prev_value,
189+
) = add_sum(
190+
val,
191+
nobs,
192+
sum_x,
193+
compensation_add,
194+
num_consecutive_same_value,
195+
prev_value,
196+
)
197+
198+
output[lab] = sum_x
199+
consecutive_counts[lab] = num_consecutive_same_value
200+
prev_vals[lab] = prev_value
201+
comp_arr[lab] = compensation_add
202+
nobs_arr[lab] = nobs
203+
204+
# Post-processing, replace sums that don't satisfy min_periods
205+
for lab in range(ngroups):
206+
nobs = nobs_arr[lab]
207+
num_consecutive_same_value = consecutive_counts[lab]
208+
prev_value = prev_vals[lab]
209+
sum_x = output[lab]
210+
if nobs == 0 == min_periods:
211+
result = 0.0
212+
elif nobs >= min_periods:
213+
if num_consecutive_same_value >= nobs:
214+
result = prev_value * nobs
215+
else:
216+
result = sum_x
217+
else:
218+
result = sum_x # Don't change val, will be replaced by nan later
219+
na_pos.append(lab)
220+
output[lab] = result
221+
222+
return output, na_pos

0 commit comments

Comments
 (0)