Skip to content

Commit 79a8609

Browse files
committed
ENH: groupby: return DataFrame with selected column and as_index=False. Had to make a bit of a mess to get this to work as desired, GH #308
1 parent 32c65bd commit 79a8609

File tree

2 files changed

+78
-21
lines changed

2 files changed

+78
-21
lines changed

pandas/core/groupby.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ class GroupBy(object):
8686
"""
8787

8888
def __init__(self, obj, grouper=None, axis=0, level=None,
89-
groupings=None, exclusions=None, name=None, as_index=True):
90-
self._name = name
89+
groupings=None, exclusions=None, column=None, as_index=True):
90+
self._column = column
9191

9292
if isinstance(obj, NDFrame):
9393
obj._consolidate_inplace()
@@ -105,13 +105,14 @@ def __init__(self, obj, grouper=None, axis=0, level=None,
105105
raise ValueError('as_index=False only valid for axis=0')
106106

107107
self.as_index = as_index
108+
self.grouper = grouper
108109

109110
if groupings is None:
110111
groupings, exclusions = _get_groupings(obj, grouper, axis=axis,
111112
level=level)
112113

113114
self.groupings = groupings
114-
self.exclusions = set(exclusions)
115+
self.exclusions = set(exclusions) if exclusions else set()
115116

116117
def __len__(self):
117118
return len(self.indices)
@@ -138,10 +139,10 @@ def indices(self):
138139

139140
@property
140141
def name(self):
141-
if self._name is None:
142+
if self._column is None:
142143
return 'result'
143144
else:
144-
return self._name
145+
return self._column
145146

146147
@property
147148
def _obj_with_exclusions(self):
@@ -854,12 +855,30 @@ def _agg_stride_shape(self):
854855
return n,
855856

856857
def __getitem__(self, key):
857-
return SeriesGroupBy(self.obj[key], groupings=self.groupings,
858-
exclusions=self.exclusions, name=key)
858+
if self._column is not None:
859+
raise Exception('Column %s already selected' % self._column)
860+
861+
if key not in self.obj: # pragma: no cover
862+
raise KeyError(str(key))
863+
864+
# kind of a kludge
865+
if self.as_index:
866+
return SeriesGroupBy(self.obj[key], column=key,
867+
groupings=self.groupings,
868+
exclusions=self.exclusions)
869+
else:
870+
return DataFrameGroupBy(self.obj, self.grouper, column=key,
871+
groupings=self.groupings,
872+
exclusions=self.exclusions,
873+
as_index=self.as_index)
859874

860875
def _iterate_slices(self):
861876
if self.axis == 0:
862-
slice_axis = self.obj.columns
877+
# kludge
878+
if self._column is None:
879+
slice_axis = self.obj.columns
880+
else:
881+
slice_axis = [self._column]
863882
slicer = lambda x: self.obj[x]
864883
else:
865884
slice_axis = self.obj.index
@@ -873,6 +892,9 @@ def _iterate_slices(self):
873892

874893
@cache_readonly
875894
def _obj_with_exclusions(self):
895+
if self._column is not None:
896+
return self.obj.reindex(columns=[self._column])
897+
876898
if len(self.exclusions) > 0:
877899
return self.obj.drop(self.exclusions, axis=1)
878900
else:
@@ -901,8 +923,11 @@ def aggregate(self, arg, *args, **kwargs):
901923
if self.axis != 0: # pragma: no cover
902924
raise ValueError('Can only pass dict with axis=0')
903925

926+
obj = self._obj_with_exclusions
904927
for col, func in arg.iteritems():
905-
result[col] = self[col].agg(func)
928+
colg = SeriesGroupBy(obj[col], column=col,
929+
groupings=self.groupings)
930+
result[col] = colg.agg(func)
906931

907932
result = DataFrame(result)
908933
else:
@@ -927,23 +952,25 @@ def aggregate(self, arg, *args, **kwargs):
927952
return result
928953

929954
def _aggregate_generic(self, func, *args, **kwargs):
930-
result = {}
931955
axis = self.axis
932956
obj = self._obj_with_exclusions
933957

934-
try:
935-
for name in self.primary:
936-
data = self.get_group(name, obj=obj)
958+
result = {}
959+
if axis == 0:
960+
try:
961+
for name in self.indices:
962+
data = self.get_group(name, obj=obj)
963+
result[name] = func(data, *args, **kwargs)
964+
except Exception:
965+
return self._aggregate_item_by_item(func, *args, **kwargs)
966+
else:
967+
for name in self.indices:
937968
try:
969+
data = self.get_group(name, obj=obj)
938970
result[name] = func(data, *args, **kwargs)
939971
except Exception:
940972
wrapper = lambda x: func(x, *args, **kwargs)
941973
result[name] = data.apply(wrapper, axis=axis)
942-
except Exception, e1:
943-
if axis == 0:
944-
return self._aggregate_item_by_item(func, *args, **kwargs)
945-
else:
946-
raise e1
947974

948975
if result:
949976
if axis == 0:
@@ -963,12 +990,18 @@ def _aggregate_item_by_item(self, func, *args, **kwargs):
963990
cannot_agg = []
964991
for item in obj:
965992
try:
966-
result[item] = self[item].agg(func, *args, **kwargs)
993+
colg = SeriesGroupBy(obj[item], column=item,
994+
groupings=self.groupings)
995+
result[item] = colg.agg(func, *args, **kwargs)
967996
except (ValueError, TypeError):
968997
cannot_agg.append(item)
969998
continue
970999

971-
return DataFrame(result)
1000+
result_columns = obj.columns
1001+
if cannot_agg:
1002+
result_columns = result_columns.drop(cannot_agg)
1003+
1004+
return DataFrame(result, columns=result_columns)
9721005

9731006
def _wrap_aggregated_output(self, output, mask):
9741007
agg_axis = 0 if self.axis == 1 else 1

pandas/tests/test_groupby.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def test_agg_regression1(self):
112112
expected = grouped.mean()
113113
assert_frame_equal(result, expected)
114114

115-
def test_agg_must_add(self):
115+
def test_agg_must_agg(self):
116116
grouped = self.df.groupby('A')['C']
117117
self.assertRaises(Exception, grouped.agg, lambda x: x.describe())
118118
self.assertRaises(Exception, grouped.agg, lambda x: x.index[:2])
@@ -551,6 +551,30 @@ def test_groupby_as_index_agg(self):
551551
expected2['D'] = grouped.sum()['D']
552552
assert_frame_equal(result2, expected2)
553553

554+
def test_as_index_series_return_frame(self):
555+
grouped = self.df.groupby('A', as_index=False)
556+
grouped2 = self.df.groupby(['A', 'B'], as_index=False)
557+
558+
result = grouped['C'].agg(np.sum)
559+
expected = grouped.agg(np.sum).ix[:, ['A', 'C']]
560+
self.assert_(isinstance(result, DataFrame))
561+
assert_frame_equal(result, expected)
562+
563+
result2 = grouped2['C'].agg(np.sum)
564+
expected2 = grouped2.agg(np.sum).ix[:, ['A', 'B', 'C']]
565+
self.assert_(isinstance(result2, DataFrame))
566+
assert_frame_equal(result2, expected2)
567+
568+
result = grouped['C'].sum()
569+
expected = grouped.sum().ix[:, ['A', 'C']]
570+
self.assert_(isinstance(result, DataFrame))
571+
assert_frame_equal(result, expected)
572+
573+
result2 = grouped2['C'].sum()
574+
expected2 = grouped2.sum().ix[:, ['A', 'B', 'C']]
575+
self.assert_(isinstance(result2, DataFrame))
576+
assert_frame_equal(result2, expected2)
577+
554578
def test_groupby_as_index_cython(self):
555579
data = self.df
556580

0 commit comments

Comments
 (0)