Skip to content

Commit 53c6b08

Browse files
committed
Merge pull request #7421 from jreback/transform_speed
PERF: Series.transform speedups (GH6496)
2 parents 3a6f34e + dd85fa0 commit 53c6b08

File tree

4 files changed

+40
-16
lines changed

4 files changed

+40
-16
lines changed

doc/source/v0.14.1.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ Performance
137137

138138

139139
- Improvements in dtype inference for numeric operations involving yielding performance gains for dtypes: ``int64``, ``timedelta64``, ``datetime64`` (:issue:`7223`)
140-
140+
- Improvements in Series.transform for signifcant performance gains (:issue`6496`)
141141

142142

143143

pandas/core/groupby.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pandas.core.categorical import Categorical
1515
from pandas.core.frame import DataFrame
1616
from pandas.core.generic import NDFrame
17-
from pandas.core.index import Index, MultiIndex, _ensure_index
17+
from pandas.core.index import Index, MultiIndex, _ensure_index, _union_indexes
1818
from pandas.core.internals import BlockManager, make_block
1919
from pandas.core.series import Series
2020
from pandas.core.panel import Panel
@@ -425,7 +425,7 @@ def convert(key, s):
425425
return Timestamp(key).asm8
426426
return key
427427

428-
sample = list(self.indices)[0]
428+
sample = next(iter(self.indices))
429429
if isinstance(sample, tuple):
430430
if not isinstance(name, tuple):
431431
raise ValueError("must supply a tuple to get_group with multiple grouping keys")
@@ -2193,33 +2193,37 @@ def transform(self, func, *args, **kwargs):
21932193
-------
21942194
transformed : Series
21952195
"""
2196-
result = self._selected_obj.copy()
2197-
if hasattr(result, 'values'):
2198-
result = result.values
2199-
dtype = result.dtype
2196+
dtype = self._selected_obj.dtype
22002197

22012198
if isinstance(func, compat.string_types):
22022199
wrapper = lambda x: getattr(x, func)(*args, **kwargs)
22032200
else:
22042201
wrapper = lambda x: func(x, *args, **kwargs)
22052202

2206-
for name, group in self:
2203+
result = self._selected_obj.values.copy()
2204+
for i, (name, group) in enumerate(self):
22072205

22082206
object.__setattr__(group, 'name', name)
22092207
res = wrapper(group)
2208+
22102209
if hasattr(res, 'values'):
22112210
res = res.values
22122211

2213-
# need to do a safe put here, as the dtype may be different
2214-
# this needs to be an ndarray
2215-
result = Series(result)
2216-
result.iloc[self._get_index(name)] = res
2217-
result = result.values
2212+
# may need to astype
2213+
try:
2214+
common_type = np.common_type(np.array(res), result)
2215+
if common_type != result.dtype:
2216+
result = result.astype(common_type)
2217+
except:
2218+
pass
2219+
2220+
indexer = self._get_index(name)
2221+
result[indexer] = res
22182222

2219-
# downcast if we can (and need)
22202223
result = _possibly_downcast_to_dtype(result, dtype)
2221-
return self._selected_obj.__class__(result, index=self._selected_obj.index,
2222-
name=self._selected_obj.name)
2224+
return self._selected_obj.__class__(result,
2225+
index=self._selected_obj.index,
2226+
name=self._selected_obj.name)
22232227

22242228
def filter(self, func, dropna=True, *args, **kwargs):
22252229
"""

pandas/tests/test_groupby.py

+2
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,10 @@ def checkit(dtype):
126126
assert_series_equal(agged, grouped.mean())
127127
assert_series_equal(grouped.agg(np.sum), grouped.sum())
128128

129+
expected = grouped.apply(lambda x: x * x.sum())
129130
transformed = grouped.transform(lambda x: x * x.sum())
130131
self.assertEqual(transformed[7], 12)
132+
assert_series_equal(transformed, expected)
131133

132134
value_grouped = data.groupby(data)
133135
assert_series_equal(value_grouped.aggregate(np.mean), agged)

vb_suite/groupby.py

+18
Original file line numberDiff line numberDiff line change
@@ -376,3 +376,21 @@ def f(g):
376376
"""
377377

378378
groupby_transform = Benchmark("data.groupby(level='security_id').transform(f_fillna)", setup)
379+
380+
setup = common_setup + """
381+
np.random.seed(0)
382+
383+
N = 120000
384+
N_TRANSITIONS = 1400
385+
386+
# generate groups
387+
transition_points = np.random.permutation(np.arange(N))[:N_TRANSITIONS]
388+
transition_points.sort()
389+
transitions = np.zeros((N,), dtype=np.bool)
390+
transitions[transition_points] = True
391+
g = transitions.cumsum()
392+
393+
df = DataFrame({ 'signal' : np.random.rand(N)})
394+
"""
395+
396+
groupby_transform2 = Benchmark("df['signal'].groupby(g).transform(np.mean)", setup)

0 commit comments

Comments
 (0)