Skip to content

PERF: Allow groupby transform with numba engine to be fully parallelizable #36240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 13, 2020
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ Performance improvements
~~~~~~~~~~~~~~~~~~~~~~~~

- Performance improvement in :meth:`GroupBy.agg` with the ``numba`` engine (:issue:`35759`)
-
- Performance improvement in :meth:`GroupBy.transform` with the ``numba`` engine (:issue:`36240`)

.. ---------------------------------------------------------------------------

Expand Down
76 changes: 40 additions & 36 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,12 +489,25 @@ def _aggregate_named(self, func, *args, **kwargs):
@Substitution(klass="Series")
@Appender(_transform_template)
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):

if maybe_use_numba(engine):
if not callable(func):
raise NotImplementedError(
"Numba engine can only be used with a single function."
)
with group_selection_context(self):
data = self._selected_obj
result = self._transform_with_numba(
data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs
)
return self.obj._constructor(
result.ravel(), index=data.index, name=data.name
)

func = self._get_cython_func(func) or func

if not isinstance(func, str):
return self._transform_general(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)
return self._transform_general(func, *args, **kwargs)

elif func not in base.transform_kernel_allowlist:
msg = f"'{func}' is not a valid function name for transform(name)"
Expand Down Expand Up @@ -1291,42 +1304,25 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):

return self._reindex_output(result)

def _transform_general(
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
):
def _transform_general(self, func, *args, **kwargs):
from pandas.core.reshape.concat import concat

applied = []
obj = self._obj_with_exclusions
gen = self.grouper.get_iterator(obj, axis=self.axis)
if maybe_use_numba(engine):
numba_func, cache_key = generate_numba_func(
func, engine_kwargs, kwargs, "groupby_transform"
)
else:
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
fast_path, slow_path = self._define_paths(func, *args, **kwargs)

for name, group in gen:
object.__setattr__(group, "name", name)

if maybe_use_numba(engine):
values, index = split_for_numba(group)
res = numba_func(values, index, *args)
if cache_key not in NUMBA_FUNC_CACHE:
NUMBA_FUNC_CACHE[cache_key] = numba_func
# Return the result as a DataFrame for concatenation later
res = self.obj._constructor(
res, index=group.index, columns=group.columns
)
else:
# Try slow path and fast path.
try:
path, res = self._choose_path(fast_path, slow_path, group)
except TypeError:
return self._transform_item_by_item(obj, fast_path)
except ValueError as err:
msg = "transform must return a scalar value for each group"
raise ValueError(msg) from err
# Try slow path and fast path.
try:
path, res = self._choose_path(fast_path, slow_path, group)
except TypeError:
return self._transform_item_by_item(obj, fast_path)
except ValueError as err:
msg = "transform must return a scalar value for each group"
raise ValueError(msg) from err

if isinstance(res, Series):

Expand Down Expand Up @@ -1362,13 +1358,23 @@ def _transform_general(
@Appender(_transform_template)
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):

if maybe_use_numba(engine):
if not callable(func):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this check should actually be in _transform_with_numba to keep DRY (you have it above as well)

raise NotImplementedError(
"Numba engine can only be used with a single function."
)
with group_selection_context(self):
data = self._selected_obj
result = self._transform_with_numba(
data, func, *args, engine_kwargs=engine_kwargs, **kwargs
)
return self.obj._constructor(result, index=data.index, columns=data.columns)

# optimized transforms
func = self._get_cython_func(func) or func

if not isinstance(func, str):
return self._transform_general(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)
return self._transform_general(func, *args, **kwargs)

elif func not in base.transform_kernel_allowlist:
msg = f"'{func}' is not a valid function name for transform(name)"
Expand All @@ -1394,9 +1400,7 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
):
return self._transform_fast(result)

return self._transform_general(
func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs
)
return self._transform_general(func, *args, **kwargs)

def _transform_fast(self, result: DataFrame) -> DataFrame:
"""
Expand Down
32 changes: 31 additions & 1 deletion pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,37 @@ def _cython_agg_general(

return self._wrap_aggregated_output(output, index=self.grouper.result_index)

def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs):
"""
Perform groupby transform routine with the numba engine.

This routine mimics the data splitting routine of the DataSplitter class
to generate the indices of each group in the sorted data and then passes the
data and indices into a Numba jitted function.
"""
group_keys = self.grouper._get_group_keys()
labels, _, n_groups = self.grouper.group_info
sorted_index = get_group_index_sorter(labels, n_groups)
sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False)
sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()
starts, ends = lib.generate_slices(sorted_labels, n_groups)
cache_key = (func, "groupby_transform")
if cache_key in NUMBA_FUNC_CACHE:
numba_transform_func = NUMBA_FUNC_CACHE[cache_key]
else:
numba_transform_func = numba_.generate_numba_transform_func(
tuple(args), kwargs, func, engine_kwargs
)
result = numba_transform_func(
sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns)
)
if cache_key not in NUMBA_FUNC_CACHE:
NUMBA_FUNC_CACHE[cache_key] = numba_transform_func

# result values needs to be resorted to their original positions since we
# evaluated the data sorted by group
return result.take(np.argsort(sorted_index), axis=0)

def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs):
"""
Perform groupby aggregation routine with the numba engine.
Expand All @@ -1072,7 +1103,6 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
starts, ends = lib.generate_slices(sorted_labels, n_groups)
cache_key = (func, "groupby_agg")
if cache_key in NUMBA_FUNC_CACHE:
# Return an already compiled version of roll_apply if available
numba_agg_func = NUMBA_FUNC_CACHE[cache_key]
else:
numba_agg_func = numba_.generate_numba_agg_func(
Expand Down
69 changes: 67 additions & 2 deletions pandas/core/groupby/numba_.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def generate_numba_agg_func(
loop_range = range

@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def group_apply(
def group_agg(
values: np.ndarray,
index: np.ndarray,
begin: np.ndarray,
Expand All @@ -169,4 +169,69 @@ def group_apply(
result[i, j] = numba_func(group, group_index, *args)
return result

return group_apply
return group_agg


def generate_numba_transform_func(
args: Tuple,
kwargs: Dict[str, Any],
func: Callable[..., Scalar],
engine_kwargs: Optional[Dict[str, bool]],
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]:
"""
Generate a numba jitted transform function specified by values from engine_kwargs.

1. jit the user's function
2. Return a groupby agg function with the jitted function inline

Configurations specified in engine_kwargs apply to both the user's
function _AND_ the rolling apply function.

Parameters
----------
args : tuple
*args to be passed into the function
kwargs : dict
**kwargs to be passed into the function
func : function
function to be applied to each window and will be JITed
engine_kwargs : dict
dictionary of arguments to be passed into numba.jit

Returns
-------
Numba function
"""
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)

check_kwargs_and_nopython(kwargs, nopython)

validate_udf(func)

numba_func = jit_user_function(func, nopython, nogil, parallel)

numba = import_optional_dependency("numba")

if parallel:
loop_range = numba.prange
else:
loop_range = range

@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def group_transform(
values: np.ndarray,
index: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
num_groups: int,
num_columns: int,
) -> np.ndarray:
result = np.empty((len(values), num_columns))
for i in loop_range(num_groups):
group_index = index[begin[i] : end[i]]
for j in loop_range(num_columns):
group = values[begin[i] : end[i], j]
result[begin[i] : end[i], j] = numba_func(group, group_index, *args)
return result

return group_transform
16 changes: 16 additions & 0 deletions pandas/tests/groupby/transform/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,19 @@ def func_1(values, index):
with option_context("compute.use_numba", True):
result = grouped.transform(func_1, engine=None)
tm.assert_frame_equal(expected, result)


@td.skip_if_no("numba", "0.46.0")
@pytest.mark.parametrize(
"agg_func", [["min", "max"], "min", {"B": ["min", "max"], "C": "sum"}],
)
def test_multifunc_notimplimented(agg_func):
data = DataFrame(
{0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
)
grouped = data.groupby(0)
with pytest.raises(NotImplementedError, match="Numba engine can"):
grouped.transform(agg_func, engine="numba")

with pytest.raises(NotImplementedError, match="Numba engine can"):
grouped[1].transform(agg_func, engine="numba")