1
1
import types
2
+ from functools import wraps
2
3
import numpy as np
3
4
4
5
from pandas .compat import (
45
46
"""
46
47
47
48
49
+ # special case to prevent duplicate plots when catching exceptions when
50
+ # forwarding methods from NDFrames
51
+ _plotting_methods = frozenset (['plot' , 'boxplot' , 'hist' ])
52
+
48
53
_apply_whitelist = frozenset (['last' , 'first' ,
49
54
'mean' , 'sum' , 'min' , 'max' ,
50
55
'head' , 'tail' ,
51
56
'cumsum' , 'cumprod' , 'cummin' , 'cummax' ,
52
57
'resample' ,
53
58
'describe' ,
54
59
'rank' , 'quantile' , 'count' ,
55
- 'fillna' , 'dtype' ])
60
+ 'fillna' , 'dtype' ]) | _plotting_methods
61
+
56
62
57
63
58
64
class GroupByError (Exception ):
@@ -180,7 +186,6 @@ class GroupBy(PandasObject):
180
186
len(grouped) : int
181
187
Number of groups
182
188
"""
183
-
184
189
def __init__ (self , obj , keys = None , axis = 0 , level = None ,
185
190
grouper = None , exclusions = None , selection = None , as_index = True ,
186
191
sort = True , group_keys = True , squeeze = False ):
@@ -244,6 +249,9 @@ def _selection_list(self):
244
249
return [self ._selection ]
245
250
return self ._selection
246
251
252
+ def _local_dir (self ):
253
+ return sorted (set (self .obj ._local_dir () + list (_apply_whitelist )))
254
+
247
255
def __getattr__ (self , attr ):
248
256
if attr in self .obj :
249
257
return self [attr ]
@@ -285,6 +293,15 @@ def curried_with_axis(x):
285
293
def curried (x ):
286
294
return f (x , * args , ** kwargs )
287
295
296
+ # preserve the name so we can detect it when calling plot methods,
297
+ # to avoid duplicates
298
+ curried .__name__ = curried_with_axis .__name__ = name
299
+
300
+ # special case otherwise extra plots are created when catching the
301
+ # exception below
302
+ if name in _plotting_methods :
303
+ return self .apply (curried )
304
+
288
305
try :
289
306
return self .apply (curried_with_axis )
290
307
except Exception :
@@ -348,7 +365,11 @@ def apply(self, func, *args, **kwargs):
348
365
applied : type depending on grouped object and function
349
366
"""
350
367
func = _intercept_function (func )
351
- f = lambda g : func (g , * args , ** kwargs )
368
+
369
+ @wraps (func )
370
+ def f (g ):
371
+ return func (g , * args , ** kwargs )
372
+
352
373
return self ._python_apply_general (f )
353
374
354
375
def _python_apply_general (self , f ):
@@ -598,7 +619,7 @@ def __iter__(self):
598
619
def nkeys (self ):
599
620
return len (self .groupings )
600
621
601
- def get_iterator (self , data , axis = 0 , keep_internal = True ):
622
+ def get_iterator (self , data , axis = 0 ):
602
623
"""
603
624
Groupby iterator
604
625
@@ -607,16 +628,14 @@ def get_iterator(self, data, axis=0, keep_internal=True):
607
628
Generator yielding sequence of (name, subsetted object)
608
629
for each group
609
630
"""
610
- splitter = self ._get_splitter (data , axis = axis ,
611
- keep_internal = keep_internal )
631
+ splitter = self ._get_splitter (data , axis = axis )
612
632
keys = self ._get_group_keys ()
613
633
for key , (i , group ) in zip (keys , splitter ):
614
634
yield key , group
615
635
616
- def _get_splitter (self , data , axis = 0 , keep_internal = True ):
636
+ def _get_splitter (self , data , axis = 0 ):
617
637
comp_ids , _ , ngroups = self .group_info
618
- return get_splitter (data , comp_ids , ngroups , axis = axis ,
619
- keep_internal = keep_internal )
638
+ return get_splitter (data , comp_ids , ngroups , axis = axis )
620
639
621
640
def _get_group_keys (self ):
622
641
if len (self .groupings ) == 1 :
@@ -627,19 +646,19 @@ def _get_group_keys(self):
627
646
mapper = _KeyMapper (comp_ids , ngroups , self .labels , self .levels )
628
647
return [mapper .get_key (i ) for i in range (ngroups )]
629
648
630
- def apply (self , f , data , axis = 0 , keep_internal = False ):
649
+ def apply (self , f , data , axis = 0 ):
631
650
mutated = False
632
- splitter = self ._get_splitter (data , axis = axis ,
633
- keep_internal = keep_internal )
651
+ splitter = self ._get_splitter (data , axis = axis )
634
652
group_keys = self ._get_group_keys ()
635
653
636
654
# oh boy
637
- if hasattr (splitter , 'fast_apply' ) and axis == 0 :
655
+ if (f .__name__ not in _plotting_methods and
656
+ hasattr (splitter , 'fast_apply' ) and axis == 0 ):
638
657
try :
639
658
values , mutated = splitter .fast_apply (f , group_keys )
640
659
return group_keys , values , mutated
641
- except ( Exception ) as detail :
642
- # we detect a mutatation of some kind
660
+ except Exception :
661
+ # we detect a mutation of some kind
643
662
# so take slow path
644
663
pass
645
664
@@ -1043,7 +1062,7 @@ def get_iterator(self, data, axis=0):
1043
1062
inds = lrange (start , n )
1044
1063
yield self .binlabels [- 1 ], data .take (inds , axis = axis )
1045
1064
1046
- def apply (self , f , data , axis = 0 , keep_internal = False ):
1065
+ def apply (self , f , data , axis = 0 ):
1047
1066
result_keys = []
1048
1067
result_values = []
1049
1068
mutated = False
@@ -1617,6 +1636,7 @@ def filter(self, func, dropna=True, *args, **kwargs):
1617
1636
else :
1618
1637
return filtered .reindex (self .obj .index ) # Fill with NaNs.
1619
1638
1639
+
1620
1640
class NDFrameGroupBy (GroupBy ):
1621
1641
1622
1642
def _iterate_slices (self ):
@@ -1939,14 +1959,14 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
1939
1959
index = key_index
1940
1960
else :
1941
1961
stacked_values = np .vstack ([np .asarray (x )
1942
- for x in values ]).T
1962
+ for x in values ]).T
1943
1963
1944
1964
index = values [0 ].index
1945
1965
columns = key_index
1946
1966
1947
- except ValueError :
1948
- #GH1738,, values is list of arrays of unequal lengths
1949
- # fall through to the outer else caluse
1967
+ except ( ValueError , AttributeError ) :
1968
+ # GH1738: values is list of arrays of unequal lengths fall
1969
+ # through to the outer else caluse
1950
1970
return Series (values , index = key_index )
1951
1971
1952
1972
return DataFrame (stacked_values , index = index ,
@@ -2268,6 +2288,7 @@ def ohlc(self):
2268
2288
"""
2269
2289
return self ._apply_to_column_groupbys (lambda x : x ._cython_agg_general ('ohlc' ))
2270
2290
2291
+
2271
2292
from pandas .tools .plotting import boxplot_frame_groupby
2272
2293
DataFrameGroupBy .boxplot = boxplot_frame_groupby
2273
2294
@@ -2364,7 +2385,7 @@ class NDArrayGroupBy(GroupBy):
2364
2385
2365
2386
class DataSplitter (object ):
2366
2387
2367
- def __init__ (self , data , labels , ngroups , axis = 0 , keep_internal = False ):
2388
+ def __init__ (self , data , labels , ngroups , axis = 0 ):
2368
2389
self .data = data
2369
2390
self .labels = com ._ensure_int64 (labels )
2370
2391
self .ngroups = ngroups
@@ -2419,10 +2440,8 @@ def _chop(self, sdata, slice_obj):
2419
2440
2420
2441
2421
2442
class FrameSplitter (DataSplitter ):
2422
-
2423
- def __init__ (self , data , labels , ngroups , axis = 0 , keep_internal = False ):
2424
- DataSplitter .__init__ (self , data , labels , ngroups , axis = axis ,
2425
- keep_internal = keep_internal )
2443
+ def __init__ (self , data , labels , ngroups , axis = 0 ):
2444
+ super (FrameSplitter , self ).__init__ (data , labels , ngroups , axis = axis )
2426
2445
2427
2446
def fast_apply (self , f , names ):
2428
2447
# must return keys::list, values::list, mutated::bool
@@ -2445,10 +2464,8 @@ def _chop(self, sdata, slice_obj):
2445
2464
2446
2465
2447
2466
class NDFrameSplitter (DataSplitter ):
2448
-
2449
- def __init__ (self , data , labels , ngroups , axis = 0 , keep_internal = False ):
2450
- DataSplitter .__init__ (self , data , labels , ngroups , axis = axis ,
2451
- keep_internal = keep_internal )
2467
+ def __init__ (self , data , labels , ngroups , axis = 0 ):
2468
+ super (NDFrameSplitter , self ).__init__ (data , labels , ngroups , axis = axis )
2452
2469
2453
2470
self .factory = data ._constructor
2454
2471
0 commit comments