-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH: Add numba engine to groupby.var/std #44862
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1272,6 +1272,7 @@ def _numba_agg_general( | |
func: Callable, | ||
engine_kwargs: dict[str, bool] | None, | ||
numba_cache_key_str: str, | ||
*aggregator_args, | ||
): | ||
""" | ||
Perform groupby with a standard numerical aggregation function (e.g. mean) | ||
|
@@ -1291,7 +1292,7 @@ def _numba_agg_general( | |
aggregator = executor.generate_shared_aggregator( | ||
func, engine_kwargs, numba_cache_key_str | ||
) | ||
result = aggregator(sorted_data, starts, ends, 0) | ||
result = aggregator(sorted_data, starts, ends, 0, *aggregator_args) | ||
|
||
cache_key = (func, numba_cache_key_str) | ||
if cache_key not in NUMBA_FUNC_CACHE: | ||
|
@@ -1989,7 +1990,12 @@ def median(self, numeric_only: bool | lib.NoDefault = lib.no_default): | |
@final | ||
@Substitution(name="groupby") | ||
@Appender(_common_see_also) | ||
def std(self, ddof: int = 1): | ||
def std( | ||
self, | ||
ddof: int = 1, | ||
engine: str | None = None, | ||
engine_kwargs: dict[str, bool] | None = None, | ||
): | ||
""" | ||
Compute standard deviation of groups, excluding missing values. | ||
|
||
|
@@ -2000,23 +2006,52 @@ def std(self, ddof: int = 1): | |
ddof : int, default 1 | ||
Degrees of freedom. | ||
|
||
engine : str, default None | ||
* ``'cython'`` : Runs the operation through C-extensions from cython. | ||
* ``'numba'`` : Runs the operation through JIT compiled code from numba. | ||
* ``None`` : Defaults to ``'cython'`` or globally setting | ||
``compute.use_numba`` | ||
|
||
.. versionadded:: 1.4.0 | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in a followup can you try to use shared docstrings here (and the other functions) for the engine/engine_kwargs There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure I can try to upstream these doc parameters from rolling too |
||
engine_kwargs : dict, default None | ||
* For ``'cython'`` engine, there are no accepted ``engine_kwargs`` | ||
* For ``'numba'`` engine, the engine can accept ``nopython``, ``nogil`` | ||
and ``parallel`` dictionary keys. The values must either be ``True`` or | ||
``False``. The default ``engine_kwargs`` for the ``'numba'`` engine is | ||
``{{'nopython': True, 'nogil': False, 'parallel': False}}`` | ||
|
||
.. versionadded:: 1.4.0 | ||
|
||
Returns | ||
------- | ||
Series or DataFrame | ||
Standard deviation of values within each group. | ||
""" | ||
return self._get_cythonized_result( | ||
libgroupby.group_var, | ||
needs_counts=True, | ||
cython_dtype=np.dtype(np.float64), | ||
post_processing=lambda vals, inference: np.sqrt(vals), | ||
ddof=ddof, | ||
) | ||
if maybe_use_numba(engine): | ||
from pandas.core._numba.kernels import sliding_var | ||
|
||
return np.sqrt( | ||
self._numba_agg_general(sliding_var, engine_kwargs, "groupby_std", ddof) | ||
) | ||
else: | ||
return self._get_cythonized_result( | ||
libgroupby.group_var, | ||
needs_counts=True, | ||
cython_dtype=np.dtype(np.float64), | ||
post_processing=lambda vals, inference: np.sqrt(vals), | ||
ddof=ddof, | ||
) | ||
|
||
@final | ||
@Substitution(name="groupby") | ||
@Appender(_common_see_also) | ||
def var(self, ddof: int = 1): | ||
def var( | ||
self, | ||
ddof: int = 1, | ||
engine: str | None = None, | ||
engine_kwargs: dict[str, bool] | None = None, | ||
): | ||
""" | ||
Compute variance of groups, excluding missing values. | ||
|
||
|
@@ -2027,20 +2062,46 @@ def var(self, ddof: int = 1): | |
ddof : int, default 1 | ||
Degrees of freedom. | ||
|
||
engine : str, default None | ||
* ``'cython'`` : Runs the operation through C-extensions from cython. | ||
* ``'numba'`` : Runs the operation through JIT compiled code from numba. | ||
* ``None`` : Defaults to ``'cython'`` or globally setting | ||
``compute.use_numba`` | ||
|
||
.. versionadded:: 1.4.0 | ||
|
||
engine_kwargs : dict, default None | ||
* For ``'cython'`` engine, there are no accepted ``engine_kwargs`` | ||
* For ``'numba'`` engine, the engine can accept ``nopython``, ``nogil`` | ||
and ``parallel`` dictionary keys. The values must either be ``True`` or | ||
``False``. The default ``engine_kwargs`` for the ``'numba'`` engine is | ||
``{{'nopython': True, 'nogil': False, 'parallel': False}}`` | ||
|
||
.. versionadded:: 1.4.0 | ||
|
||
Returns | ||
------- | ||
Series or DataFrame | ||
Variance of values within each group. | ||
""" | ||
if ddof == 1: | ||
numeric_only = self._resolve_numeric_only(lib.no_default) | ||
return self._cython_agg_general( | ||
"var", alt=lambda x: Series(x).var(ddof=ddof), numeric_only=numeric_only | ||
if maybe_use_numba(engine): | ||
from pandas.core._numba.kernels import sliding_var | ||
|
||
return self._numba_agg_general( | ||
sliding_var, engine_kwargs, "groupby_var", ddof | ||
) | ||
else: | ||
func = lambda x: x.var(ddof=ddof) | ||
with self._group_selection_context(): | ||
return self._python_agg_general(func) | ||
if ddof == 1: | ||
numeric_only = self._resolve_numeric_only(lib.no_default) | ||
return self._cython_agg_general( | ||
"var", | ||
alt=lambda x: Series(x).var(ddof=ddof), | ||
numeric_only=numeric_only, | ||
) | ||
else: | ||
func = lambda x: x.var(ddof=ddof) | ||
with self._group_selection_context(): | ||
return self._python_agg_general(func) | ||
|
||
@final | ||
@Substitution(name="groupby") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are these not kwargs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Numba functions don't play nice with kwargs currently numba/numba#2916