@@ -218,7 +218,7 @@ def flatten(gen, level=0, shape_axis=0):
218
218
219
219
return flatten (self ._generator_factory (data ), shape_axis = self .axis )
220
220
221
- def apply (self , func ):
221
+ def apply (self , func , * args , ** kwargs ):
222
222
"""
223
223
Apply function and combine results together in an intelligent way. The
224
224
split-apply-combine combination rules attempt to be as common sense
@@ -255,16 +255,16 @@ def apply(self, func):
255
255
-------
256
256
applied : type depending on grouped object and function
257
257
"""
258
- return self ._python_apply_general (func )
258
+ return self ._python_apply_general (func , * args , ** kwargs )
259
259
260
- def aggregate (self , func ):
260
+ def aggregate (self , func , * args , ** kwargs ):
261
261
raise NotImplementedError
262
262
263
- def agg (self , func ):
263
+ def agg (self , func , * args , ** kwargs ):
264
264
"""
265
265
See docstring for aggregate
266
266
"""
267
- return self .aggregate (func )
267
+ return self .aggregate (func , * args , ** kwargs )
268
268
269
269
def _get_names (self ):
270
270
axes = [ping .group_index for ping in self .groupings ]
@@ -275,18 +275,18 @@ def _get_names(self):
275
275
def _iterate_slices (self ):
276
276
yield self .name , self .obj
277
277
278
- def transform (self , func ):
278
+ def transform (self , func , * args , ** kwargs ):
279
279
raise NotImplementedError
280
280
281
- def mean (self , axis = None ):
281
+ def mean (self ):
282
282
"""
283
283
Compute mean of groups, excluding missing values
284
284
285
285
For multiple groupings, the result index will be a MultiIndex
286
286
"""
287
287
return self ._cython_agg_general ('mean' )
288
288
289
- def sum (self , axis = None ):
289
+ def sum (self ):
290
290
"""
291
291
Compute sum of values, excluding missing values
292
292
@@ -330,7 +330,7 @@ def _get_group_levels(self, mask):
330
330
name_list = self ._get_names ()
331
331
return [(name , raveled [mask ]) for name , raveled in name_list ]
332
332
333
- def _python_agg_general (self , arg ):
333
+ def _python_agg_general (self , func , * args , ** kwargs ):
334
334
group_shape = self ._group_shape
335
335
counts = np .zeros (group_shape , dtype = int )
336
336
@@ -342,7 +342,7 @@ def _doit(reschunk, ctchunk, gen, shape_axis=0):
342
342
ctchunk [i ] = size
343
343
if size == 0 :
344
344
continue
345
- reschunk [i ] = arg (subgen )
345
+ reschunk [i ] = func (subgen , * args , ** kwargs )
346
346
else :
347
347
_doit (reschunk [i ], ctchunk [i ], subgen ,
348
348
shape_axis = shape_axis )
@@ -378,14 +378,14 @@ def _doit(reschunk, ctchunk, gen, shape_axis=0):
378
378
379
379
return self ._wrap_aggregated_output (output , mask )
380
380
381
- def _python_apply_general (self , arg ):
381
+ def _python_apply_general (self , func , * args , ** kwargs ):
382
382
result_keys = []
383
383
result_values = []
384
384
385
385
not_indexed_same = False
386
386
for key , group in self :
387
387
group .name = key
388
- res = arg (group )
388
+ res = func (group , * args , ** kwargs )
389
389
if not _is_indexed_like (res , group ):
390
390
not_indexed_same = True
391
391
@@ -588,7 +588,7 @@ class SeriesGroupBy(GroupBy):
588
588
def _agg_stride_shape (self ):
589
589
return ()
590
590
591
- def aggregate (self , func_or_funcs ):
591
+ def aggregate (self , func_or_funcs , * args , ** kwargs ):
592
592
"""
593
593
Apply aggregation function or functions to groups, yielding most likely
594
594
Series but in some cases DataFrame depending on the output of the
@@ -640,18 +640,18 @@ def aggregate(self, func_or_funcs):
640
640
Series or DataFrame
641
641
"""
642
642
if isinstance (func_or_funcs , basestring ):
643
- return getattr (self , func_or_funcs )()
644
-
645
- if len (self .groupings ) > 1 :
646
- return self ._python_agg_general (func_or_funcs )
643
+ return getattr (self , func_or_funcs )(* args , ** kwargs )
647
644
648
645
if hasattr (func_or_funcs ,'__iter__' ):
649
646
ret = self ._aggregate_multiple_funcs (func_or_funcs )
650
647
else :
648
+ if len (self .groupings ) > 1 :
649
+ return self ._python_agg_general (func_or_funcs , * args , ** kwargs )
650
+
651
651
try :
652
- result = self ._aggregate_simple (func_or_funcs )
652
+ result = self ._aggregate_simple (func_or_funcs , * args , ** kwargs )
653
653
except Exception :
654
- result = self ._aggregate_named (func_or_funcs )
654
+ result = self ._aggregate_named (func_or_funcs , * args , ** kwargs )
655
655
656
656
if len (result ) > 0 :
657
657
if isinstance (result .values ()[0 ], Series ):
@@ -711,34 +711,30 @@ def _aggregate_multiple_funcs(self, arg):
711
711
results = {}
712
712
713
713
for name , func in arg .iteritems ():
714
- try :
715
- result = func (self )
716
- except Exception :
717
- result = self .aggregate (func )
718
- results [name ] = result
714
+ results [name ] = self .aggregate (func )
719
715
720
716
return DataFrame (results )
721
717
722
- def _aggregate_simple (self , arg ):
718
+ def _aggregate_simple (self , func , * args , ** kwargs ):
723
719
values = self .obj .values
724
720
result = {}
725
721
for k , v in self .primary .indices .iteritems ():
726
- result [k ] = arg (values .take (v ))
722
+ result [k ] = func (values .take (v ), * args , ** kwargs )
727
723
728
724
return result
729
725
730
- def _aggregate_named (self , arg ):
726
+ def _aggregate_named (self , func , * args , ** kwargs ):
731
727
result = {}
732
728
733
729
for name in self .primary :
734
730
grp = self .get_group (name )
735
731
grp .name = name
736
- output = arg (grp )
732
+ output = func (grp , * args , ** kwargs )
737
733
result [name ] = output
738
734
739
735
return result
740
736
741
- def transform (self , func ):
737
+ def transform (self , func , * args , ** kwargs ):
742
738
"""
743
739
Call function producing a like-indexed Series on each group and return
744
740
a Series with the transformed values
@@ -760,7 +756,7 @@ def transform(self, func):
760
756
761
757
for name , group in self :
762
758
group .name = name
763
- res = func (group )
759
+ res = func (group , * args , ** kwargs )
764
760
indexer , _ = self .obj .index .get_indexer (group .index )
765
761
np .put (result , indexer , res )
766
762
@@ -817,7 +813,7 @@ def _obj_with_exclusions(self):
817
813
else :
818
814
return self .obj
819
815
820
- def aggregate (self , arg ):
816
+ def aggregate (self , arg , * args , ** kwargs ):
821
817
"""
822
818
Aggregate using input function or dict of {column -> function}
823
819
@@ -843,26 +839,27 @@ def aggregate(self, arg):
843
839
result = DataFrame (result )
844
840
else :
845
841
if len (self .groupings ) > 1 :
846
- return self ._python_agg_general (arg )
847
- result = self ._aggregate_generic (arg , axis = self . axis )
842
+ return self ._python_agg_general (arg , * args , ** kwargs )
843
+ result = self ._aggregate_generic (arg , * args , ** kwargs )
848
844
849
845
return result
850
846
851
- def _aggregate_generic (self , agger , axis = 0 ):
847
+ def _aggregate_generic (self , func , * args , ** kwargs ):
852
848
result = {}
853
-
849
+ axis = self . axis
854
850
obj = self ._obj_with_exclusions
855
851
856
852
try :
857
853
for name in self .primary :
858
854
data = self .get_group (name , obj = obj )
859
855
try :
860
- result [name ] = agger (data )
856
+ result [name ] = func (data , * args , ** kwargs )
861
857
except Exception :
862
- result [name ] = data .apply (agger , axis = axis )
858
+ wrapper = lambda x : func (x , * args , ** kwargs )
859
+ result [name ] = data .apply (wrapper , axis = axis )
863
860
except Exception , e1 :
864
861
if axis == 0 :
865
- return self ._aggregate_item_by_item (agger )
862
+ return self ._aggregate_item_by_item (func , * args , ** kwargs )
866
863
else :
867
864
raise e1
868
865
@@ -872,7 +869,7 @@ def _aggregate_generic(self, agger, axis=0):
872
869
873
870
return result
874
871
875
- def _aggregate_item_by_item (self , agger ):
872
+ def _aggregate_item_by_item (self , func , * args , ** kwargs ):
876
873
# only for axis==0
877
874
878
875
obj = self ._obj_with_exclusions
@@ -881,7 +878,7 @@ def _aggregate_item_by_item(self, agger):
881
878
cannot_agg = []
882
879
for item in obj :
883
880
try :
884
- result [item ] = self [item ].agg (agger )
881
+ result [item ] = self [item ].agg (func , * args , ** kwargs )
885
882
except (ValueError , TypeError ):
886
883
cannot_agg .append (item )
887
884
continue
@@ -954,15 +951,15 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
954
951
955
952
return DataFrame (stacked_values , index = index , columns = columns )
956
953
957
- def transform (self , func ):
954
+ def transform (self , func , * args , ** kwargs ):
958
955
"""
959
956
Call function producing a like-indexed DataFrame on each group and
960
957
return a DataFrame having the same indexes as the original object
961
958
filled with the transformed values
962
959
963
960
Parameters
964
961
----------
965
- func : function
962
+ f : function
966
963
Function to apply to each subframe
967
964
968
965
Note
@@ -982,9 +979,10 @@ def transform(self, func):
982
979
group .name = name
983
980
984
981
try :
985
- res = group .apply (func , axis = self .axis )
982
+ wrapper = lambda x : func (x , * args , ** kwargs )
983
+ res = group .apply (wrapper , axis = self .axis )
986
984
except Exception : # pragma: no cover
987
- res = func (group )
985
+ res = func (group , * args , ** kwargs )
988
986
989
987
# broadcasting
990
988
if isinstance (res , Series ):
@@ -1081,7 +1079,7 @@ def _all_indexes_same(indexes):
1081
1079
1082
1080
class PanelGroupBy (GroupBy ):
1083
1081
1084
- def aggregate (self , func ):
1082
+ def aggregate (self , func , * args , ** kwargs ):
1085
1083
"""
1086
1084
Aggregate using input function or dict of {column -> function}
1087
1085
@@ -1096,19 +1094,22 @@ def aggregate(self, func):
1096
1094
-------
1097
1095
aggregated : Panel
1098
1096
"""
1099
- return self ._aggregate_generic (func , axis = self . axis )
1097
+ return self ._aggregate_generic (func , * args , ** kwargs )
1100
1098
1101
- def _aggregate_generic (self , agger , axis = 0 ):
1099
+ def _aggregate_generic (self , func , * args , ** kwargs ):
1102
1100
result = {}
1103
1101
1102
+ axis = self .axis
1103
+
1104
1104
obj = self ._obj_with_exclusions
1105
1105
1106
1106
for name in self .primary :
1107
1107
data = self .get_group (name , obj = obj )
1108
1108
try :
1109
- result [name ] = agger (data )
1109
+ result [name ] = func (data , * args , ** kwargs )
1110
1110
except Exception :
1111
- result [name ] = data .apply (agger , axis = axis )
1111
+ wrapper = lambda x : func (x , * args , ** kwargs )
1112
+ result [name ] = data .apply (wrapper , axis = axis )
1112
1113
1113
1114
result = Panel .fromDict (result , intersect = False )
1114
1115
0 commit comments