Skip to content

Commit 30e539c

Browse files
committed
BUG: set index name attribute in single-key groupby, GH #358
1 parent 0b201dc commit 30e539c

File tree

3 files changed

+45
-37
lines changed

3 files changed

+45
-37
lines changed

RELEASE.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ pandas 0.5.1
9797
- Added missing bang at top of setup.py (GH #352)
9898
- Change `is_monotonic` on MultiIndex so it properly compares the tuples
9999
- Fix MultiIndex outer join logic (GH #351)
100+
- Set index name attribute with single-key groupby (GH #358)
100101

101102
Thanks
102103
------

pandas/core/groupby.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,8 @@ def aggregate(self, func_or_funcs, *args, **kwargs):
716716
except Exception:
717717
result = self._aggregate_named(func_or_funcs, *args, **kwargs)
718718

719-
ret = Series(result)
719+
index = Index(sorted(result), name=self.groupings[0].name)
720+
ret = Series(result, index=index)
720721

721722
if not self.as_index: # pragma: no cover
722723
print 'Warning, ignoring as_index=True'
@@ -728,12 +729,8 @@ def _wrap_aggregated_output(self, output, mask):
728729
# sort of a kludge
729730
output = output[self.name]
730731

731-
if len(self.groupings) > 1:
732-
index = self._get_multi_index(mask)
733-
return Series(output, index=index)
734-
else:
735-
name_list = self._get_names()
736-
return Series(output, index=name_list[0][1])
732+
index = self._get_multi_index(mask)
733+
return Series(output, index=index)
737734

738735
def _wrap_applied_output(self, keys, values, not_indexed_same=False):
739736
if len(keys) == 0:
@@ -972,11 +969,17 @@ def _aggregate_generic(self, func, *args, **kwargs):
972969
wrapper = lambda x: func(x, *args, **kwargs)
973970
result[name] = data.apply(wrapper, axis=axis)
974971

972+
index_name = (self.groupings[0].name
973+
if len(self.groupings) == 1 else None)
974+
result_index = Index(sorted(result), name=index_name)
975+
975976
if result:
976977
if axis == 0:
977-
result = DataFrame(result, index=obj.columns).T
978+
result = DataFrame(result, index=obj.columns,
979+
columns=result_index).T
978980
else:
979-
result = DataFrame(result, index=obj.index)
981+
result = DataFrame(result, index=obj.index,
982+
columns=result_index)
980983
else:
981984
result = DataFrame(result)
982985

@@ -1022,24 +1025,15 @@ def _wrap_aggregated_output(self, output, mask):
10221025
else:
10231026
output_keys = agg_labels
10241027

1025-
if len(self.groupings) > 1:
1026-
if not self.as_index:
1027-
result = DataFrame(output, columns=output_keys)
1028-
group_levels = self._get_group_levels(mask)
1029-
for i, (name, labels) in enumerate(group_levels):
1030-
result.insert(i, name, labels)
1031-
result = result.consolidate()
1032-
else:
1033-
index = self._get_multi_index(mask)
1034-
result = DataFrame(output, index=index, columns=output_keys)
1028+
if not self.as_index:
1029+
result = DataFrame(output, columns=output_keys)
1030+
group_levels = self._get_group_levels(mask)
1031+
for i, (name, labels) in enumerate(group_levels):
1032+
result.insert(i, name, labels)
1033+
result = result.consolidate()
10351034
else:
1036-
name_list = self._get_names()
1037-
name, labels = name_list[0]
1038-
if not self.as_index:
1039-
result = DataFrame(output, columns=output_keys)
1040-
result.insert(0, name, labels)
1041-
else:
1042-
result = DataFrame(output, index=labels, columns=output_keys)
1035+
index = self._get_multi_index(mask)
1036+
result = DataFrame(output, index=index, columns=output_keys)
10431037

10441038
if self.axis == 1:
10451039
result = result.T

pandas/tests/test_groupby.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -390,17 +390,30 @@ def test_frame_groupby_columns(self):
390390
for k, v in grouped:
391391
self.assertEqual(len(v.columns), 2)
392392

393-
# # tgroupby
394-
# grouping = {
395-
# 'A' : 0,
396-
# 'B' : 1,
397-
# 'C' : 0,
398-
# 'D' : 1
399-
# }
400-
401-
# grouped = self.frame.tgroupby(grouping.get, np.mean)
402-
# self.assertEqual(len(grouped), len(self.frame.index))
403-
# self.assertEqual(len(grouped.columns), 2)
393+
def test_frame_set_name_single(self):
394+
grouped = self.df.groupby('A')
395+
396+
result = grouped.mean()
397+
self.assert_(result.index.name == 'A')
398+
399+
result = self.df.groupby('A', as_index=False).mean()
400+
self.assert_(result.index.name != 'A')
401+
402+
result = grouped.agg(np.mean)
403+
self.assert_(result.index.name == 'A')
404+
405+
result = grouped.agg({'C' : np.mean, 'D' : np.std})
406+
self.assert_(result.index.name == 'A')
407+
408+
result = grouped['C'].mean()
409+
self.assert_(result.index.name == 'A')
410+
result = grouped['C'].agg(np.mean)
411+
self.assert_(result.index.name == 'A')
412+
result = grouped['C'].agg([np.mean, np.std])
413+
self.assert_(result.index.name == 'A')
414+
415+
result = grouped['C'].agg({'foo' : np.mean, 'bar' : np.std})
416+
self.assert_(result.index.name == 'A')
404417

405418
def test_multi_iter(self):
406419
s = Series(np.arange(6))

0 commit comments

Comments
 (0)