Skip to content

Commit 781b9b3

Browse files
committed
Move _groupby_function inside GroupBy
Add support for __qualname__
1 parent 68013bf commit 781b9b3

File tree

2 files changed

+98
-76
lines changed

2 files changed

+98
-76
lines changed

pandas/core/groupby.py

+72-68
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
)
1313

1414
from pandas import compat
15-
from pandas.compat.numpy import function as nv
16-
from pandas.compat.numpy import _np_version_under1p8
15+
from pandas.compat.numpy import function as nv, _np_version_under1p8
16+
from pandas.compat import set_function_name
1717

1818
from pandas.types.common import (is_numeric_dtype,
1919
is_timedelta64_dtype, is_datetime64_dtype,
@@ -172,64 +172,6 @@
172172
'cummin', 'cummax'])
173173

174174

175-
def _groupby_function(name, alias, npfunc, numeric_only=True,
176-
_convert=False):
177-
178-
_local_template = "Compute %(f)s of group values"
179-
180-
@Substitution(name='groupby', f=name)
181-
@Appender(_doc_template)
182-
@Appender(_local_template)
183-
def f(self, **kwargs):
184-
if 'numeric_only' not in kwargs:
185-
kwargs['numeric_only'] = numeric_only
186-
self._set_group_selection()
187-
try:
188-
return self._cython_agg_general(alias, alt=npfunc, **kwargs)
189-
except AssertionError as e:
190-
raise SpecificationError(str(e))
191-
except Exception:
192-
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
193-
if _convert:
194-
result = result._convert(datetime=True)
195-
return result
196-
197-
f.__name__ = name
198-
199-
return f
200-
201-
202-
def _first_compat(x, axis=0):
203-
204-
def _first(x):
205-
206-
x = np.asarray(x)
207-
x = x[notnull(x)]
208-
if len(x) == 0:
209-
return np.nan
210-
return x[0]
211-
212-
if isinstance(x, DataFrame):
213-
return x.apply(_first, axis=axis)
214-
else:
215-
return _first(x)
216-
217-
218-
def _last_compat(x, axis=0):
219-
def _last(x):
220-
221-
x = np.asarray(x)
222-
x = x[notnull(x)]
223-
if len(x) == 0:
224-
return np.nan
225-
return x[-1]
226-
227-
if isinstance(x, DataFrame):
228-
return x.apply(_last, axis=axis)
229-
else:
230-
return _last(x)
231-
232-
233175
class Grouper(object):
234176
"""
235177
A Grouper allows the user to specify a groupby instruction for a target
@@ -1184,14 +1126,74 @@ def size(self):
11841126
result.name = getattr(self, 'name', None)
11851127
return result
11861128

1187-
sum = _groupby_function('sum', 'add', np.sum)
1188-
prod = _groupby_function('prod', 'prod', np.prod)
1189-
min = _groupby_function('min', 'min', np.min, numeric_only=False)
1190-
max = _groupby_function('max', 'max', np.max, numeric_only=False)
1191-
first = _groupby_function('first', 'first', _first_compat,
1192-
numeric_only=False, _convert=True)
1193-
last = _groupby_function('last', 'last', _last_compat, numeric_only=False,
1194-
_convert=True)
1129+
@classmethod
1130+
def _add_numeric_operations(cls):
1131+
""" add numeric operations to the GroupBy generically """
1132+
1133+
def _groupby_function(name, alias, npfunc,
1134+
numeric_only=True, _convert=False):
1135+
1136+
_local_template = "Compute %(f)s of group values"
1137+
1138+
@Substitution(name='groupby', f=name)
1139+
@Appender(_doc_template)
1140+
@Appender(_local_template)
1141+
def f(self, **kwargs):
1142+
if 'numeric_only' not in kwargs:
1143+
kwargs['numeric_only'] = numeric_only
1144+
self._set_group_selection()
1145+
try:
1146+
return self._cython_agg_general(alias, alt=npfunc, **kwargs)
1147+
except AssertionError as e:
1148+
raise SpecificationError(str(e))
1149+
except Exception:
1150+
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
1151+
if _convert:
1152+
result = result._convert(datetime=True)
1153+
return result
1154+
1155+
set_function_name(f, name, cls)
1156+
1157+
return f
1158+
1159+
def _first_compat(x, axis=0):
1160+
1161+
def _first(x):
1162+
1163+
x = np.asarray(x)
1164+
x = x[notnull(x)]
1165+
if len(x) == 0:
1166+
return np.nan
1167+
return x[0]
1168+
1169+
if isinstance(x, DataFrame):
1170+
return x.apply(_first, axis=axis)
1171+
else:
1172+
return _first(x)
1173+
1174+
1175+
def _last_compat(x, axis=0):
1176+
def _last(x):
1177+
1178+
x = np.asarray(x)
1179+
x = x[notnull(x)]
1180+
if len(x) == 0:
1181+
return np.nan
1182+
return x[-1]
1183+
1184+
if isinstance(x, DataFrame):
1185+
return x.apply(_last, axis=axis)
1186+
else:
1187+
return _last(x)
1188+
1189+
cls.sum = _groupby_function('sum', 'add', np.sum)
1190+
cls.prod = _groupby_function('prod', 'prod', np.prod)
1191+
cls.min = _groupby_function('min', 'min', np.min, numeric_only=False)
1192+
cls.max = _groupby_function('max', 'max', np.max, numeric_only=False)
1193+
cls.first = _groupby_function('first', 'first', _first_compat,
1194+
numeric_only=False, _convert=True)
1195+
cls.last = _groupby_function('last', 'last', _last_compat, numeric_only=False,
1196+
_convert=True)
11951197

11961198
@Substitution(name='groupby')
11971199
@Appender(_doc_template)
@@ -1603,6 +1605,8 @@ def tail(self, n=5):
16031605
mask = self._cumcount_array(ascending=False) < n
16041606
return self._selected_obj[mask]
16051607

1608+
GroupBy._add_numeric_operations()
1609+
16061610

16071611
@Appender(GroupBy.__doc__)
16081612
def groupby(obj, by, **kwds):

pandas/tests/groupby/test_groupby.py

+26-8
Original file line numberDiff line numberDiff line change
@@ -3849,19 +3849,36 @@ def test_groupby_whitelist(self):
38493849
'nsmallest',
38503850
])
38513851

3852-
names_dont_match_pair = (self.DF_METHOD_NAMES_THAT_DONT_MATCH_ATTRIBUTE,
3853-
self.S_METHOD_NAMES_THAT_DONT_MATCH_ATTRIBUTE)
3854-
for obj, whitelist, names_dont_match in zip((df, s), (df_whitelist, s_whitelist), names_dont_match_pair):
3852+
names_dont_match_pair = (
3853+
self.DF_METHOD_NAMES_THAT_DONT_MATCH_ATTRIBUTE,
3854+
self.S_METHOD_NAMES_THAT_DONT_MATCH_ATTRIBUTE)
3855+
for obj, whitelist, names_dont_match in (
3856+
zip((df, s),
3857+
(df_whitelist, s_whitelist),
3858+
names_dont_match_pair)):
3859+
38553860
gb = obj.groupby(df.letters)
3856-
self.assertEqual(whitelist, gb._apply_whitelist)
3861+
3862+
assert whitelist == gb._apply_whitelist
38573863
for m in whitelist:
38583864
f = getattr(type(gb), m)
3865+
3866+
# name
38593867
try:
38603868
n = f.__name__
38613869
except AttributeError:
38623870
continue
38633871
if m not in names_dont_match:
3864-
self.assertEqual(n, m)
3872+
assert n == m
3873+
3874+
# qualname
3875+
if compat.PY3:
3876+
try:
3877+
n = f.__qualname__
3878+
except AttributeError:
3879+
continue
3880+
if m not in names_dont_match:
3881+
assert n.endswith(m)
38653882

38663883
def test_groupby_method_names_that_dont_match_attribute(self):
38673884
from string import ascii_lowercase
@@ -3873,9 +3890,10 @@ def test_groupby_method_names_that_dont_match_attribute(self):
38733890
gb = df.groupby(df.letters)
38743891
s = df.floats
38753892

3876-
names_dont_match_pair = (self.DF_METHOD_NAMES_THAT_DONT_MATCH_ATTRIBUTE,
3877-
self.S_METHOD_NAMES_THAT_DONT_MATCH_ATTRIBUTE)
3878-
for obj, names_dont_match in zip((df, s), names_dont_match_pair):
3893+
names_dont_match_pair = (
3894+
self.DF_METHOD_NAMES_THAT_DONT_MATCH_ATTRIBUTE,
3895+
self.S_METHOD_NAMES_THAT_DONT_MATCH_ATTRIBUTE)
3896+
for obj, names_dont_match in zip((df, s), names_dont_match_pair):
38793897
gb = obj.groupby(df.letters)
38803898
for m in names_dont_match:
38813899
f = getattr(gb, m)

0 commit comments

Comments
 (0)