1
1
import types
2
+ from functools import wraps
2
3
import numpy as np
3
4
4
5
from pandas .compat import (
45
46
"""
46
47
47
48
49
+ # special case to prevent duplicate plots when catching exceptions when
50
+ # forwarding methods from NDFrames
51
+ _plotting_methods = frozenset (['plot' , 'boxplot' , 'hist' ])
52
+
48
53
_apply_whitelist = frozenset (['last' , 'first' ,
49
54
'mean' , 'sum' , 'min' , 'max' ,
50
55
'head' , 'tail' ,
51
56
'cumsum' , 'cumprod' , 'cummin' , 'cummax' ,
52
57
'resample' ,
53
58
'describe' ,
54
59
'rank' , 'quantile' , 'count' ,
55
- 'fillna' , 'dtype' ])
60
+ 'fillna' , 'dtype' ]) | _plotting_methods
61
+
56
62
57
63
58
64
class GroupByError (Exception ):
@@ -180,7 +186,6 @@ class GroupBy(PandasObject):
180
186
len(grouped) : int
181
187
Number of groups
182
188
"""
183
-
184
189
def __init__ (self , obj , keys = None , axis = 0 , level = None ,
185
190
grouper = None , exclusions = None , selection = None , as_index = True ,
186
191
sort = True , group_keys = True , squeeze = False ):
@@ -244,6 +249,9 @@ def _selection_list(self):
244
249
return [self ._selection ]
245
250
return self ._selection
246
251
252
+ def _local_dir (self ):
253
+ return sorted (set (self .obj ._local_dir () + list (_apply_whitelist )))
254
+
247
255
def __getattr__ (self , attr ):
248
256
if attr in self .obj :
249
257
return self [attr ]
@@ -281,9 +289,16 @@ def wrapper(*args, **kwargs):
281
289
282
290
def curried_with_axis (x ):
283
291
return f (x , * args , ** kwargs_with_axis )
292
+ curried_with_axis .__name__ = name
284
293
285
294
def curried (x ):
286
295
return f (x , * args , ** kwargs )
296
+ curried .__name__ = name
297
+
298
+ # special case otherwise extra plots are created when catching the
299
+ # exception below
300
+ if name in _plotting_methods :
301
+ return self .apply (curried )
287
302
288
303
try :
289
304
return self .apply (curried_with_axis )
@@ -348,7 +363,11 @@ def apply(self, func, *args, **kwargs):
348
363
applied : type depending on grouped object and function
349
364
"""
350
365
func = _intercept_function (func )
351
- f = lambda g : func (g , * args , ** kwargs )
366
+
367
+ @wraps (func )
368
+ def f (g ):
369
+ return func (g , * args , ** kwargs )
370
+
352
371
return self ._python_apply_general (f )
353
372
354
373
def _python_apply_general (self , f ):
@@ -598,7 +617,7 @@ def __iter__(self):
598
617
def nkeys (self ):
599
618
return len (self .groupings )
600
619
601
- def get_iterator (self , data , axis = 0 , keep_internal = True ):
620
+ def get_iterator (self , data , axis = 0 ):
602
621
"""
603
622
Groupby iterator
604
623
@@ -607,16 +626,14 @@ def get_iterator(self, data, axis=0, keep_internal=True):
607
626
Generator yielding sequence of (name, subsetted object)
608
627
for each group
609
628
"""
610
- splitter = self ._get_splitter (data , axis = axis ,
611
- keep_internal = keep_internal )
629
+ splitter = self ._get_splitter (data , axis = axis )
612
630
keys = self ._get_group_keys ()
613
631
for key , (i , group ) in zip (keys , splitter ):
614
632
yield key , group
615
633
616
- def _get_splitter (self , data , axis = 0 , keep_internal = True ):
634
+ def _get_splitter (self , data , axis = 0 ):
617
635
comp_ids , _ , ngroups = self .group_info
618
- return get_splitter (data , comp_ids , ngroups , axis = axis ,
619
- keep_internal = keep_internal )
636
+ return get_splitter (data , comp_ids , ngroups , axis = axis )
620
637
621
638
def _get_group_keys (self ):
622
639
if len (self .groupings ) == 1 :
@@ -627,19 +644,19 @@ def _get_group_keys(self):
627
644
mapper = _KeyMapper (comp_ids , ngroups , self .labels , self .levels )
628
645
return [mapper .get_key (i ) for i in range (ngroups )]
629
646
630
- def apply (self , f , data , axis = 0 , keep_internal = False ):
647
+ def apply (self , f , data , axis = 0 ):
631
648
mutated = False
632
- splitter = self ._get_splitter (data , axis = axis ,
633
- keep_internal = keep_internal )
649
+ splitter = self ._get_splitter (data , axis = axis )
634
650
group_keys = self ._get_group_keys ()
635
651
636
652
# oh boy
637
- if hasattr (splitter , 'fast_apply' ) and axis == 0 :
653
+ if (f .__name__ not in _plotting_methods and
654
+ hasattr (splitter , 'fast_apply' ) and axis == 0 ):
638
655
try :
639
656
values , mutated = splitter .fast_apply (f , group_keys )
640
657
return group_keys , values , mutated
641
- except ( Exception ) as detail :
642
- # we detect a mutatation of some kind
658
+ except Exception :
659
+ # we detect a mutation of some kind
643
660
# so take slow path
644
661
pass
645
662
@@ -1043,7 +1060,7 @@ def get_iterator(self, data, axis=0):
1043
1060
inds = lrange (start , n )
1044
1061
yield self .binlabels [- 1 ], data .take (inds , axis = axis )
1045
1062
1046
- def apply (self , f , data , axis = 0 , keep_internal = False ):
1063
+ def apply (self , f , data , axis = 0 ):
1047
1064
result_keys = []
1048
1065
result_values = []
1049
1066
mutated = False
@@ -1617,6 +1634,7 @@ def filter(self, func, dropna=True, *args, **kwargs):
1617
1634
else :
1618
1635
return filtered .reindex (self .obj .index ) # Fill with NaNs.
1619
1636
1637
+
1620
1638
class NDFrameGroupBy (GroupBy ):
1621
1639
1622
1640
def _iterate_slices (self ):
@@ -2268,6 +2286,7 @@ def ohlc(self):
2268
2286
"""
2269
2287
return self ._apply_to_column_groupbys (lambda x : x ._cython_agg_general ('ohlc' ))
2270
2288
2289
+
2271
2290
from pandas .tools .plotting import boxplot_frame_groupby
2272
2291
DataFrameGroupBy .boxplot = boxplot_frame_groupby
2273
2292
@@ -2364,7 +2383,7 @@ class NDArrayGroupBy(GroupBy):
2364
2383
2365
2384
class DataSplitter (object ):
2366
2385
2367
- def __init__ (self , data , labels , ngroups , axis = 0 , keep_internal = False ):
2386
+ def __init__ (self , data , labels , ngroups , axis = 0 ):
2368
2387
self .data = data
2369
2388
self .labels = com ._ensure_int64 (labels )
2370
2389
self .ngroups = ngroups
@@ -2419,10 +2438,8 @@ def _chop(self, sdata, slice_obj):
2419
2438
2420
2439
2421
2440
class FrameSplitter (DataSplitter ):
2422
-
2423
- def __init__ (self , data , labels , ngroups , axis = 0 , keep_internal = False ):
2424
- DataSplitter .__init__ (self , data , labels , ngroups , axis = axis ,
2425
- keep_internal = keep_internal )
2441
+ def __init__ (self , data , labels , ngroups , axis = 0 ):
2442
+ super (FrameSplitter , self ).__init__ (data , labels , ngroups , axis = axis )
2426
2443
2427
2444
def fast_apply (self , f , names ):
2428
2445
# must return keys::list, values::list, mutated::bool
@@ -2445,10 +2462,8 @@ def _chop(self, sdata, slice_obj):
2445
2462
2446
2463
2447
2464
class NDFrameSplitter (DataSplitter ):
2448
-
2449
- def __init__ (self , data , labels , ngroups , axis = 0 , keep_internal = False ):
2450
- DataSplitter .__init__ (self , data , labels , ngroups , axis = axis ,
2451
- keep_internal = keep_internal )
2465
+ def __init__ (self , data , labels , ngroups , axis = 0 ):
2466
+ super (NDFrameSplitter , self ).__init__ (data , labels , ngroups , axis = axis )
2452
2467
2453
2468
self .factory = data ._constructor
2454
2469
0 commit comments