Skip to content

Commit 7898ec2

Browse files
committed
PERF: groupby transform
1 parent e6140e9 commit 7898ec2

File tree

2 files changed

+107
-13
lines changed

2 files changed

+107
-13
lines changed

bench/bench_transform.py

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import numpy as np
2+
import pandas as pd
3+
from pandas import Index, MultiIndex, DataFrame
4+
from pandas.core.groupby import SeriesGroupBy, DataFrameGroupBy
5+
6+
def apply_by_group(grouped, f):
7+
"""
8+
Applies a function to each Series or DataFrame in a GroupBy object, concatenates the results
9+
and returns the resulting Series or DataFrame.
10+
11+
Parameters
12+
----------
13+
grouped: SeriesGroupBy or DataFrameGroupBy
14+
f: callable
15+
Function to apply to each Series or DataFrame in the grouped object.
16+
17+
Returns
18+
-------
19+
Series or DataFrame that results from applying the function to each Series or DataFrame in the
20+
GroupBy object and concatenating the results.
21+
22+
"""
23+
assert isinstance(grouped, (SeriesGroupBy, DataFrameGroupBy))
24+
assert hasattr(f, '__call__')
25+
26+
groups = []
27+
for key, group in grouped:
28+
groups.append(f(group))
29+
c = pd.concat(groups)
30+
c.sort_index(inplace=True)
31+
return c
32+
33+
n_dates = 1000
34+
n_securities = 2000
35+
n_columns = 3
36+
share_na = 0.1
37+
38+
dates = pd.date_range('1997-12-31', periods=n_dates, freq='B')
39+
dates = Index(map(lambda x: x.year * 10000 + x.month * 100 + x.day, dates))
40+
41+
secid_min = int('10000000', 16)
42+
secid_max = int('F0000000', 16)
43+
step = (secid_max - secid_min) // (n_securities - 1)
44+
security_ids = map(lambda x: hex(x)[2:10].upper(), range(secid_min, secid_max + 1, step))
45+
46+
data_index = MultiIndex(levels=[dates.values, security_ids],
47+
labels=[[i for i in xrange(n_dates) for _ in xrange(n_securities)], range(n_securities) * n_dates],
48+
names=['date', 'security_id'])
49+
n_data = len(data_index)
50+
51+
columns = Index(['factor{}'.format(i) for i in xrange(1, n_columns + 1)])
52+
53+
data = DataFrame(np.random.randn(n_data, n_columns), index=data_index, columns=columns)
54+
55+
step = int(n_data * share_na)
56+
for column_index in xrange(n_columns):
57+
index = column_index
58+
while index < n_data:
59+
data.set_value(data_index[index], columns[column_index], np.nan)
60+
index += step
61+
62+
grouped = data.groupby(level='security_id')
63+
f_fillna = lambda x: x.fillna(method='pad')
64+
65+
#%timeit grouped.transform(f_fillna)
66+
#%timeit apply_by_group(grouped, f_fillna)

pandas/core/groupby.py

+41-13
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pandas.util.compat import OrderedDict
1414
import pandas.core.algorithms as algos
1515
import pandas.core.common as com
16-
from pandas.core.common import _possibly_downcast_to_dtype
16+
from pandas.core.common import _possibly_downcast_to_dtype, notnull
1717

1818
import pandas.lib as lib
1919
import pandas.algos as _algos
@@ -75,7 +75,7 @@ def f(self):
7575
def _first_compat(x, axis=0):
7676
def _first(x):
7777
x = np.asarray(x)
78-
x = x[com.notnull(x)]
78+
x = x[notnull(x)]
7979
if len(x) == 0:
8080
return np.nan
8181
return x[0]
@@ -89,7 +89,7 @@ def _first(x):
8989
def _last_compat(x, axis=0):
9090
def _last(x):
9191
x = np.asarray(x)
92-
x = x[com.notnull(x)]
92+
x = x[notnull(x)]
9393
if len(x) == 0:
9494
return np.nan
9595
return x[-1]
@@ -421,7 +421,7 @@ def ohlc(self):
421421

422422
def nth(self, n):
423423
def picker(arr):
424-
arr = arr[com.notnull(arr)]
424+
arr = arr[notnull(arr)]
425425
if len(arr) >= n + 1:
426426
return arr.iget(n)
427427
else:
@@ -1897,19 +1897,46 @@ def transform(self, func, *args, **kwargs):
18971897
gen = self.grouper.get_iterator(obj, axis=self.axis)
18981898

18991899
if isinstance(func, basestring):
1900-
wrapper = lambda x: getattr(x, func)(*args, **kwargs)
1900+
fast_path = lambda group: getattr(group, func)(*args, **kwargs)
1901+
slow_path = lambda group: group.apply(lambda x: getattr(x, func)(*args, **kwargs), axis=self.axis)
19011902
else:
1902-
wrapper = lambda x: func(x, *args, **kwargs)
1903+
fast_path = lambda group: func(group, *args, **kwargs)
1904+
slow_path = lambda group: group.apply(lambda x: func(x, *args, **kwargs), axis=self.axis)
19031905

1906+
path = None
19041907
for name, group in gen:
19051908
object.__setattr__(group, 'name', name)
19061909

1907-
try:
1908-
res = group.apply(wrapper, axis=self.axis)
1909-
except TypeError:
1910-
return self._transform_item_by_item(obj, wrapper)
1911-
except Exception: # pragma: no cover
1912-
res = wrapper(group)
1910+
# decide on a fast path
1911+
if path is None:
1912+
1913+
path = slow_path
1914+
try:
1915+
res = slow_path(group)
1916+
1917+
# if we make it here, test if we can use the fast path
1918+
try:
1919+
res_fast = fast_path(group)
1920+
1921+
# compare that we get the same results
1922+
if res.shape == res_fast.shape:
1923+
res_r = res.values.ravel()
1924+
res_fast_r = res_fast.values.ravel()
1925+
mask = notnull(res_r)
1926+
if (res_r[mask] == res_fast_r[mask]).all():
1927+
path = fast_path
1928+
1929+
except:
1930+
pass
1931+
except TypeError:
1932+
return self._transform_item_by_item(obj, fast_path)
1933+
except Exception: # pragma: no cover
1934+
res = fast_path(group)
1935+
path = fast_path
1936+
1937+
else:
1938+
1939+
res = path(group)
19131940

19141941
# broadcasting
19151942
if isinstance(res, Series):
@@ -1925,7 +1952,8 @@ def transform(self, func, *args, **kwargs):
19251952
concat_index = obj.columns if self.axis == 0 else obj.index
19261953
concatenated = concat(applied, join_axes=[concat_index],
19271954
axis=self.axis, verify_integrity=False)
1928-
return concatenated.reindex_like(obj)
1955+
concatenated.sort_index(inplace=True)
1956+
return concatenated
19291957

19301958
def _transform_item_by_item(self, obj, wrapper):
19311959
# iterate through columns

0 commit comments

Comments
 (0)