-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
ENH: Allow numba aggregations to return non-float64 results #53444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bcd93e0
e22d783
5be4d9e
9f2f70d
6f12756
00ce652
64ecaec
4d58a47
405a71c
c6d4ffe
8f076e7
d05ebdf
5b4f7fc
e67bbeb
6f103ab
6d75ce4
b0d22db
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
import functools | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
Callable, | ||
) | ||
|
||
|
@@ -15,8 +16,86 @@ | |
|
||
|
||
@functools.cache | ||
def make_looper(func, result_dtype, nopython, nogil, parallel): | ||
if TYPE_CHECKING: | ||
import numba | ||
mroeschke marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
numba = import_optional_dependency("numba") | ||
|
||
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) | ||
def column_looper( | ||
values: np.ndarray, | ||
start: np.ndarray, | ||
end: np.ndarray, | ||
min_periods: int, | ||
*args, | ||
): | ||
result = np.empty((values.shape[0], len(start)), dtype=result_dtype) | ||
na_positions = {} | ||
for i in numba.prange(values.shape[0]): | ||
output, na_pos = func( | ||
values[i], result_dtype, start, end, min_periods, *args | ||
) | ||
result[i] = output | ||
if len(na_pos) > 0: | ||
na_positions[i] = np.array(na_pos) | ||
return result, na_positions | ||
|
||
return column_looper | ||
|
||
|
||
default_dtype_mapping: dict[np.dtype, Any] = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious, could we not just define signatures for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We allocate arrays inside the function and need to pass a dtype there as well. Not sure how to access the signature from inside the func. |
||
np.dtype("int8"): np.int64, | ||
np.dtype("int16"): np.int64, | ||
np.dtype("int32"): np.int64, | ||
np.dtype("int64"): np.int64, | ||
np.dtype("uint8"): np.uint64, | ||
np.dtype("uint16"): np.uint64, | ||
np.dtype("uint32"): np.uint64, | ||
np.dtype("uint64"): np.uint64, | ||
np.dtype("float32"): np.float64, | ||
np.dtype("float64"): np.float64, | ||
np.dtype("complex64"): np.complex128, | ||
np.dtype("complex128"): np.complex128, | ||
} | ||
|
||
|
||
# TODO: Preserve complex dtypes | ||
|
||
float_dtype_mapping: dict[np.dtype, Any] = { | ||
np.dtype("int8"): np.float64, | ||
np.dtype("int16"): np.float64, | ||
np.dtype("int32"): np.float64, | ||
np.dtype("int64"): np.float64, | ||
np.dtype("uint8"): np.float64, | ||
np.dtype("uint16"): np.float64, | ||
np.dtype("uint32"): np.float64, | ||
np.dtype("uint64"): np.float64, | ||
np.dtype("float32"): np.float64, | ||
np.dtype("float64"): np.float64, | ||
np.dtype("complex64"): np.float64, | ||
np.dtype("complex128"): np.float64, | ||
} | ||
|
||
identity_dtype_mapping: dict[np.dtype, Any] = { | ||
np.dtype("int8"): np.int8, | ||
np.dtype("int16"): np.int16, | ||
np.dtype("int32"): np.int32, | ||
np.dtype("int64"): np.int64, | ||
np.dtype("uint8"): np.uint8, | ||
np.dtype("uint16"): np.uint16, | ||
np.dtype("uint32"): np.uint32, | ||
np.dtype("uint64"): np.uint64, | ||
np.dtype("float32"): np.float32, | ||
np.dtype("float64"): np.float64, | ||
np.dtype("complex64"): np.complex64, | ||
np.dtype("complex128"): np.complex128, | ||
} | ||
|
||
|
||
def generate_shared_aggregator( | ||
func: Callable[..., Scalar], | ||
dtype_mapping: dict[np.dtype, np.dtype], | ||
nopython: bool, | ||
nogil: bool, | ||
parallel: bool, | ||
|
@@ -29,6 +108,9 @@ def generate_shared_aggregator( | |
---------- | ||
func : function | ||
aggregation function to be applied to each column | ||
dtype_mapping: dict or None | ||
If not None, maps a dtype to a result dtype. | ||
Otherwise, will fall back to default mapping. | ||
nopython : bool | ||
nopython to be passed into numba.jit | ||
nogil : bool | ||
|
@@ -40,22 +122,35 @@ def generate_shared_aggregator( | |
------- | ||
Numba function | ||
""" | ||
if TYPE_CHECKING: | ||
import numba | ||
else: | ||
numba = import_optional_dependency("numba") | ||
|
||
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) | ||
def column_looper( | ||
values: np.ndarray, | ||
start: np.ndarray, | ||
end: np.ndarray, | ||
min_periods: int, | ||
*args, | ||
): | ||
result = np.empty((len(start), values.shape[1]), dtype=np.float64) | ||
for i in numba.prange(values.shape[1]): | ||
result[:, i] = func(values[:, i], start, end, min_periods, *args) | ||
# A wrapper around the looper function, | ||
# to dispatch based on dtype since numba is unable to do that in nopython mode | ||
|
||
# It also post-processes the values by inserting nans where number of observations | ||
# is less than min_periods | ||
# Cannot do this in numba nopython mode | ||
# (you'll run into type-unification error when you cast int -> float) | ||
def looper_wrapper(values, start, end, min_periods, **kwargs): | ||
result_dtype = dtype_mapping[values.dtype] | ||
column_looper = make_looper(func, result_dtype, nopython, nogil, parallel) | ||
# Need to unpack kwargs since numba only supports *args | ||
result, na_positions = column_looper( | ||
values, start, end, min_periods, *kwargs.values() | ||
) | ||
if result.dtype.kind == "i": | ||
# Look if na_positions is not empty | ||
# If so, convert the whole block | ||
# This is OK since int dtype cannot hold nan, | ||
# so if min_periods not satisfied for 1 col, it is not satisfied for | ||
# all columns at that index | ||
for na_pos in na_positions.values(): | ||
if len(na_pos) > 0: | ||
result = result.astype("float64") | ||
break | ||
# TODO: Optimize this | ||
for i, na_pos in na_positions.items(): | ||
if len(na_pos) > 0: | ||
result[i, na_pos] = np.nan | ||
rhshadrach marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return result | ||
|
||
return column_looper | ||
return looper_wrapper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Disabled min/max because it's reaaaallllly sloooooow.
It takes 20s (as opposed to milliseconds for the other kernels) to run, and can time out the ASVs sometimes(causing flakiness).
Best guess is that the list operations are slowing it down. Snakeviz tells me most (99% of the time) is spent in the numba kernel, and I can't profile into there.
I'm planning on splitting groupby stuff from the Window numba kernels in the future, so hopefully this doesn't stay commented for long.