@@ -86,8 +86,8 @@ class GroupBy(object):
86
86
"""
87
87
88
88
def __init__ (self , obj , grouper = None , axis = 0 , level = None ,
89
- groupings = None , exclusions = None , name = None , as_index = True ):
90
- self ._name = name
89
+ groupings = None , exclusions = None , column = None , as_index = True ):
90
+ self ._column = column
91
91
92
92
if isinstance (obj , NDFrame ):
93
93
obj ._consolidate_inplace ()
@@ -105,13 +105,14 @@ def __init__(self, obj, grouper=None, axis=0, level=None,
105
105
raise ValueError ('as_index=False only valid for axis=0' )
106
106
107
107
self .as_index = as_index
108
+ self .grouper = grouper
108
109
109
110
if groupings is None :
110
111
groupings , exclusions = _get_groupings (obj , grouper , axis = axis ,
111
112
level = level )
112
113
113
114
self .groupings = groupings
114
- self .exclusions = set (exclusions )
115
+ self .exclusions = set (exclusions ) if exclusions else set ()
115
116
116
117
def __len__ (self ):
117
118
return len (self .indices )
@@ -138,10 +139,10 @@ def indices(self):
138
139
139
140
@property
140
141
def name (self ):
141
- if self ._name is None :
142
+ if self ._column is None :
142
143
return 'result'
143
144
else :
144
- return self ._name
145
+ return self ._column
145
146
146
147
@property
147
148
def _obj_with_exclusions (self ):
@@ -854,12 +855,30 @@ def _agg_stride_shape(self):
854
855
return n ,
855
856
856
857
def __getitem__ (self , key ):
857
- return SeriesGroupBy (self .obj [key ], groupings = self .groupings ,
858
- exclusions = self .exclusions , name = key )
858
+ if self ._column is not None :
859
+ raise Exception ('Column %s already selected' % self ._column )
860
+
861
+ if key not in self .obj : # pragma: no cover
862
+ raise KeyError (str (key ))
863
+
864
+ # kind of a kludge
865
+ if self .as_index :
866
+ return SeriesGroupBy (self .obj [key ], column = key ,
867
+ groupings = self .groupings ,
868
+ exclusions = self .exclusions )
869
+ else :
870
+ return DataFrameGroupBy (self .obj , self .grouper , column = key ,
871
+ groupings = self .groupings ,
872
+ exclusions = self .exclusions ,
873
+ as_index = self .as_index )
859
874
860
875
def _iterate_slices (self ):
861
876
if self .axis == 0 :
862
- slice_axis = self .obj .columns
877
+ # kludge
878
+ if self ._column is None :
879
+ slice_axis = self .obj .columns
880
+ else :
881
+ slice_axis = [self ._column ]
863
882
slicer = lambda x : self .obj [x ]
864
883
else :
865
884
slice_axis = self .obj .index
@@ -873,6 +892,9 @@ def _iterate_slices(self):
873
892
874
893
@cache_readonly
875
894
def _obj_with_exclusions (self ):
895
+ if self ._column is not None :
896
+ return self .obj .reindex (columns = [self ._column ])
897
+
876
898
if len (self .exclusions ) > 0 :
877
899
return self .obj .drop (self .exclusions , axis = 1 )
878
900
else :
@@ -901,8 +923,11 @@ def aggregate(self, arg, *args, **kwargs):
901
923
if self .axis != 0 : # pragma: no cover
902
924
raise ValueError ('Can only pass dict with axis=0' )
903
925
926
+ obj = self ._obj_with_exclusions
904
927
for col , func in arg .iteritems ():
905
- result [col ] = self [col ].agg (func )
928
+ colg = SeriesGroupBy (obj [col ], column = col ,
929
+ groupings = self .groupings )
930
+ result [col ] = colg .agg (func )
906
931
907
932
result = DataFrame (result )
908
933
else :
@@ -927,23 +952,25 @@ def aggregate(self, arg, *args, **kwargs):
927
952
return result
928
953
929
954
def _aggregate_generic (self , func , * args , ** kwargs ):
930
- result = {}
931
955
axis = self .axis
932
956
obj = self ._obj_with_exclusions
933
957
934
- try :
935
- for name in self .primary :
936
- data = self .get_group (name , obj = obj )
958
+ result = {}
959
+ if axis == 0 :
960
+ try :
961
+ for name in self .indices :
962
+ data = self .get_group (name , obj = obj )
963
+ result [name ] = func (data , * args , ** kwargs )
964
+ except Exception :
965
+ return self ._aggregate_item_by_item (func , * args , ** kwargs )
966
+ else :
967
+ for name in self .indices :
937
968
try :
969
+ data = self .get_group (name , obj = obj )
938
970
result [name ] = func (data , * args , ** kwargs )
939
971
except Exception :
940
972
wrapper = lambda x : func (x , * args , ** kwargs )
941
973
result [name ] = data .apply (wrapper , axis = axis )
942
- except Exception , e1 :
943
- if axis == 0 :
944
- return self ._aggregate_item_by_item (func , * args , ** kwargs )
945
- else :
946
- raise e1
947
974
948
975
if result :
949
976
if axis == 0 :
@@ -963,12 +990,18 @@ def _aggregate_item_by_item(self, func, *args, **kwargs):
963
990
cannot_agg = []
964
991
for item in obj :
965
992
try :
966
- result [item ] = self [item ].agg (func , * args , ** kwargs )
993
+ colg = SeriesGroupBy (obj [item ], column = item ,
994
+ groupings = self .groupings )
995
+ result [item ] = colg .agg (func , * args , ** kwargs )
967
996
except (ValueError , TypeError ):
968
997
cannot_agg .append (item )
969
998
continue
970
999
971
- return DataFrame (result )
1000
+ result_columns = obj .columns
1001
+ if cannot_agg :
1002
+ result_columns = result_columns .drop (cannot_agg )
1003
+
1004
+ return DataFrame (result , columns = result_columns )
972
1005
973
1006
def _wrap_aggregated_output (self , output , mask ):
974
1007
agg_axis = 0 if self .axis == 1 else 1
0 commit comments