Skip to content

Commit b66a1c8

Browse files
committed
PERF: DataFrame transform
1 parent 2de2884 commit b66a1c8

File tree

4 files changed

+57
-27
lines changed

4 files changed

+57
-27
lines changed

asv_bench/benchmarks/groupby.py

+15
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,21 @@ def setup(self):
773773
def time_groupby_transform_series2(self):
774774
self.df.groupby('id')['val'].transform(np.mean)
775775

776+
777+
class groupby_transform_dataframe(object):
778+
# GH 12737
779+
goal_time = 0.2
780+
781+
def setup(self):
782+
self.df = pd.DataFrame({'group': np.repeat(np.arange(1000), 10),
783+
'B': np.nan,
784+
'C': np.nan})
785+
self.df.ix[4::10, 'B':'C'] = 5
786+
787+
def time_groupby_transform_dataframe(self):
788+
self.df.groupby('group').transform('first')
789+
790+
776791
class groupby_transform_cythonized(object):
777792
goal_time = 0.2
778793

doc/source/whatsnew/v0.18.2.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ Performance Improvements
104104
- increased performance of ``DataFrame.quantile()`` as it now operates per-block (:issue:`11623`)
105105

106106

107-
107+
- Improved performance of ``DataFrameGroupBy.transform`` (:issue:`12737`)
108108

109109

110110
.. _whatsnew_0182.bug_fixes:
@@ -123,7 +123,7 @@ Bug Fixes
123123

124124
- Regression in ``Series.quantile`` with nans (also shows up in ``.median()`` and ``.describe()``); furthermore now names the ``Series`` with the quantile (:issue:`13098`, :issue:`13146`)
125125

126-
126+
- Bug in ``SeriesGroupBy.transform`` with datetime values and missing groups (:issue:`13191`)
127127

128128
- Bug in ``Series.str.extractall()`` with ``str`` index raises ``ValueError`` (:issue:`13156`)
129129

pandas/core/groupby.py

+8-23
Original file line numberDiff line numberDiff line change
@@ -2776,18 +2776,9 @@ def _transform_fast(self, func):
27762776
func = getattr(self, func)
27772777

27782778
ids, _, ngroup = self.grouper.group_info
2779-
mask = ids != -1
2780-
2781-
out = func().values[ids]
2782-
if not mask.all():
2783-
out = np.where(mask, out, np.nan)
2784-
2785-
obs = np.zeros(ngroup, dtype='bool')
2786-
obs[ids[mask]] = True
2787-
if not obs.all():
2788-
out = self._try_cast(out, self._selected_obj)
27892779

2790-
return Series(out, index=self.obj.index)
2780+
out = algos.take_1d(func().values, ids)
2781+
return Series(out, index=self.obj.index, name=self.obj.name)
27912782

27922783
def filter(self, func, dropna=True, *args, **kwargs): # noqa
27932784
"""
@@ -3465,19 +3456,13 @@ def transform(self, func, *args, **kwargs):
34653456
if not result.columns.equals(obj.columns):
34663457
return self._transform_general(func, *args, **kwargs)
34673458

3468-
results = np.empty_like(obj.values, result.values.dtype)
3469-
for (name, group), (i, row) in zip(self, result.iterrows()):
3470-
indexer = self._get_index(name)
3471-
if len(indexer) > 0:
3472-
results[indexer] = np.tile(row.values, len(
3473-
indexer)).reshape(len(indexer), -1)
3474-
3475-
counts = self.size().fillna(0).values
3476-
if any(counts == 0):
3477-
results = self._try_cast(results, obj[result.columns])
3459+
# Fast transform
3460+
ids, _, ngroup = self.grouper.group_info
3461+
out = {}
3462+
for col in result:
3463+
out[col] = algos.take_nd(result[col].values, ids)
34783464

3479-
return (DataFrame(results, columns=result.columns, index=obj.index)
3480-
._convert(datetime=True))
3465+
return DataFrame(out, columns=result.columns, index=obj.index)
34813466

34823467
def _define_paths(self, func, *args, **kwargs):
34833468
if isinstance(func, compat.string_types):

pandas/tests/test_groupby.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -1051,13 +1051,33 @@ def test_transform_fast(self):
10511051

10521052
values = np.repeat(grp.mean().values,
10531053
com._ensure_platform_int(grp.count().values))
1054-
expected = pd.Series(values, index=df.index)
1054+
expected = pd.Series(values, index=df.index, name='val')
10551055
result = grp.transform(np.mean)
10561056
assert_series_equal(result, expected)
10571057

10581058
result = grp.transform('mean')
10591059
assert_series_equal(result, expected)
10601060

1061+
# GH 12737
1062+
df = pd.DataFrame({'grouping': [0, 1, 1, 3], 'f': [1.1, 2.1, 3.1, 4.5],
1063+
'd': pd.date_range('2014-1-1', '2014-1-4'),
1064+
'i': [1, 2, 3, 4]},
1065+
columns=['grouping', 'f', 'i', 'd'])
1066+
result = df.groupby('grouping').transform('first')
1067+
1068+
dates = [pd.Timestamp('2014-1-1'), pd.Timestamp('2014-1-2'),
1069+
pd.Timestamp('2014-1-2'), pd.Timestamp('2014-1-4')]
1070+
expected = pd.DataFrame({'f': [1.1, 2.1, 2.1, 4.5],
1071+
'd': dates,
1072+
'i': [1, 2, 2, 4]},
1073+
columns=['f', 'i', 'd'])
1074+
assert_frame_equal(result, expected)
1075+
1076+
# selection
1077+
result = df.groupby('grouping')[['f', 'i']].transform('first')
1078+
expected = expected[['f', 'i']]
1079+
assert_frame_equal(result, expected)
1080+
10611081
def test_transform_broadcast(self):
10621082
grouped = self.ts.groupby(lambda x: x.month)
10631083
result = grouped.transform(np.mean)
@@ -1191,6 +1211,16 @@ def test_transform_function_aliases(self):
11911211
expected = self.df.groupby('A')['C'].transform(np.mean)
11921212
assert_series_equal(result, expected)
11931213

1214+
def test_series_fast_transform_date(self):
1215+
# GH 13191
1216+
df = pd.DataFrame({'grouping': [np.nan, 1, 1, 3],
1217+
'd': pd.date_range('2014-1-1', '2014-1-4')})
1218+
result = df.groupby('grouping')['d'].transform('first')
1219+
dates = [pd.NaT, pd.Timestamp('2014-1-2'), pd.Timestamp('2014-1-2'),
1220+
pd.Timestamp('2014-1-4')]
1221+
expected = pd.Series(dates, name='d')
1222+
assert_series_equal(result, expected)
1223+
11941224
def test_transform_length(self):
11951225
# GH 9697
11961226
df = pd.DataFrame({'col1': [1, 1, 2, 2], 'col2': [1, 2, 3, np.nan]})
@@ -4406,7 +4436,7 @@ def test_groupby_datetime64_32_bit(self):
44064436

44074437
df = DataFrame({"A": range(2), "B": [pd.Timestamp('2000-01-1')] * 2})
44084438
result = df.groupby("A")["B"].transform(min)
4409-
expected = Series([pd.Timestamp('2000-01-1')] * 2)
4439+
expected = Series([pd.Timestamp('2000-01-1')] * 2, name='B')
44104440
assert_series_equal(result, expected)
44114441

44124442
def test_groupby_categorical_unequal_len(self):

0 commit comments

Comments
 (0)