Skip to content

Commit c4a1aad

Browse files
committed
CLN/TYP: Groupby agg methods
1 parent f25ed6f commit c4a1aad

File tree

2 files changed

+160
-99
lines changed

2 files changed

+160
-99
lines changed

pandas/core/groupby/generic.py

+82
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,28 @@
7979
if TYPE_CHECKING:
8080
from pandas.core.internals import Block
8181

82+
_agg_template = """
83+
Compute %(f)s of group values.
84+
85+
Parameters
86+
----------
87+
numeric_only : bool, default %(no)s
88+
Include only float, int, boolean columns. If None, will attempt to use
89+
everything, then use only numeric data.
90+
min_count : int, default %(mc)s
91+
The required number of valid values to perform the operation. If fewer
92+
than ``min_count`` non-NA values are present the result will be NA.
93+
94+
Returns
95+
-------
96+
%(return_type)s
97+
Computed %(f)s of values within each group.
98+
99+
See Also
100+
--------
101+
%(return_type)s.groupby
102+
"""
103+
82104

83105
NamedAgg = namedtuple("NamedAgg", ["column", "aggfunc"])
84106
# TODO(typing) the return value on this callable should be any *scalar*.
@@ -789,6 +811,36 @@ def count(self) -> Series:
789811
)
790812
return self._reindex_output(result, fill_value=0)
791813

814+
@Substitution(f="sum", no=True, mc=0, return_type="Series")
815+
@Appender(_agg_template)
816+
def sum(self, numeric_only=True, min_count=0) -> Series:
817+
return super().sum(numeric_only=numeric_only, min_count=min_count)
818+
819+
@Substitution(f="prod", no=True, mc=0, return_type="Series")
820+
@Appender(_agg_template)
821+
def prod(self, numeric_only=True, min_count=0) -> Series:
822+
return super().prod(numeric_only=numeric_only, min_count=min_count)
823+
824+
@Substitution(f="min", no=False, mc=-1, return_type="Series")
825+
@Appender(_agg_template)
826+
def min(self, numeric_only=False, min_count=-1) -> Series:
827+
return super().min(numeric_only=numeric_only, min_count=min_count)
828+
829+
@Substitution(f="max", no=False, mc=-1, return_type="Series")
830+
@Appender(_agg_template)
831+
def max(self, numeric_only=False, min_count=-1) -> Series:
832+
return super().max(numeric_only=numeric_only, min_count=min_count)
833+
834+
@Substitution(f="first", no=False, mc=-1, return_type="Series")
835+
@Appender(_agg_template)
836+
def first(self, numeric_only=False, min_count=-1) -> Series:
837+
return super().first(numeric_only=numeric_only, min_count=min_count)
838+
839+
@Substitution(f="last", no=False, mc=-1, return_type="Series")
840+
@Appender(_agg_template)
841+
def last(self, numeric_only=False, min_count=-1) -> Series:
842+
return super().last(numeric_only=numeric_only, min_count=min_count)
843+
792844
def _apply_to_column_groupbys(self, func):
793845
""" return a pass thru """
794846
return func(self)
@@ -1863,6 +1915,36 @@ def groupby_series(obj, col=None):
18631915
results.index = ibase.default_index(len(results))
18641916
return results
18651917

1918+
@Substitution(f="sum", no=True, mc=0, return_type="DataFrame")
1919+
@Appender(_agg_template)
1920+
def sum(self, numeric_only=True, min_count=0) -> DataFrame:
1921+
return super().sum(numeric_only=numeric_only, min_count=min_count)
1922+
1923+
@Substitution(f="prod", no=True, mc=0, return_type="DataFrame")
1924+
@Appender(_agg_template)
1925+
def prod(self, numeric_only=True, min_count=0) -> DataFrame:
1926+
return super().prod(numeric_only=numeric_only, min_count=min_count)
1927+
1928+
@Substitution(f="min", no=False, mc=-1, return_type="DataFrame")
1929+
@Appender(_agg_template)
1930+
def min(self, numeric_only=False, min_count=-1) -> DataFrame:
1931+
return super().min(numeric_only=numeric_only, min_count=min_count)
1932+
1933+
@Substitution(f="max", no=False, mc=-1, return_type="DataFrame")
1934+
@Appender(_agg_template)
1935+
def max(self, numeric_only=False, min_count=-1) -> DataFrame:
1936+
return super().max(numeric_only=numeric_only, min_count=min_count)
1937+
1938+
@Substitution(f="first", no=False, mc=-1, return_type="DataFrame")
1939+
@Appender(_agg_template)
1940+
def first(self, numeric_only=False, min_count=-1) -> DataFrame:
1941+
return super().first(numeric_only=numeric_only, min_count=min_count)
1942+
1943+
@Substitution(f="last", no=False, mc=-1, return_type="DataFrame")
1944+
@Appender(_agg_template)
1945+
def last(self, numeric_only=False, min_count=-1) -> DataFrame:
1946+
return super().last(numeric_only=numeric_only, min_count=min_count)
1947+
18661948
boxplot = boxplot_frame_groupby
18671949

18681950

pandas/core/groupby/groupby.py

+78-99
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ class providing the base-class of operations.
3434
from pandas._libs import Timestamp
3535
import pandas._libs.groupby as libgroupby
3636
from pandas._typing import FrameOrSeries, Scalar
37-
from pandas.compat import set_function_name
3837
from pandas.compat.numpy import function as nv
3938
from pandas.errors import AbstractMethodError
4039
from pandas.util._decorators import Appender, Substitution, cache_readonly
@@ -871,6 +870,32 @@ def _wrap_transformed_output(self, output: Mapping[base.OutputKey, np.ndarray]):
871870
def _wrap_applied_output(self, keys, values, not_indexed_same: bool = False):
872871
raise AbstractMethodError(self)
873872

873+
def _agg_general(
874+
self, numeric_only=True, min_count=-1, *, alias: str, npfunc: Callable
875+
):
876+
self._set_group_selection()
877+
878+
# try a cython aggregation if we can
879+
try:
880+
return self._cython_agg_general(
881+
how=alias, alt=npfunc, numeric_only=numeric_only, min_count=min_count,
882+
)
883+
except DataError:
884+
pass
885+
except NotImplementedError as err:
886+
if "function is not implemented for this dtype" in str(
887+
err
888+
) or "category dtype not supported" in str(err):
889+
# raised in _get_cython_function, in some cases can
890+
# be trimmed by implementing cython funcs for more dtypes
891+
pass
892+
else:
893+
raise
894+
895+
# apply a non-cython aggregation
896+
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
897+
return result
898+
874899
def _cython_agg_general(
875900
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1
876901
):
@@ -1336,105 +1361,62 @@ def size(self):
13361361
result.name = self.obj.name
13371362
return self._reindex_output(result, fill_value=0)
13381363

1339-
@classmethod
1340-
def _add_numeric_operations(cls):
1341-
"""
1342-
Add numeric operations to the GroupBy generically.
1364+
def sum(self, numeric_only=True, min_count=0):
1365+
return self._agg_general(
1366+
numeric_only=numeric_only, min_count=min_count, alias="add", npfunc=np.sum
1367+
)
1368+
1369+
def prod(self, numeric_only=True, min_count=0):
1370+
return self._agg_general(
1371+
numeric_only=numeric_only, min_count=min_count, alias="prod", npfunc=np.prod
1372+
)
1373+
1374+
def min(self, numeric_only=False, min_count=-1):
1375+
return self._agg_general(
1376+
numeric_only=numeric_only, min_count=min_count, alias="min", npfunc=np.min
1377+
)
1378+
1379+
def max(self, numeric_only=False, min_count=-1):
1380+
return self._agg_general(
1381+
numeric_only=numeric_only, min_count=min_count, alias="max", npfunc=np.max
1382+
)
1383+
1384+
@staticmethod
1385+
def _get_loc(x, axis: int = 0, *, loc: int):
1386+
"""Helper function for first/last item that isn't NA.
13431387
"""
13441388

1345-
def groupby_function(
1346-
name: str,
1347-
alias: str,
1348-
npfunc,
1349-
numeric_only: bool = True,
1350-
min_count: int = -1,
1351-
):
1389+
def get_loc_notna(x, loc: int):
1390+
x = x.to_numpy()
1391+
x = x[notna(x)]
1392+
if len(x) == 0:
1393+
return np.nan
1394+
return x[loc]
13521395

1353-
_local_template = """
1354-
Compute %(f)s of group values.
1355-
1356-
Parameters
1357-
----------
1358-
numeric_only : bool, default %(no)s
1359-
Include only float, int, boolean columns. If None, will attempt to use
1360-
everything, then use only numeric data.
1361-
min_count : int, default %(mc)s
1362-
The required number of valid values to perform the operation. If fewer
1363-
than ``min_count`` non-NA values are present the result will be NA.
1364-
1365-
Returns
1366-
-------
1367-
Series or DataFrame
1368-
Computed %(f)s of values within each group.
1369-
"""
1370-
1371-
@Substitution(name="groupby", f=name, no=numeric_only, mc=min_count)
1372-
@Appender(_common_see_also)
1373-
@Appender(_local_template)
1374-
def func(self, numeric_only=numeric_only, min_count=min_count):
1375-
self._set_group_selection()
1376-
1377-
# try a cython aggregation if we can
1378-
try:
1379-
return self._cython_agg_general(
1380-
how=alias,
1381-
alt=npfunc,
1382-
numeric_only=numeric_only,
1383-
min_count=min_count,
1384-
)
1385-
except DataError:
1386-
pass
1387-
except NotImplementedError as err:
1388-
if "function is not implemented for this dtype" in str(
1389-
err
1390-
) or "category dtype not supported" in str(err):
1391-
# raised in _get_cython_function, in some cases can
1392-
# be trimmed by implementing cython funcs for more dtypes
1393-
pass
1394-
else:
1395-
raise
1396-
1397-
# apply a non-cython aggregation
1398-
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
1399-
return result
1400-
1401-
set_function_name(func, name, cls)
1402-
1403-
return func
1404-
1405-
def first_compat(x, axis=0):
1406-
def first(x):
1407-
x = x.to_numpy()
1408-
1409-
x = x[notna(x)]
1410-
if len(x) == 0:
1411-
return np.nan
1412-
return x[0]
1413-
1414-
if isinstance(x, DataFrame):
1415-
return x.apply(first, axis=axis)
1416-
else:
1417-
return first(x)
1418-
1419-
def last_compat(x, axis=0):
1420-
def last(x):
1421-
x = x.to_numpy()
1422-
x = x[notna(x)]
1423-
if len(x) == 0:
1424-
return np.nan
1425-
return x[-1]
1426-
1427-
if isinstance(x, DataFrame):
1428-
return x.apply(last, axis=axis)
1429-
else:
1430-
return last(x)
1396+
if isinstance(x, DataFrame):
1397+
return x.apply(get_loc_notna, axis=axis, loc=loc)
1398+
else:
1399+
return get_loc_notna(x, loc=loc)
1400+
1401+
def first(self, numeric_only=False, min_count=-1):
1402+
first_compat = partial(self._get_loc, loc=0)
14311403

1432-
cls.sum = groupby_function("sum", "add", np.sum, min_count=0)
1433-
cls.prod = groupby_function("prod", "prod", np.prod, min_count=0)
1434-
cls.min = groupby_function("min", "min", np.min, numeric_only=False)
1435-
cls.max = groupby_function("max", "max", np.max, numeric_only=False)
1436-
cls.first = groupby_function("first", "first", first_compat, numeric_only=False)
1437-
cls.last = groupby_function("last", "last", last_compat, numeric_only=False)
1404+
return self._agg_general(
1405+
numeric_only=numeric_only,
1406+
min_count=min_count,
1407+
alias="first",
1408+
npfunc=first_compat,
1409+
)
1410+
1411+
def last(self, numeric_only=False, min_count=-1):
1412+
last_compat = partial(self._get_loc, loc=-1)
1413+
1414+
return self._agg_general(
1415+
numeric_only=numeric_only,
1416+
min_count=min_count,
1417+
alias="last",
1418+
npfunc=last_compat,
1419+
)
14381420

14391421
@Substitution(name="groupby")
14401422
@Appender(_common_see_also)
@@ -2528,9 +2510,6 @@ def _reindex_output(
25282510
return output.reset_index(drop=True)
25292511

25302512

2531-
GroupBy._add_numeric_operations()
2532-
2533-
25342513
@Appender(GroupBy.__doc__)
25352514
def get_groupby(
25362515
obj: NDFrame,

0 commit comments

Comments
 (0)