Skip to content

Commit ff2165c

Browse files
lithomas1im-vinicius
authored and
im-vinicius
committed
ENH: Allow numba aggregations to return non-float64 results (pandas-dev#53444)
* ENH: non float64 result support in numba groupby * refactor & simplify * fix CI * maybe green? * skip unsupported ops in other bench as well * updates from code review * remove commented code * update whatsnew * debug benchmarks * Skip min/max benchmarks
1 parent d9d981d commit ff2165c

File tree

13 files changed

+342
-88
lines changed

13 files changed

+342
-88
lines changed

asv_bench/benchmarks/groupby.py

+88-9
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,38 @@
5757
},
5858
}
5959

60+
# These aggregations don't have a kernel implemented for them yet
61+
_numba_unsupported_methods = [
62+
"all",
63+
"any",
64+
"bfill",
65+
"count",
66+
"cumcount",
67+
"cummax",
68+
"cummin",
69+
"cumprod",
70+
"cumsum",
71+
"describe",
72+
"diff",
73+
"ffill",
74+
"first",
75+
"head",
76+
"last",
77+
"median",
78+
"nunique",
79+
"pct_change",
80+
"prod",
81+
"quantile",
82+
"rank",
83+
"sem",
84+
"shift",
85+
"size",
86+
"skew",
87+
"tail",
88+
"unique",
89+
"value_counts",
90+
]
91+
6092

6193
class ApplyDictReturn:
6294
def setup(self):
@@ -453,9 +485,10 @@ class GroupByMethods:
453485
],
454486
["direct", "transformation"],
455487
[1, 5],
488+
["cython", "numba"],
456489
]
457490

458-
def setup(self, dtype, method, application, ncols):
491+
def setup(self, dtype, method, application, ncols, engine):
459492
if method in method_blocklist.get(dtype, {}):
460493
raise NotImplementedError # skip benchmark
461494

@@ -474,6 +507,19 @@ def setup(self, dtype, method, application, ncols):
474507
# DataFrameGroupBy doesn't have these methods
475508
raise NotImplementedError
476509

510+
# Numba currently doesn't support
511+
# multiple transform functions or strs for transform,
512+
# grouping on multiple columns
513+
# and we lack kernels for a bunch of methods
514+
if (
515+
engine == "numba"
516+
and method in _numba_unsupported_methods
517+
or ncols > 1
518+
or application == "transformation"
519+
or dtype == "datetime"
520+
):
521+
raise NotImplementedError
522+
477523
if method == "describe":
478524
ngroups = 20
479525
elif method == "skew":
@@ -505,17 +551,30 @@ def setup(self, dtype, method, application, ncols):
505551
if len(cols) == 1:
506552
cols = cols[0]
507553

554+
# Not everything supports the engine keyword yet
555+
kwargs = {}
556+
if engine == "numba":
557+
kwargs["engine"] = engine
558+
508559
if application == "transformation":
509-
self.as_group_method = lambda: df.groupby("key")[cols].transform(method)
510-
self.as_field_method = lambda: df.groupby(cols)["key"].transform(method)
560+
self.as_group_method = lambda: df.groupby("key")[cols].transform(
561+
method, **kwargs
562+
)
563+
self.as_field_method = lambda: df.groupby(cols)["key"].transform(
564+
method, **kwargs
565+
)
511566
else:
512-
self.as_group_method = getattr(df.groupby("key")[cols], method)
513-
self.as_field_method = getattr(df.groupby(cols)["key"], method)
567+
self.as_group_method = partial(
568+
getattr(df.groupby("key")[cols], method), **kwargs
569+
)
570+
self.as_field_method = partial(
571+
getattr(df.groupby(cols)["key"], method), **kwargs
572+
)
514573

515-
def time_dtype_as_group(self, dtype, method, application, ncols):
574+
def time_dtype_as_group(self, dtype, method, application, ncols, engine):
516575
self.as_group_method()
517576

518-
def time_dtype_as_field(self, dtype, method, application, ncols):
577+
def time_dtype_as_field(self, dtype, method, application, ncols, engine):
519578
self.as_field_method()
520579

521580

@@ -532,8 +591,12 @@ class GroupByCythonAgg:
532591
[
533592
"sum",
534593
"prod",
535-
"min",
536-
"max",
594+
# TODO: uncomment min/max
595+
# Currently, min/max implemented very inefficiently
596+
# because it re-uses the Window min/max kernel
597+
# so it will time out ASVs
598+
# "min",
599+
# "max",
537600
"mean",
538601
"median",
539602
"var",
@@ -554,6 +617,22 @@ def time_frame_agg(self, dtype, method):
554617
self.df.groupby("key").agg(method)
555618

556619

620+
class GroupByNumbaAgg(GroupByCythonAgg):
621+
"""
622+
Benchmarks specifically targeting our numba aggregation algorithms
623+
(using a big enough dataframe with simple key, so a large part of the
624+
time is actually spent in the grouped aggregation).
625+
"""
626+
627+
def setup(self, dtype, method):
628+
if method in _numba_unsupported_methods:
629+
raise NotImplementedError
630+
super().setup(dtype, method)
631+
632+
def time_frame_agg(self, dtype, method):
633+
self.df.groupby("key").agg(method, engine="numba")
634+
635+
557636
class GroupByCythonAggEaDtypes:
558637
"""
559638
Benchmarks specifically targeting our cython aggregation algorithms

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ Other enhancements
108108
- :meth:`SeriesGroupby.transform` and :meth:`DataFrameGroupby.transform` now support passing in a string as the function for ``engine="numba"`` (:issue:`53579`)
109109
- Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`)
110110
- Added a new parameter ``by_row`` to :meth:`Series.apply`. When set to ``False`` the supplied callables will always operate on the whole Series (:issue:`53400`).
111+
- Groupby aggregations (such as :meth:`DataFrameGroupby.sum`) now can preserve the dtype of the input instead of casting to ``float64`` (:issue:`44952`)
111112
- Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`)
112113
- Performance improvement in :func:`concat` with homogeneous ``np.float64`` or ``np.float32`` dtypes (:issue:`52685`)
113114
- Performance improvement in :meth:`DataFrame.filter` when ``items`` is given (:issue:`52941`)

pandas/core/_numba/executor.py

+111-16
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import functools
44
from typing import (
55
TYPE_CHECKING,
6+
Any,
67
Callable,
78
)
89

@@ -15,8 +16,86 @@
1516

1617

1718
@functools.cache
19+
def make_looper(func, result_dtype, nopython, nogil, parallel):
20+
if TYPE_CHECKING:
21+
import numba
22+
else:
23+
numba = import_optional_dependency("numba")
24+
25+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
26+
def column_looper(
27+
values: np.ndarray,
28+
start: np.ndarray,
29+
end: np.ndarray,
30+
min_periods: int,
31+
*args,
32+
):
33+
result = np.empty((values.shape[0], len(start)), dtype=result_dtype)
34+
na_positions = {}
35+
for i in numba.prange(values.shape[0]):
36+
output, na_pos = func(
37+
values[i], result_dtype, start, end, min_periods, *args
38+
)
39+
result[i] = output
40+
if len(na_pos) > 0:
41+
na_positions[i] = np.array(na_pos)
42+
return result, na_positions
43+
44+
return column_looper
45+
46+
47+
default_dtype_mapping: dict[np.dtype, Any] = {
48+
np.dtype("int8"): np.int64,
49+
np.dtype("int16"): np.int64,
50+
np.dtype("int32"): np.int64,
51+
np.dtype("int64"): np.int64,
52+
np.dtype("uint8"): np.uint64,
53+
np.dtype("uint16"): np.uint64,
54+
np.dtype("uint32"): np.uint64,
55+
np.dtype("uint64"): np.uint64,
56+
np.dtype("float32"): np.float64,
57+
np.dtype("float64"): np.float64,
58+
np.dtype("complex64"): np.complex128,
59+
np.dtype("complex128"): np.complex128,
60+
}
61+
62+
63+
# TODO: Preserve complex dtypes
64+
65+
float_dtype_mapping: dict[np.dtype, Any] = {
66+
np.dtype("int8"): np.float64,
67+
np.dtype("int16"): np.float64,
68+
np.dtype("int32"): np.float64,
69+
np.dtype("int64"): np.float64,
70+
np.dtype("uint8"): np.float64,
71+
np.dtype("uint16"): np.float64,
72+
np.dtype("uint32"): np.float64,
73+
np.dtype("uint64"): np.float64,
74+
np.dtype("float32"): np.float64,
75+
np.dtype("float64"): np.float64,
76+
np.dtype("complex64"): np.float64,
77+
np.dtype("complex128"): np.float64,
78+
}
79+
80+
identity_dtype_mapping: dict[np.dtype, Any] = {
81+
np.dtype("int8"): np.int8,
82+
np.dtype("int16"): np.int16,
83+
np.dtype("int32"): np.int32,
84+
np.dtype("int64"): np.int64,
85+
np.dtype("uint8"): np.uint8,
86+
np.dtype("uint16"): np.uint16,
87+
np.dtype("uint32"): np.uint32,
88+
np.dtype("uint64"): np.uint64,
89+
np.dtype("float32"): np.float32,
90+
np.dtype("float64"): np.float64,
91+
np.dtype("complex64"): np.complex64,
92+
np.dtype("complex128"): np.complex128,
93+
}
94+
95+
1896
def generate_shared_aggregator(
1997
func: Callable[..., Scalar],
98+
dtype_mapping: dict[np.dtype, np.dtype],
2099
nopython: bool,
21100
nogil: bool,
22101
parallel: bool,
@@ -29,6 +108,9 @@ def generate_shared_aggregator(
29108
----------
30109
func : function
31110
aggregation function to be applied to each column
111+
dtype_mapping: dict or None
112+
If not None, maps a dtype to a result dtype.
113+
Otherwise, will fall back to default mapping.
32114
nopython : bool
33115
nopython to be passed into numba.jit
34116
nogil : bool
@@ -40,22 +122,35 @@ def generate_shared_aggregator(
40122
-------
41123
Numba function
42124
"""
43-
if TYPE_CHECKING:
44-
import numba
45-
else:
46-
numba = import_optional_dependency("numba")
47125

48-
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
49-
def column_looper(
50-
values: np.ndarray,
51-
start: np.ndarray,
52-
end: np.ndarray,
53-
min_periods: int,
54-
*args,
55-
):
56-
result = np.empty((len(start), values.shape[1]), dtype=np.float64)
57-
for i in numba.prange(values.shape[1]):
58-
result[:, i] = func(values[:, i], start, end, min_periods, *args)
126+
# A wrapper around the looper function,
127+
# to dispatch based on dtype since numba is unable to do that in nopython mode
128+
129+
# It also post-processes the values by inserting nans where number of observations
130+
# is less than min_periods
131+
# Cannot do this in numba nopython mode
132+
# (you'll run into type-unification error when you cast int -> float)
133+
def looper_wrapper(values, start, end, min_periods, **kwargs):
134+
result_dtype = dtype_mapping[values.dtype]
135+
column_looper = make_looper(func, result_dtype, nopython, nogil, parallel)
136+
# Need to unpack kwargs since numba only supports *args
137+
result, na_positions = column_looper(
138+
values, start, end, min_periods, *kwargs.values()
139+
)
140+
if result.dtype.kind == "i":
141+
# Look if na_positions is not empty
142+
# If so, convert the whole block
143+
# This is OK since int dtype cannot hold nan,
144+
# so if min_periods not satisfied for 1 col, it is not satisfied for
145+
# all columns at that index
146+
for na_pos in na_positions.values():
147+
if len(na_pos) > 0:
148+
result = result.astype("float64")
149+
break
150+
# TODO: Optimize this
151+
for i, na_pos in na_positions.items():
152+
if len(na_pos) > 0:
153+
result[i, na_pos] = np.nan
59154
return result
60155

61-
return column_looper
156+
return looper_wrapper

pandas/core/_numba/kernels/mean_.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,11 @@ def remove_mean(
6060
@numba.jit(nopython=True, nogil=True, parallel=False)
6161
def sliding_mean(
6262
values: np.ndarray,
63+
result_dtype: np.dtype,
6364
start: np.ndarray,
6465
end: np.ndarray,
6566
min_periods: int,
66-
) -> np.ndarray:
67+
) -> tuple[np.ndarray, list[int]]:
6768
N = len(start)
6869
nobs = 0
6970
sum_x = 0.0
@@ -75,7 +76,7 @@ def sliding_mean(
7576
start
7677
) and is_monotonic_increasing(end)
7778

78-
output = np.empty(N, dtype=np.float64)
79+
output = np.empty(N, dtype=result_dtype)
7980

8081
for i in range(N):
8182
s = start[i]
@@ -147,4 +148,8 @@ def sliding_mean(
147148
neg_ct = 0
148149
compensation_remove = 0.0
149150

150-
return output
151+
# na_position is empty list since float64 can already hold nans
152+
# Do list comprehension, since numba cannot figure out that na_pos is
153+
# empty list of ints on its own
154+
na_pos = [0 for i in range(0)]
155+
return output, na_pos

pandas/core/_numba/kernels/min_max_.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@
1515
@numba.jit(nopython=True, nogil=True, parallel=False)
1616
def sliding_min_max(
1717
values: np.ndarray,
18+
result_dtype: np.dtype,
1819
start: np.ndarray,
1920
end: np.ndarray,
2021
min_periods: int,
2122
is_max: bool,
22-
) -> np.ndarray:
23+
) -> tuple[np.ndarray, list[int]]:
2324
N = len(start)
2425
nobs = 0
25-
output = np.empty(N, dtype=np.float64)
26+
output = np.empty(N, dtype=result_dtype)
27+
na_pos = []
2628
# Use deque once numba supports it
2729
# https://github.com/numba/numba/issues/7417
2830
Q: list = []
@@ -64,6 +66,9 @@ def sliding_min_max(
6466
if Q and curr_win_size > 0 and nobs >= min_periods:
6567
output[i] = values[Q[0]]
6668
else:
67-
output[i] = np.nan
69+
if values.dtype.kind != "i":
70+
output[i] = np.nan
71+
else:
72+
na_pos.append(i)
6873

69-
return output
74+
return output, na_pos

0 commit comments

Comments
 (0)