Skip to content

Commit 3440a9c

Browse files
committed
ENH: can pass a list of functions to DataFrame.groupby per #166
1 parent 85b75f9 commit 3440a9c

File tree

9 files changed

+72
-17
lines changed

9 files changed

+72
-17
lines changed

RELEASE.rst

+2
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ pandas 0.7.0
128128
the number of displayed digits (GH #395)
129129
- Use bottleneck if available for performing NaN-friendly statistical
130130
operations that it implemented (GH #91)
131+
- Can pass a list of functions to aggregate with groupby on a DataFrame,
132+
yielding an aggregated result with hierarchical columns (GH #166)
131133

132134
**Bug fixes**
133135

doc/source/groupby.rst

+9-2
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,15 @@ function's name (stored in the function object) will be used.
310310
grouped['D'].agg({'result1' : np.sum,
311311
'result2' : np.mean})
312312
313-
We would like to enable this functionality for DataFrame, too. The result will
314-
likely have a MultiIndex for the columns.
313+
On a grouped DataFrame, you can pass a list of functions to apply to each
314+
column, which produces an aggregated result with a hierarchical index:
315+
316+
.. ipython:: python
317+
318+
grouped.agg([np.sum, np.mean, np.std])
319+
320+
Passing a dict of functions has different behavior by default, see the next
321+
section.
315322

316323
Applying different functions to DataFrame columns
317324
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

pandas/core/groupby.py

+25
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,8 @@ def aggregate(self, arg, *args, **kwargs):
971971
result[col] = colg.agg(func)
972972

973973
result = DataFrame(result)
974+
elif isinstance(arg, list):
975+
return self._aggregate_multiple_funcs(arg)
974976
else:
975977
if len(self.groupings) > 1:
976978
return self._python_agg_general(arg, *args, **kwargs)
@@ -992,6 +994,29 @@ def aggregate(self, arg, *args, **kwargs):
992994

993995
return result
994996

997+
def _aggregate_multiple_funcs(self, arg):
998+
from pandas.tools.merge import concat
999+
1000+
if self.axis != 0:
1001+
raise NotImplementedError
1002+
1003+
obj = self._obj_with_exclusions
1004+
1005+
results = []
1006+
keys = []
1007+
for col in obj:
1008+
try:
1009+
colg = SeriesGroupBy(obj[col], column=col,
1010+
groupings=self.groupings)
1011+
results.append(colg.agg(arg))
1012+
keys.append(col)
1013+
except TypeError:
1014+
pass
1015+
1016+
result = concat(results, keys=keys, axis=1)
1017+
1018+
return result
1019+
9951020
def _aggregate_generic(self, func, *args, **kwargs):
9961021
assert(len(self.groupings) == 1)
9971022

pandas/core/index.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1896,7 +1896,7 @@ def _get_combined_index(indexes, intersect=False):
18961896
index = index.intersection(other)
18971897
return index
18981898
union = _union_indexes(indexes)
1899-
return Index(union)
1899+
return _ensure_index(union)
19001900

19011901
def _get_distinct_indexes(indexes):
19021902
return dict((id(x), x) for x in indexes).values()

pandas/tests/test_groupby.py

+25
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pandas.util.testing import (assert_panel_equal, assert_frame_equal,
1414
assert_series_equal, assert_almost_equal)
1515
from pandas.core.panel import Panel
16+
from pandas.tools.merge import concat
1617
from collections import defaultdict
1718
import pandas._tseries as lib
1819
import pandas.core.datetools as dt
@@ -564,6 +565,30 @@ def test_multi_key_multiple_functions(self):
564565
'std' : grouped.agg(np.std)})
565566
assert_frame_equal(agged, expected)
566567

568+
def test_frame_multi_key_function_list(self):
569+
data = DataFrame({'A' : ['foo', 'foo', 'foo', 'foo',
570+
'bar', 'bar', 'bar', 'bar',
571+
'foo', 'foo', 'foo'],
572+
'B' : ['one', 'one', 'one', 'two',
573+
'one', 'one', 'one', 'two',
574+
'two', 'two', 'one'],
575+
'C' : ['dull', 'dull', 'shiny', 'dull',
576+
'dull', 'shiny', 'shiny', 'dull',
577+
'shiny', 'shiny', 'shiny'],
578+
'D' : np.random.randn(11),
579+
'E' : np.random.randn(11),
580+
'F' : np.random.randn(11)})
581+
582+
grouped = data.groupby(['A', 'B'])
583+
funcs = [np.mean, np.std]
584+
agged = grouped.agg(funcs)
585+
expected = concat([grouped['D'].agg(funcs), grouped['E'].agg(funcs),
586+
grouped['F'].agg(funcs)],
587+
keys=['D', 'E', 'F'], axis=1)
588+
assert(isinstance(agged.index, MultiIndex))
589+
assert(isinstance(expected.index, MultiIndex))
590+
assert_frame_equal(agged, expected)
591+
567592
def test_groupby_multiple_columns(self):
568593
data = self.df
569594
grouped = data.groupby(['A', 'B'])

vb_suite/binary_ops.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,6 @@ def sample(values, k):
2121
ts2 = Series(np.random.randn(sz), idx2)
2222
"""
2323
stmt = "ts1 + ts2"
24-
bm_align1 = Benchmark(stmt, setup,
25-
name="series_align_int64_index",
26-
start_date=datetime(2010, 6, 1), logy=True)
24+
series_align_int64_index = Benchmark(stmt, setup,
25+
start_date=datetime(2010, 6, 1),
26+
logy=True)

vb_suite/frame_ctor.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,11 @@
1717
dict_list = [dict(zip(columns, row)) for row in frame.values]
1818
"""
1919

20-
frame_ctor_nested_dict = \
21-
Benchmark("DataFrame(data)", setup, name='frame_ctor_nested_dict')
20+
frame_ctor_nested_dict = Benchmark("DataFrame(data)", setup)
2221

2322
# From JSON-like stuff
2423

25-
frame_ctor_list_of_dict = \
26-
Benchmark("DataFrame(dict_list)", setup, name='frame_ctor_list_of_dict',
27-
start_date=datetime(2011, 12, 20))
24+
frame_ctor_list_of_dict = Benchmark("DataFrame(dict_list)", setup,
25+
start_date=datetime(2011, 12, 20))
2826

29-
series_ctor_from_dict = \
30-
Benchmark("Series(some_dict)", setup, name='series_ctor_from_dict')
27+
series_ctor_from_dict = Benchmark("Series(some_dict)", setup)

vb_suite/groupby.py

-4
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,17 @@ def f():
3333

3434
stmt1 = "df.groupby(['key1', 'key2'])['data1'].agg(lambda x: x.values.sum())"
3535
groupby_multi_python = Benchmark(stmt1, setup,
36-
name="groupby_multi_python",
3736
start_date=datetime(2011, 7, 1))
3837

3938
stmt3 = "df.groupby(['key1', 'key2']).sum()"
4039
groupby_multi_cython = Benchmark(stmt3, setup,
41-
name="groupby_multi_cython",
4240
start_date=datetime(2011, 7, 1))
4341

4442
stmt = "df.groupby(['key1', 'key2'])['data1'].agg(np.std)"
4543
groupby_multi_series_op = Benchmark(stmt, setup,
46-
name="groupby_multi_series_op",
4744
start_date=datetime(2011, 8, 1))
4845

4946
groupby_series_simple_cython = \
5047
Benchmark('simple_series.groupby(key1).sum()', setup,
51-
name='groupby_series_simple_cython',
5248
start_date=datetime(2011, 3, 1))
5349

vb_suite/suite.py

+3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
if isinstance(v, Benchmark)]
1717
benchmarks.extend(by_module[modname])
1818

19+
for bm in benchmarks:
20+
assert(bm.name is not None)
21+
1922
REPO_PATH = '/home/wesm/code/pandas'
2023
REPO_URL = '[email protected]:wesm/pandas.git'
2124
DB_PATH = '/home/wesm/code/pandas/vb_suite/benchmarks.db'

0 commit comments

Comments
 (0)