Skip to content

Commit 321f9e1

Browse files
committed
BUG: multi-key groupby now works with multiple functions. starting to add *args, **kwargs. address GH #133
1 parent 0397d37 commit 321f9e1

File tree

2 files changed

+58
-49
lines changed

2 files changed

+58
-49
lines changed

pandas/core/groupby.py

+50-49
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def flatten(gen, level=0, shape_axis=0):
218218

219219
return flatten(self._generator_factory(data), shape_axis=self.axis)
220220

221-
def apply(self, func):
221+
def apply(self, func, *args, **kwargs):
222222
"""
223223
Apply function and combine results together in an intelligent way. The
224224
split-apply-combine combination rules attempt to be as common sense
@@ -255,16 +255,16 @@ def apply(self, func):
255255
-------
256256
applied : type depending on grouped object and function
257257
"""
258-
return self._python_apply_general(func)
258+
return self._python_apply_general(func, *args, **kwargs)
259259

260-
def aggregate(self, func):
260+
def aggregate(self, func, *args, **kwargs):
261261
raise NotImplementedError
262262

263-
def agg(self, func):
263+
def agg(self, func, *args, **kwargs):
264264
"""
265265
See docstring for aggregate
266266
"""
267-
return self.aggregate(func)
267+
return self.aggregate(func, *args, **kwargs)
268268

269269
def _get_names(self):
270270
axes = [ping.group_index for ping in self.groupings]
@@ -275,18 +275,18 @@ def _get_names(self):
275275
def _iterate_slices(self):
276276
yield self.name, self.obj
277277

278-
def transform(self, func):
278+
def transform(self, func, *args, **kwargs):
279279
raise NotImplementedError
280280

281-
def mean(self, axis=None):
281+
def mean(self):
282282
"""
283283
Compute mean of groups, excluding missing values
284284
285285
For multiple groupings, the result index will be a MultiIndex
286286
"""
287287
return self._cython_agg_general('mean')
288288

289-
def sum(self, axis=None):
289+
def sum(self):
290290
"""
291291
Compute sum of values, excluding missing values
292292
@@ -330,7 +330,7 @@ def _get_group_levels(self, mask):
330330
name_list = self._get_names()
331331
return [(name, raveled[mask]) for name, raveled in name_list]
332332

333-
def _python_agg_general(self, arg):
333+
def _python_agg_general(self, func, *args, **kwargs):
334334
group_shape = self._group_shape
335335
counts = np.zeros(group_shape, dtype=int)
336336

@@ -342,7 +342,7 @@ def _doit(reschunk, ctchunk, gen, shape_axis=0):
342342
ctchunk[i] = size
343343
if size == 0:
344344
continue
345-
reschunk[i] = arg(subgen)
345+
reschunk[i] = func(subgen, *args, **kwargs)
346346
else:
347347
_doit(reschunk[i], ctchunk[i], subgen,
348348
shape_axis=shape_axis)
@@ -378,14 +378,14 @@ def _doit(reschunk, ctchunk, gen, shape_axis=0):
378378

379379
return self._wrap_aggregated_output(output, mask)
380380

381-
def _python_apply_general(self, arg):
381+
def _python_apply_general(self, func, *args, **kwargs):
382382
result_keys = []
383383
result_values = []
384384

385385
not_indexed_same = False
386386
for key, group in self:
387387
group.name = key
388-
res = arg(group)
388+
res = func(group, *args, **kwargs)
389389
if not _is_indexed_like(res, group):
390390
not_indexed_same = True
391391

@@ -588,7 +588,7 @@ class SeriesGroupBy(GroupBy):
588588
def _agg_stride_shape(self):
589589
return ()
590590

591-
def aggregate(self, func_or_funcs):
591+
def aggregate(self, func_or_funcs, *args, **kwargs):
592592
"""
593593
Apply aggregation function or functions to groups, yielding most likely
594594
Series but in some cases DataFrame depending on the output of the
@@ -640,18 +640,18 @@ def aggregate(self, func_or_funcs):
640640
Series or DataFrame
641641
"""
642642
if isinstance(func_or_funcs, basestring):
643-
return getattr(self, func_or_funcs)()
644-
645-
if len(self.groupings) > 1:
646-
return self._python_agg_general(func_or_funcs)
643+
return getattr(self, func_or_funcs)(*args, **kwargs)
647644

648645
if hasattr(func_or_funcs,'__iter__'):
649646
ret = self._aggregate_multiple_funcs(func_or_funcs)
650647
else:
648+
if len(self.groupings) > 1:
649+
return self._python_agg_general(func_or_funcs, *args, **kwargs)
650+
651651
try:
652-
result = self._aggregate_simple(func_or_funcs)
652+
result = self._aggregate_simple(func_or_funcs, *args, **kwargs)
653653
except Exception:
654-
result = self._aggregate_named(func_or_funcs)
654+
result = self._aggregate_named(func_or_funcs, *args, **kwargs)
655655

656656
if len(result) > 0:
657657
if isinstance(result.values()[0], Series):
@@ -711,34 +711,30 @@ def _aggregate_multiple_funcs(self, arg):
711711
results = {}
712712

713713
for name, func in arg.iteritems():
714-
try:
715-
result = func(self)
716-
except Exception:
717-
result = self.aggregate(func)
718-
results[name] = result
714+
results[name] = self.aggregate(func)
719715

720716
return DataFrame(results)
721717

722-
def _aggregate_simple(self, arg):
718+
def _aggregate_simple(self, func, *args, **kwargs):
723719
values = self.obj.values
724720
result = {}
725721
for k, v in self.primary.indices.iteritems():
726-
result[k] = arg(values.take(v))
722+
result[k] = func(values.take(v), *args, **kwargs)
727723

728724
return result
729725

730-
def _aggregate_named(self, arg):
726+
def _aggregate_named(self, func, *args, **kwargs):
731727
result = {}
732728

733729
for name in self.primary:
734730
grp = self.get_group(name)
735731
grp.name = name
736-
output = arg(grp)
732+
output = func(grp, *args, **kwargs)
737733
result[name] = output
738734

739735
return result
740736

741-
def transform(self, func):
737+
def transform(self, func, *args, **kwargs):
742738
"""
743739
Call function producing a like-indexed Series on each group and return
744740
a Series with the transformed values
@@ -760,7 +756,7 @@ def transform(self, func):
760756

761757
for name, group in self:
762758
group.name = name
763-
res = func(group)
759+
res = func(group, *args, **kwargs)
764760
indexer, _ = self.obj.index.get_indexer(group.index)
765761
np.put(result, indexer, res)
766762

@@ -817,7 +813,7 @@ def _obj_with_exclusions(self):
817813
else:
818814
return self.obj
819815

820-
def aggregate(self, arg):
816+
def aggregate(self, arg, *args, **kwargs):
821817
"""
822818
Aggregate using input function or dict of {column -> function}
823819
@@ -843,26 +839,27 @@ def aggregate(self, arg):
843839
result = DataFrame(result)
844840
else:
845841
if len(self.groupings) > 1:
846-
return self._python_agg_general(arg)
847-
result = self._aggregate_generic(arg, axis=self.axis)
842+
return self._python_agg_general(arg, *args, **kwargs)
843+
result = self._aggregate_generic(arg, *args, **kwargs)
848844

849845
return result
850846

851-
def _aggregate_generic(self, agger, axis=0):
847+
def _aggregate_generic(self, func, *args, **kwargs):
852848
result = {}
853-
849+
axis = self.axis
854850
obj = self._obj_with_exclusions
855851

856852
try:
857853
for name in self.primary:
858854
data = self.get_group(name, obj=obj)
859855
try:
860-
result[name] = agger(data)
856+
result[name] = func(data, *args, **kwargs)
861857
except Exception:
862-
result[name] = data.apply(agger, axis=axis)
858+
wrapper = lambda x: func(x, *args, **kwargs)
859+
result[name] = data.apply(wrapper, axis=axis)
863860
except Exception, e1:
864861
if axis == 0:
865-
return self._aggregate_item_by_item(agger)
862+
return self._aggregate_item_by_item(func, *args, **kwargs)
866863
else:
867864
raise e1
868865

@@ -872,7 +869,7 @@ def _aggregate_generic(self, agger, axis=0):
872869

873870
return result
874871

875-
def _aggregate_item_by_item(self, agger):
872+
def _aggregate_item_by_item(self, func, *args, **kwargs):
876873
# only for axis==0
877874

878875
obj = self._obj_with_exclusions
@@ -881,7 +878,7 @@ def _aggregate_item_by_item(self, agger):
881878
cannot_agg = []
882879
for item in obj:
883880
try:
884-
result[item] = self[item].agg(agger)
881+
result[item] = self[item].agg(func, *args, **kwargs)
885882
except (ValueError, TypeError):
886883
cannot_agg.append(item)
887884
continue
@@ -954,15 +951,15 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
954951

955952
return DataFrame(stacked_values, index=index, columns=columns)
956953

957-
def transform(self, func):
954+
def transform(self, func, *args, **kwargs):
958955
"""
959956
Call function producing a like-indexed DataFrame on each group and
960957
return a DataFrame having the same indexes as the original object
961958
filled with the transformed values
962959
963960
Parameters
964961
----------
965-
func : function
962+
f : function
966963
Function to apply to each subframe
967964
968965
Note
@@ -982,9 +979,10 @@ def transform(self, func):
982979
group.name = name
983980

984981
try:
985-
res = group.apply(func, axis=self.axis)
982+
wrapper = lambda x: func(x, *args, **kwargs)
983+
res = group.apply(wrapper, axis=self.axis)
986984
except Exception: # pragma: no cover
987-
res = func(group)
985+
res = func(group, *args, **kwargs)
988986

989987
# broadcasting
990988
if isinstance(res, Series):
@@ -1081,7 +1079,7 @@ def _all_indexes_same(indexes):
10811079

10821080
class PanelGroupBy(GroupBy):
10831081

1084-
def aggregate(self, func):
1082+
def aggregate(self, func, *args, **kwargs):
10851083
"""
10861084
Aggregate using input function or dict of {column -> function}
10871085
@@ -1096,19 +1094,22 @@ def aggregate(self, func):
10961094
-------
10971095
aggregated : Panel
10981096
"""
1099-
return self._aggregate_generic(func, axis=self.axis)
1097+
return self._aggregate_generic(func, *args, **kwargs)
11001098

1101-
def _aggregate_generic(self, agger, axis=0):
1099+
def _aggregate_generic(self, func, *args, **kwargs):
11021100
result = {}
11031101

1102+
axis = self.axis
1103+
11041104
obj = self._obj_with_exclusions
11051105

11061106
for name in self.primary:
11071107
data = self.get_group(name, obj=obj)
11081108
try:
1109-
result[name] = agger(data)
1109+
result[name] = func(data, *args, **kwargs)
11101110
except Exception:
1111-
result[name] = data.apply(agger, axis=axis)
1111+
wrapper = lambda x: func(x, *args, **kwargs)
1112+
result[name] = data.apply(wrapper, axis=axis)
11121113

11131114
result = Panel.fromDict(result, intersect=False)
11141115

pandas/tests/test_groupby.py

+8
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,14 @@ def test_multi_func(self):
475475
grouped = df.groupby(['k1', 'k2'])
476476
grouped.agg(np.sum)
477477

478+
def test_multi_key_multiple_functions(self):
479+
grouped = self.df.groupby(['A', 'B'])['C']
480+
481+
agged = grouped.agg([np.mean, np.std])
482+
expected = DataFrame({'mean' : grouped.agg(np.mean),
483+
'std' : grouped.agg(np.std)})
484+
assert_frame_equal(agged, expected)
485+
478486
def test_groupby_multiple_columns(self):
479487
data = self.df
480488
grouped = data.groupby(['A', 'B'])

0 commit comments

Comments
 (0)