Skip to content

Commit dbfda31

Browse files
lukemanleyyehoshuadimarsky
authored andcommitted
PERF: use groupby.transform fast path for DataFrame -> Series aggregations (pandas-dev#45387)
1 parent a21860e commit dbfda31

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

asv_bench/benchmarks/groupby.py

+9
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,12 @@ def setup(self):
735735
data = DataFrame(arr, index=index, columns=["col1", "col20", "col3"])
736736
self.df = data
737737

738+
n = 1000
739+
self.df_wide = DataFrame(
740+
np.random.randn(n, n),
741+
index=np.random.choice(range(10), n),
742+
)
743+
738744
n = 20000
739745
self.df1 = DataFrame(
740746
np.random.randint(1, n, (n, 3)), columns=["jim", "joe", "jolie"]
@@ -754,6 +760,9 @@ def time_transform_lambda_max(self):
754760
def time_transform_ufunc_max(self):
755761
self.df.groupby(level="lev1").transform(np.max)
756762

763+
def time_transform_lambda_max_wide(self):
764+
self.df_wide.groupby(level=0).transform(lambda x: np.max(x, axis=0))
765+
757766
def time_transform_multi_key1(self):
758767
self.df1.groupby(["jim", "joe"])["jolie"].transform("max")
759768

doc/source/whatsnew/v1.5.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ Other Deprecations
157157

158158
Performance improvements
159159
~~~~~~~~~~~~~~~~~~~~~~~~
160+
- Performance improvement in :meth:`.GroupBy.transform` for some user-defined DataFrame -> Series functions (:issue:`45387`)
160161
- Performance improvement in :meth:`DataFrame.duplicated` when subset consists of only one column (:issue:`45236`)
161162
-
162163

pandas/core/groupby/generic.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -1217,12 +1217,16 @@ def _choose_path(self, fast_path: Callable, slow_path: Callable, group: DataFram
12171217
# raised; see test_transform.test_transform_fastpath_raises
12181218
return path, res
12191219

1220-
# verify fast path does not change columns (and names), otherwise
1221-
# its results cannot be joined with those of the slow path
1222-
if not isinstance(res_fast, DataFrame):
1223-
return path, res
1224-
1225-
if not res_fast.columns.equals(group.columns):
1220+
# verify fast path returns either:
1221+
# a DataFrame with columns equal to group.columns
1222+
# OR a Series with index equal to group.columns
1223+
if isinstance(res_fast, DataFrame):
1224+
if not res_fast.columns.equals(group.columns):
1225+
return path, res
1226+
elif isinstance(res_fast, Series):
1227+
if not res_fast.index.equals(group.columns):
1228+
return path, res
1229+
else:
12261230
return path, res
12271231

12281232
if res_fast.equals(res):

0 commit comments

Comments
 (0)