Skip to content

TYP: Groupby sum|prod|min|max|first|last methods #32302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 116 additions & 92 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class providing the base-class of operations.
from pandas._libs import Timestamp
import pandas._libs.groupby as libgroupby
from pandas._typing import FrameOrSeries, Scalar
from pandas.compat import set_function_name
from pandas.compat.numpy import function as nv
from pandas.errors import AbstractMethodError
from pandas.util._decorators import Appender, Substitution, cache_readonly, doc
Expand Down Expand Up @@ -192,6 +191,24 @@ class providing the base-class of operations.
""",
)

_groupby_agg_method_template = """
Compute {fname} of group values.

Parameters
----------
numeric_only : bool, default {no}
Include only float, int, boolean columns. If None, will attempt to use
everything, then use only numeric data.
min_count : int, default {mc}
The required number of valid values to perform the operation. If fewer
than ``min_count`` non-NA values are present the result will be NA.

Returns
-------
Series or DataFrame
Computed {fname} of values within each group.
"""

_pipe_template = """
Apply a function `func` with arguments to this %(klass)s object and return
the function's result.
Expand Down Expand Up @@ -945,6 +962,37 @@ def _wrap_transformed_output(self, output: Mapping[base.OutputKey, np.ndarray]):
def _wrap_applied_output(self, keys, values, not_indexed_same: bool = False):
raise AbstractMethodError(self)

def _agg_general(
self,
numeric_only: bool = True,
min_count: int = -1,
*,
alias: str,
npfunc: Callable,
):
self._set_group_selection()

# try a cython aggregation if we can
try:
return self._cython_agg_general(
how=alias, alt=npfunc, numeric_only=numeric_only, min_count=min_count,
)
except DataError:
pass
except NotImplementedError as err:
if "function is not implemented for this dtype" in str(
err
) or "category dtype not supported" in str(err):
# raised in _get_cython_function, in some cases can
# be trimmed by implementing cython funcs for more dtypes
pass
else:
raise

# apply a non-cython aggregation
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
return result

def _cython_agg_general(
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1
):
Expand Down Expand Up @@ -1117,6 +1165,27 @@ def _apply_filter(self, indices, dropna):
OutputFrameOrSeries = TypeVar("OutputFrameOrSeries", bound=NDFrame)


def get_loc_notna(obj: "Series", *, loc: int):
"""Find the value in position ``loc`` after filtering ``obj`` for nan values.

if ``obj`` is empty or has only nan values, np.nan er returned.

Examples
--------
>>> ser = pd.Series([np.nan, np.nan, 1, 2, np.nan])
>>> get_loc_notna(ser, loc=0) # get first non-na
1.0
>>> get_loc_notna(ser, loc=-1) # get last non-na
2.0
"""
x = obj.to_numpy()
x = x[notna(x)]

if len(x) == 0:
return np.nan
return x[loc]


class GroupBy(_GroupBy[FrameOrSeries]):
"""
Class for grouping and aggregating relational data.
Expand Down Expand Up @@ -1438,105 +1507,63 @@ def size(self):
result = self._obj_1d_constructor(result)
return self._reindex_output(result, fill_value=0)

@classmethod
def _add_numeric_operations(cls):
"""
Add numeric operations to the GroupBy generically.
"""

def groupby_function(
name: str,
alias: str,
npfunc,
numeric_only: bool = True,
min_count: int = -1,
):
@doc(_groupby_agg_method_template, fname="sum", no=True, mc=0)
def sum(self, numeric_only: bool = True, min_count: int = 0):
return self._agg_general(
numeric_only=numeric_only, min_count=min_count, alias="add", npfunc=np.sum
)

_local_template = """
Compute %(f)s of group values.

Parameters
----------
numeric_only : bool, default %(no)s
Include only float, int, boolean columns. If None, will attempt to use
everything, then use only numeric data.
min_count : int, default %(mc)s
The required number of valid values to perform the operation. If fewer
than ``min_count`` non-NA values are present the result will be NA.

Returns
-------
Series or DataFrame
Computed %(f)s of values within each group.
"""
@doc(_groupby_agg_method_template, fname="prod", no=True, mc=0)
def prod(self, numeric_only: bool = True, min_count: int = 0):
return self._agg_general(
numeric_only=numeric_only, min_count=min_count, alias="prod", npfunc=np.prod
)

@Substitution(name="groupby", f=name, no=numeric_only, mc=min_count)
@Appender(_common_see_also)
@Appender(_local_template)
def func(self, numeric_only=numeric_only, min_count=min_count):
self._set_group_selection()
@doc(_groupby_agg_method_template, fname="min", no=False, mc=-1)
def min(self, numeric_only: bool = False, min_count: int = -1):
return self._agg_general(
numeric_only=numeric_only, min_count=min_count, alias="min", npfunc=np.min
)

# try a cython aggregation if we can
try:
return self._cython_agg_general(
how=alias,
alt=npfunc,
numeric_only=numeric_only,
min_count=min_count,
)
except DataError:
pass
except NotImplementedError as err:
if "function is not implemented for this dtype" in str(
err
) or "category dtype not supported" in str(err):
# raised in _get_cython_function, in some cases can
# be trimmed by implementing cython funcs for more dtypes
pass
else:
raise

# apply a non-cython aggregation
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
return result

set_function_name(func, name, cls)

return func

def first_compat(x, axis=0):
def first(x):
x = x.to_numpy()

x = x[notna(x)]
if len(x) == 0:
return np.nan
return x[0]
@doc(_groupby_agg_method_template, fname="max", no=False, mc=-1)
def max(self, numeric_only: bool = False, min_count: int = -1):
return self._agg_general(
numeric_only=numeric_only, min_count=min_count, alias="max", npfunc=np.max
)

@doc(_groupby_agg_method_template, fname="first", no=False, mc=-1)
def first(self, numeric_only: bool = False, min_count: int = -1):
def first_compat(x, axis: int = 0):
"""Helper function for first item that isn't NA.
"""
if isinstance(x, DataFrame):
return x.apply(first, axis=axis)
return x.apply(get_loc_notna, axis=axis, loc=0)
else:
return first(x)
return get_loc_notna(x, loc=0)

def last_compat(x, axis=0):
def last(x):
x = x.to_numpy()
x = x[notna(x)]
if len(x) == 0:
return np.nan
return x[-1]
return self._agg_general(
numeric_only=numeric_only,
min_count=min_count,
alias="first",
npfunc=first_compat,
)

@doc(_groupby_agg_method_template, fname="last", no=False, mc=-1)
def last(self, numeric_only: bool = False, min_count: int = -1):
def last_compat(x, axis: int = 0):
"""Helper function for last item that isn't NA.
"""
if isinstance(x, DataFrame):
return x.apply(last, axis=axis)
return x.apply(get_loc_notna, axis=axis, loc=-1)
else:
return last(x)
return get_loc_notna(x, loc=-1)

cls.sum = groupby_function("sum", "add", np.sum, min_count=0)
cls.prod = groupby_function("prod", "prod", np.prod, min_count=0)
cls.min = groupby_function("min", "min", np.min, numeric_only=False)
cls.max = groupby_function("max", "max", np.max, numeric_only=False)
cls.first = groupby_function("first", "first", first_compat, numeric_only=False)
cls.last = groupby_function("last", "last", last_compat, numeric_only=False)
return self._agg_general(
numeric_only=numeric_only,
min_count=min_count,
alias="last",
npfunc=last_compat,
)

@Substitution(name="groupby")
@Appender(_common_see_also)
Expand Down Expand Up @@ -2636,9 +2663,6 @@ def _reindex_output(
return output.reset_index(drop=True)


GroupBy._add_numeric_operations()


@doc(GroupBy)
def get_groupby(
obj: NDFrame,
Expand Down
6 changes: 3 additions & 3 deletions pandas/util/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def wrapper(*args, **kwargs) -> Callable[..., Any]:
return decorate


def doc(*args: Union[str, Callable], **kwargs: str) -> Callable[[F], F]:
def doc(*args: Union[str, Callable], **kwargs: Any) -> Callable[[F], F]:
"""
A decorator take docstring templates, concatenate them and perform string
substitution on it.
Expand All @@ -345,8 +345,8 @@ def doc(*args: Union[str, Callable], **kwargs: str) -> Callable[[F], F]:
*args : str or callable
The string / docstring / docstring template to be appended in order
after default docstring under function.
**kwargs : str
The string which would be used to format docstring template.
**kwargs : Any
The objects which would be used to format docstring template.
"""

def decorator(func: F) -> F:
Expand Down