Skip to content

Commit 92c4604

Browse files
committed
CLN/TYP: Groupby agg methods
1 parent 9f746a7 commit 92c4604

File tree

2 files changed

+161
-97
lines changed

2 files changed

+161
-97
lines changed

pandas/core/groupby/generic.py

+82
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,28 @@
8787
if TYPE_CHECKING:
8888
from pandas.core.internals import Block
8989

90+
_agg_template = """
91+
Compute %(f)s of group values.
92+
93+
Parameters
94+
----------
95+
numeric_only : bool, default %(no)s
96+
Include only float, int, boolean columns. If None, will attempt to use
97+
everything, then use only numeric data.
98+
min_count : int, default %(mc)s
99+
The required number of valid values to perform the operation. If fewer
100+
than ``min_count`` non-NA values are present the result will be NA.
101+
102+
Returns
103+
-------
104+
%(return_type)s
105+
Computed %(f)s of values within each group.
106+
107+
See Also
108+
--------
109+
%(return_type)s.groupby
110+
"""
111+
90112

91113
NamedAgg = namedtuple("NamedAgg", ["column", "aggfunc"])
92114
# TODO(typing) the return value on this callable should be any *scalar*.
@@ -805,6 +827,36 @@ def count(self) -> Series:
805827
)
806828
return self._reindex_output(result, fill_value=0)
807829

830+
@Substitution(f="sum", no=True, mc=0, return_type="Series")
831+
@Appender(_agg_template)
832+
def sum(self, numeric_only=True, min_count=0) -> Series:
833+
return super().sum(numeric_only=numeric_only, min_count=min_count)
834+
835+
@Substitution(f="prod", no=True, mc=0, return_type="Series")
836+
@Appender(_agg_template)
837+
def prod(self, numeric_only=True, min_count=0) -> Series:
838+
return super().prod(numeric_only=numeric_only, min_count=min_count)
839+
840+
@Substitution(f="min", no=False, mc=-1, return_type="Series")
841+
@Appender(_agg_template)
842+
def min(self, numeric_only=False, min_count=-1) -> Series:
843+
return super().min(numeric_only=numeric_only, min_count=min_count)
844+
845+
@Substitution(f="max", no=False, mc=-1, return_type="Series")
846+
@Appender(_agg_template)
847+
def max(self, numeric_only=False, min_count=-1) -> Series:
848+
return super().max(numeric_only=numeric_only, min_count=min_count)
849+
850+
@Substitution(f="first", no=False, mc=-1, return_type="Series")
851+
@Appender(_agg_template)
852+
def first(self, numeric_only=False, min_count=-1) -> Series:
853+
return super().first(numeric_only=numeric_only, min_count=min_count)
854+
855+
@Substitution(f="last", no=False, mc=-1, return_type="Series")
856+
@Appender(_agg_template)
857+
def last(self, numeric_only=False, min_count=-1) -> Series:
858+
return super().last(numeric_only=numeric_only, min_count=min_count)
859+
808860
def _apply_to_column_groupbys(self, func):
809861
""" return a pass thru """
810862
return func(self)
@@ -1884,6 +1936,36 @@ def groupby_series(obj, col=None):
18841936
results.index = ibase.default_index(len(results))
18851937
return results
18861938

1939+
@Substitution(f="sum", no=True, mc=0, return_type="DataFrame")
1940+
@Appender(_agg_template)
1941+
def sum(self, numeric_only=True, min_count=0) -> DataFrame:
1942+
return super().sum(numeric_only=numeric_only, min_count=min_count)
1943+
1944+
@Substitution(f="prod", no=True, mc=0, return_type="DataFrame")
1945+
@Appender(_agg_template)
1946+
def prod(self, numeric_only=True, min_count=0) -> DataFrame:
1947+
return super().prod(numeric_only=numeric_only, min_count=min_count)
1948+
1949+
@Substitution(f="min", no=False, mc=-1, return_type="DataFrame")
1950+
@Appender(_agg_template)
1951+
def min(self, numeric_only=False, min_count=-1) -> DataFrame:
1952+
return super().min(numeric_only=numeric_only, min_count=min_count)
1953+
1954+
@Substitution(f="max", no=False, mc=-1, return_type="DataFrame")
1955+
@Appender(_agg_template)
1956+
def max(self, numeric_only=False, min_count=-1) -> DataFrame:
1957+
return super().max(numeric_only=numeric_only, min_count=min_count)
1958+
1959+
@Substitution(f="first", no=False, mc=-1, return_type="DataFrame")
1960+
@Appender(_agg_template)
1961+
def first(self, numeric_only=False, min_count=-1) -> DataFrame:
1962+
return super().first(numeric_only=numeric_only, min_count=min_count)
1963+
1964+
@Substitution(f="last", no=False, mc=-1, return_type="DataFrame")
1965+
@Appender(_agg_template)
1966+
def last(self, numeric_only=False, min_count=-1) -> DataFrame:
1967+
return super().last(numeric_only=numeric_only, min_count=min_count)
1968+
18871969
boxplot = boxplot_frame_groupby
18881970

18891971

pandas/core/groupby/groupby.py

+79-97
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ class providing the base-class of operations.
3636
from pandas._libs import Timestamp
3737
import pandas._libs.groupby as libgroupby
3838
from pandas._typing import FrameOrSeries, Scalar
39-
from pandas.compat import set_function_name
4039
from pandas.compat.numpy import function as nv
4140
from pandas.errors import AbstractMethodError
4241
from pandas.util._decorators import Appender, Substitution, cache_readonly, doc
@@ -942,6 +941,32 @@ def _wrap_transformed_output(self, output: Mapping[base.OutputKey, np.ndarray]):
942941
def _wrap_applied_output(self, keys, values, not_indexed_same: bool = False):
943942
raise AbstractMethodError(self)
944943

944+
def _agg_general(
945+
self, numeric_only=True, min_count=-1, *, alias: str, npfunc: Callable
946+
):
947+
self._set_group_selection()
948+
949+
# try a cython aggregation if we can
950+
try:
951+
return self._cython_agg_general(
952+
how=alias, alt=npfunc, numeric_only=numeric_only, min_count=min_count,
953+
)
954+
except DataError:
955+
pass
956+
except NotImplementedError as err:
957+
if "function is not implemented for this dtype" in str(
958+
err
959+
) or "category dtype not supported" in str(err):
960+
# raised in _get_cython_function, in some cases can
961+
# be trimmed by implementing cython funcs for more dtypes
962+
pass
963+
else:
964+
raise
965+
966+
# apply a non-cython aggregation
967+
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
968+
return result
969+
945970
def _cython_agg_general(
946971
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1
947972
):
@@ -1424,105 +1449,62 @@ def size(self):
14241449
result.name = self.obj.name
14251450
return self._reindex_output(result, fill_value=0)
14261451

1427-
@classmethod
1428-
def _add_numeric_operations(cls):
1429-
"""
1430-
Add numeric operations to the GroupBy generically.
1452+
def sum(self, numeric_only=True, min_count=0):
1453+
return self._agg_general(
1454+
numeric_only=numeric_only, min_count=min_count, alias="add", npfunc=np.sum
1455+
)
1456+
1457+
def prod(self, numeric_only=True, min_count=0):
1458+
return self._agg_general(
1459+
numeric_only=numeric_only, min_count=min_count, alias="prod", npfunc=np.prod
1460+
)
1461+
1462+
def min(self, numeric_only=False, min_count=-1):
1463+
return self._agg_general(
1464+
numeric_only=numeric_only, min_count=min_count, alias="min", npfunc=np.min
1465+
)
1466+
1467+
def max(self, numeric_only=False, min_count=-1):
1468+
return self._agg_general(
1469+
numeric_only=numeric_only, min_count=min_count, alias="max", npfunc=np.max
1470+
)
1471+
1472+
@staticmethod
1473+
def _get_loc(x, axis: int = 0, *, loc: int):
1474+
"""Helper function for first/last item that isn't NA.
14311475
"""
14321476

1433-
def groupby_function(
1434-
name: str,
1435-
alias: str,
1436-
npfunc,
1437-
numeric_only: bool = True,
1438-
min_count: int = -1,
1439-
):
1477+
def get_loc_notna(x, loc: int):
1478+
x = x.to_numpy()
1479+
x = x[notna(x)]
1480+
if len(x) == 0:
1481+
return np.nan
1482+
return x[loc]
14401483

1441-
_local_template = """
1442-
Compute %(f)s of group values.
1443-
1444-
Parameters
1445-
----------
1446-
numeric_only : bool, default %(no)s
1447-
Include only float, int, boolean columns. If None, will attempt to use
1448-
everything, then use only numeric data.
1449-
min_count : int, default %(mc)s
1450-
The required number of valid values to perform the operation. If fewer
1451-
than ``min_count`` non-NA values are present the result will be NA.
1452-
1453-
Returns
1454-
-------
1455-
Series or DataFrame
1456-
Computed %(f)s of values within each group.
1457-
"""
1458-
1459-
@Substitution(name="groupby", f=name, no=numeric_only, mc=min_count)
1460-
@Appender(_common_see_also)
1461-
@Appender(_local_template)
1462-
def func(self, numeric_only=numeric_only, min_count=min_count):
1463-
self._set_group_selection()
1464-
1465-
# try a cython aggregation if we can
1466-
try:
1467-
return self._cython_agg_general(
1468-
how=alias,
1469-
alt=npfunc,
1470-
numeric_only=numeric_only,
1471-
min_count=min_count,
1472-
)
1473-
except DataError:
1474-
pass
1475-
except NotImplementedError as err:
1476-
if "function is not implemented for this dtype" in str(
1477-
err
1478-
) or "category dtype not supported" in str(err):
1479-
# raised in _get_cython_function, in some cases can
1480-
# be trimmed by implementing cython funcs for more dtypes
1481-
pass
1482-
else:
1483-
raise
1484-
1485-
# apply a non-cython aggregation
1486-
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
1487-
return result
1488-
1489-
set_function_name(func, name, cls)
1490-
1491-
return func
1492-
1493-
def first_compat(x, axis=0):
1494-
def first(x):
1495-
x = x.to_numpy()
1496-
1497-
x = x[notna(x)]
1498-
if len(x) == 0:
1499-
return np.nan
1500-
return x[0]
1501-
1502-
if isinstance(x, DataFrame):
1503-
return x.apply(first, axis=axis)
1504-
else:
1505-
return first(x)
1506-
1507-
def last_compat(x, axis=0):
1508-
def last(x):
1509-
x = x.to_numpy()
1510-
x = x[notna(x)]
1511-
if len(x) == 0:
1512-
return np.nan
1513-
return x[-1]
1514-
1515-
if isinstance(x, DataFrame):
1516-
return x.apply(last, axis=axis)
1517-
else:
1518-
return last(x)
1519-
1520-
cls.sum = groupby_function("sum", "add", np.sum, min_count=0)
1521-
cls.prod = groupby_function("prod", "prod", np.prod, min_count=0)
1522-
cls.min = groupby_function("min", "min", np.min, numeric_only=False)
1523-
cls.max = groupby_function("max", "max", np.max, numeric_only=False)
1524-
cls.first = groupby_function("first", "first", first_compat, numeric_only=False)
1525-
cls.last = groupby_function("last", "last", last_compat, numeric_only=False)
1484+
if isinstance(x, DataFrame):
1485+
return x.apply(get_loc_notna, axis=axis, loc=loc)
1486+
else:
1487+
return get_loc_notna(x, loc=loc)
1488+
1489+
def first(self, numeric_only=False, min_count=-1):
1490+
first_compat = partial(self._get_loc, loc=0)
1491+
1492+
return self._agg_general(
1493+
numeric_only=numeric_only,
1494+
min_count=min_count,
1495+
alias="first",
1496+
npfunc=first_compat,
1497+
)
1498+
1499+
def last(self, numeric_only=False, min_count=-1):
1500+
last_compat = partial(self._get_loc, loc=-1)
1501+
1502+
return self._agg_general(
1503+
numeric_only=numeric_only,
1504+
min_count=min_count,
1505+
alias="last",
1506+
npfunc=last_compat,
1507+
)
15261508

15271509
@Substitution(name="groupby")
15281510
@Appender(_common_see_also)

0 commit comments

Comments
 (0)