Skip to content

Commit 2f48549

Browse files
committed
fixup slow transforms
1 parent cc43503 commit 2f48549

File tree

2 files changed

+18
-24
lines changed

2 files changed

+18
-24
lines changed

pandas/core/groupby.py

+17-24
Original file line numberDiff line numberDiff line change
@@ -2890,39 +2890,32 @@ def transform(self, func, *args, **kwargs):
28902890
lambda: getattr(self, func)(*args, **kwargs))
28912891

28922892
# reg transform
2893-
dtype = self._selected_obj.dtype
2894-
result = self._selected_obj.values.copy()
2895-
2893+
klass = self._selected_obj.__class__
2894+
results = []
28962895
wrapper = lambda x: func(x, *args, **kwargs)
2897-
for i, (name, group) in enumerate(self):
2896+
for name, group in self:
28982897
object.__setattr__(group, 'name', name)
28992898
res = wrapper(group)
29002899

29012900
if hasattr(res, 'values'):
29022901
res = res.values
29032902

2904-
# may need to astype
2905-
try:
2906-
common_type = np.common_type(np.array(res), result)
2907-
if common_type != result.dtype:
2908-
result = result.astype(common_type)
2909-
except Exception as exc:
2910-
# date math can cause type of result to change
2911-
if i == 0 and (is_datetime64_dtype(result.dtype) or
2912-
is_timedelta64_dtype(result.dtype)):
2913-
try:
2914-
dtype = res.dtype
2915-
except Exception as exc:
2916-
dtype = type(res)
2917-
result = np.empty_like(result, dtype)
2918-
29192903
indexer = self._get_index(name)
2920-
result[indexer] = res
2904+
s = klass(res, indexer)
2905+
results.append(s)
29212906

2922-
result = _possibly_downcast_to_dtype(result, dtype)
2923-
return self._selected_obj.__class__(result,
2924-
index=self._selected_obj.index,
2925-
name=self._selected_obj.name)
2907+
from pandas.tools.concat import concat
2908+
result = concat(results).sort_index()
2909+
2910+
# we will only try to coerce the result type if
2911+
# we have a numeric dtype
2912+
dtype = self._selected_obj.dtype
2913+
if is_numeric_dtype(dtype):
2914+
result = _possibly_downcast_to_dtype(result, dtype)
2915+
2916+
result.name = self._selected_obj.name
2917+
result.index = self._selected_obj.index
2918+
return result
29262919

29272920
def _transform_fast(self, func):
29282921
"""

pandas/tests/groupby/test_filters.py

+1
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def test_filter_against_workaround(self):
216216
grouper = s.apply(lambda x: np.round(x, -1))
217217
grouped = s.groupby(grouper)
218218
f = lambda x: x.mean() > 10
219+
219220
old_way = s[grouped.transform(f).astype('bool')]
220221
new_way = grouped.filter(f)
221222
assert_series_equal(new_way.sort_values(), old_way.sort_values())

0 commit comments

Comments
 (0)