Skip to content

PERF: Allow groupby transform with numba engine to be fully parallelizable #36240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 13, 2020
46 changes: 32 additions & 14 deletions asv_bench/benchmarks/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,49 +627,63 @@ def time_first(self):


class TransformEngine:
def setup(self):

param_names = ["parallel"]
params = [[True, False]]

def setup(self, parallel):
N = 10 ** 3
data = DataFrame(
{0: [str(i) for i in range(100)] * N, 1: list(range(100)) * N},
columns=[0, 1],
)
self.parallel = parallel
self.grouper = data.groupby(0)

def time_series_numba(self):
def time_series_numba(self, parallel):
def function(values, index):
return values * 5

self.grouper[1].transform(function, engine="numba")
self.grouper[1].transform(
function, engine="numba", engine_kwargs={"parallel": self.parallel}
)

def time_series_cython(self):
def time_series_cython(self, parallel):
def function(values):
return values * 5

self.grouper[1].transform(function, engine="cython")

def time_dataframe_numba(self):
def time_dataframe_numba(self, parallel):
def function(values, index):
return values * 5

self.grouper.transform(function, engine="numba")
self.grouper.transform(
function, engine="numba", engine_kwargs={"parallel": self.parallel}
)

def time_dataframe_cython(self):
def time_dataframe_cython(self, parallel):
def function(values):
return values * 5

self.grouper.transform(function, engine="cython")


class AggEngine:
def setup(self):

param_names = ["parallel"]
params = [[True, False]]

def setup(self, parallel):
N = 10 ** 3
data = DataFrame(
{0: [str(i) for i in range(100)] * N, 1: list(range(100)) * N},
columns=[0, 1],
)
self.parallel = parallel
self.grouper = data.groupby(0)

def time_series_numba(self):
def time_series_numba(self, parallel):
def function(values, index):
total = 0
for i, value in enumerate(values):
Expand All @@ -679,9 +693,11 @@ def function(values, index):
total += value * 2
return total

self.grouper[1].agg(function, engine="numba")
self.grouper[1].agg(
function, engine="numba", engine_kwargs={"parallel": self.parallel}
)

def time_series_cython(self):
def time_series_cython(self, parallel):
def function(values):
total = 0
for i, value in enumerate(values):
Expand All @@ -693,7 +709,7 @@ def function(values):

self.grouper[1].agg(function, engine="cython")

def time_dataframe_numba(self):
def time_dataframe_numba(self, parallel):
def function(values, index):
total = 0
for i, value in enumerate(values):
Expand All @@ -703,9 +719,11 @@ def function(values, index):
total += value * 2
return total

self.grouper.agg(function, engine="numba")
self.grouper.agg(
function, engine="numba", engine_kwargs={"parallel": self.parallel}
)

def time_dataframe_cython(self):
def time_dataframe_cython(self, parallel):
def function(values):
total = 0
for i, value in enumerate(values):
Expand Down
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ Performance improvements

- Performance improvements when creating Series with dtype `str` or :class:`StringDtype` from array with many string elements (:issue:`36304`)
- Performance improvement in :meth:`GroupBy.agg` with the ``numba`` engine (:issue:`35759`)
-
- Performance improvement in :meth:`GroupBy.transform` with the ``numba`` engine (:issue:`36240`)

.. ---------------------------------------------------------------------------

Expand Down
76 changes: 32 additions & 44 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,6 @@ def apply(self, func, *args, **kwargs):
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):

if maybe_use_numba(engine):
if not callable(func):
raise NotImplementedError(
"Numba engine can only be used with a single function."
)
with group_selection_context(self):
data = self._selected_obj
result, index = self._aggregate_with_numba(
Expand Down Expand Up @@ -489,12 +485,21 @@ def _aggregate_named(self, func, *args, **kwargs):
@Substitution(klass="Series")
@Appender(_transform_template)
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):

if maybe_use_numba(engine):
with group_selection_context(self):
data = self._selected_obj
result = self._transform_with_numba(
data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs
)
return self.obj._constructor(
result.ravel(), index=data.index, name=data.name
)

func = self._get_cython_func(func) or func

if not isinstance(func, str):
return self._transform_general(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)
return self._transform_general(func, *args, **kwargs)

elif func not in base.transform_kernel_allowlist:
msg = f"'{func}' is not a valid function name for transform(name)"
Expand Down Expand Up @@ -938,10 +943,6 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):

if maybe_use_numba(engine):
if not callable(func):
raise NotImplementedError(
"Numba engine can only be used with a single function."
)
with group_selection_context(self):
data = self._selected_obj
result, index = self._aggregate_with_numba(
Expand Down Expand Up @@ -1290,42 +1291,25 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):

return self._reindex_output(result)

def _transform_general(
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
):
def _transform_general(self, func, *args, **kwargs):
from pandas.core.reshape.concat import concat

applied = []
obj = self._obj_with_exclusions
gen = self.grouper.get_iterator(obj, axis=self.axis)
if maybe_use_numba(engine):
numba_func, cache_key = generate_numba_func(
func, engine_kwargs, kwargs, "groupby_transform"
)
else:
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
fast_path, slow_path = self._define_paths(func, *args, **kwargs)

for name, group in gen:
object.__setattr__(group, "name", name)

if maybe_use_numba(engine):
values, index = split_for_numba(group)
res = numba_func(values, index, *args)
if cache_key not in NUMBA_FUNC_CACHE:
NUMBA_FUNC_CACHE[cache_key] = numba_func
# Return the result as a DataFrame for concatenation later
res = self.obj._constructor(
res, index=group.index, columns=group.columns
)
else:
# Try slow path and fast path.
try:
path, res = self._choose_path(fast_path, slow_path, group)
except TypeError:
return self._transform_item_by_item(obj, fast_path)
except ValueError as err:
msg = "transform must return a scalar value for each group"
raise ValueError(msg) from err
# Try slow path and fast path.
try:
path, res = self._choose_path(fast_path, slow_path, group)
except TypeError:
return self._transform_item_by_item(obj, fast_path)
except ValueError as err:
msg = "transform must return a scalar value for each group"
raise ValueError(msg) from err

if isinstance(res, Series):

Expand Down Expand Up @@ -1361,13 +1345,19 @@ def _transform_general(
@Appender(_transform_template)
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):

if maybe_use_numba(engine):
with group_selection_context(self):
data = self._selected_obj
result = self._transform_with_numba(
data, func, *args, engine_kwargs=engine_kwargs, **kwargs
)
return self.obj._constructor(result, index=data.index, columns=data.columns)

# optimized transforms
func = self._get_cython_func(func) or func

if not isinstance(func, str):
return self._transform_general(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)
return self._transform_general(func, *args, **kwargs)

elif func not in base.transform_kernel_allowlist:
msg = f"'{func}' is not a valid function name for transform(name)"
Expand All @@ -1393,9 +1383,7 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
):
return self._transform_fast(result)

return self._transform_general(
func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs
)
return self._transform_general(func, *args, **kwargs)

def _transform_fast(self, result: DataFrame) -> DataFrame:
"""
Expand Down
40 changes: 39 additions & 1 deletion pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,41 @@ def _cython_agg_general(

return self._wrap_aggregated_output(output, index=self.grouper.result_index)

def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs):
"""
Perform groupby transform routine with the numba engine.

This routine mimics the data splitting routine of the DataSplitter class
to generate the indices of each group in the sorted data and then passes the
data and indices into a Numba jitted function.
"""
if not callable(func):
raise NotImplementedError(
"Numba engine can only be used with a single function."
)
group_keys = self.grouper._get_group_keys()
labels, _, n_groups = self.grouper.group_info
sorted_index = get_group_index_sorter(labels, n_groups)
sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False)
sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()
starts, ends = lib.generate_slices(sorted_labels, n_groups)
cache_key = (func, "groupby_transform")
if cache_key in NUMBA_FUNC_CACHE:
numba_transform_func = NUMBA_FUNC_CACHE[cache_key]
else:
numba_transform_func = numba_.generate_numba_transform_func(
tuple(args), kwargs, func, engine_kwargs
)
result = numba_transform_func(
sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns)
)
if cache_key not in NUMBA_FUNC_CACHE:
NUMBA_FUNC_CACHE[cache_key] = numba_transform_func

# result values needs to be resorted to their original positions since we
# evaluated the data sorted by group
return result.take(np.argsort(sorted_index), axis=0)

def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs):
"""
Perform groupby aggregation routine with the numba engine.
Expand All @@ -1064,6 +1099,10 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
to generate the indices of each group in the sorted data and then passes the
data and indices into a Numba jitted function.
"""
if not callable(func):
raise NotImplementedError(
"Numba engine can only be used with a single function."
)
group_keys = self.grouper._get_group_keys()
labels, _, n_groups = self.grouper.group_info
sorted_index = get_group_index_sorter(labels, n_groups)
Expand All @@ -1072,7 +1111,6 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
starts, ends = lib.generate_slices(sorted_labels, n_groups)
cache_key = (func, "groupby_agg")
if cache_key in NUMBA_FUNC_CACHE:
# Return an already compiled version of roll_apply if available
numba_agg_func = NUMBA_FUNC_CACHE[cache_key]
else:
numba_agg_func = numba_.generate_numba_agg_func(
Expand Down
69 changes: 67 additions & 2 deletions pandas/core/groupby/numba_.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def generate_numba_agg_func(
loop_range = range

@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def group_apply(
def group_agg(
values: np.ndarray,
index: np.ndarray,
begin: np.ndarray,
Expand All @@ -169,4 +169,69 @@ def group_apply(
result[i, j] = numba_func(group, group_index, *args)
return result

return group_apply
return group_agg


def generate_numba_transform_func(
args: Tuple,
kwargs: Dict[str, Any],
func: Callable[..., Scalar],
engine_kwargs: Optional[Dict[str, bool]],
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]:
"""
Generate a numba jitted transform function specified by values from engine_kwargs.

1. jit the user's function
2. Return a groupby agg function with the jitted function inline

Configurations specified in engine_kwargs apply to both the user's
function _AND_ the rolling apply function.

Parameters
----------
args : tuple
*args to be passed into the function
kwargs : dict
**kwargs to be passed into the function
func : function
function to be applied to each window and will be JITed
engine_kwargs : dict
dictionary of arguments to be passed into numba.jit

Returns
-------
Numba function
"""
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)

check_kwargs_and_nopython(kwargs, nopython)

validate_udf(func)

numba_func = jit_user_function(func, nopython, nogil, parallel)

numba = import_optional_dependency("numba")

if parallel:
loop_range = numba.prange
else:
loop_range = range

@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def group_transform(
values: np.ndarray,
index: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
num_groups: int,
num_columns: int,
) -> np.ndarray:
result = np.empty((len(values), num_columns))
for i in loop_range(num_groups):
group_index = index[begin[i] : end[i]]
for j in loop_range(num_columns):
group = values[begin[i] : end[i], j]
result[begin[i] : end[i], j] = numba_func(group, group_index, *args)
return result

return group_transform
Loading