Skip to content

Commit 9fdb8f6

Browse files
PERF: don't sort data twice in groupby apply when not using libreduction fast_apply (#40176)
1 parent 13a97c2 commit 9fdb8f6

File tree

4 files changed

+30
-19
lines changed

4 files changed

+30
-19
lines changed

asv_bench/benchmarks/groupby.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,18 @@ def time_groupby_apply_dict_return(self):
6868

6969

7070
class Apply:
71-
def setup_cache(self):
72-
N = 10 ** 4
73-
labels = np.random.randint(0, 2000, size=N)
71+
72+
param_names = ["factor"]
73+
params = [4, 5]
74+
75+
def setup(self, factor):
76+
N = 10 ** factor
77+
# two cases:
78+
# - small groups: small data (N**4) + many labels (2000) -> average group
79+
# size of 5 (-> larger overhead of slicing method)
80+
# - larger groups: larger data (N**5) + fewer labels (20) -> average group
81+
# size of 5000
82+
labels = np.random.randint(0, 2000 if factor == 4 else 20, size=N)
7483
labels2 = np.random.randint(0, 3, size=N)
7584
df = DataFrame(
7685
{
@@ -80,25 +89,25 @@ def setup_cache(self):
8089
"value2": ["foo", "bar", "baz", "qux"] * (N // 4),
8190
}
8291
)
83-
return df
92+
self.df = df
8493

85-
def time_scalar_function_multi_col(self, df):
86-
df.groupby(["key", "key2"]).apply(lambda x: 1)
94+
def time_scalar_function_multi_col(self, factor):
95+
self.df.groupby(["key", "key2"]).apply(lambda x: 1)
8796

88-
def time_scalar_function_single_col(self, df):
89-
df.groupby("key").apply(lambda x: 1)
97+
def time_scalar_function_single_col(self, factor):
98+
self.df.groupby("key").apply(lambda x: 1)
9099

91100
@staticmethod
92101
def df_copy_function(g):
93102
# ensure that the group name is available (see GH #15062)
94103
g.name
95104
return g.copy()
96105

97-
def time_copy_function_multi_col(self, df):
98-
df.groupby(["key", "key2"]).apply(self.df_copy_function)
106+
def time_copy_function_multi_col(self, factor):
107+
self.df.groupby(["key", "key2"]).apply(self.df_copy_function)
99108

100-
def time_copy_overhead_single_col(self, df):
101-
df.groupby("key").apply(self.df_copy_function)
109+
def time_copy_overhead_single_col(self, factor):
110+
self.df.groupby("key").apply(self.df_copy_function)
102111

103112

104113
class Groups:

doc/source/whatsnew/v1.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ Performance improvements
339339
- Performance improvement in :class:`core.window.rolling.ExpandingGroupby` aggregation methods (:issue:`39664`)
340340
- Performance improvement in :class:`Styler` where render times are more than 50% reduced (:issue:`39972` :issue:`39952`)
341341
- Performance improvement in :meth:`core.window.ewm.ExponentialMovingWindow.mean` with ``times`` (:issue:`39784`)
342+
- Performance improvement in :meth:`.GroupBy.apply` when requiring the python fallback implementation (:issue:`40176`)
342343

343344
.. ---------------------------------------------------------------------------
344345

pandas/core/groupby/ops.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -208,14 +208,13 @@ def apply(self, f: F, data: FrameOrSeries, axis: int = 0):
208208
group_keys = self._get_group_keys()
209209
result_values = None
210210

211-
sdata: FrameOrSeries = splitter._get_sorted_data()
212-
if sdata.ndim == 2 and np.any(sdata.dtypes.apply(is_extension_array_dtype)):
211+
if data.ndim == 2 and np.any(data.dtypes.apply(is_extension_array_dtype)):
213212
# calling splitter.fast_apply will raise TypeError via apply_frame_axis0
214213
# if we pass EA instead of ndarray
215214
# TODO: can we have a workaround for EAs backed by ndarray?
216215
pass
217216

218-
elif isinstance(sdata._mgr, ArrayManager):
217+
elif isinstance(data._mgr, ArrayManager):
219218
# TODO(ArrayManager) don't use fast_apply / libreduction.apply_frame_axis0
220219
# for now -> relies on BlockManager internals
221220
pass
@@ -224,9 +223,10 @@ def apply(self, f: F, data: FrameOrSeries, axis: int = 0):
224223
and isinstance(splitter, FrameSplitter)
225224
and axis == 0
226225
# fast_apply/libreduction doesn't allow non-numpy backed indexes
227-
and not sdata.index._has_complex_internals
226+
and not data.index._has_complex_internals
228227
):
229228
try:
229+
sdata = splitter.sorted_data
230230
result_values, mutated = splitter.fast_apply(f, sdata, group_keys)
231231

232232
except IndexError:
@@ -988,7 +988,7 @@ def sort_idx(self):
988988
return get_group_index_sorter(self.labels, self.ngroups)
989989

990990
def __iter__(self):
991-
sdata = self._get_sorted_data()
991+
sdata = self.sorted_data
992992

993993
if self.ngroups == 0:
994994
# we are inside a generator, rather than raise StopIteration
@@ -1000,7 +1000,8 @@ def __iter__(self):
10001000
for i, (start, end) in enumerate(zip(starts, ends)):
10011001
yield i, self._chop(sdata, slice(start, end))
10021002

1003-
def _get_sorted_data(self) -> FrameOrSeries:
1003+
@cache_readonly
1004+
def sorted_data(self) -> FrameOrSeries:
10041005
return self.data.take(self.sort_idx, axis=self.axis)
10051006

10061007
def _chop(self, sdata, slice_obj: slice) -> NDFrame:

pandas/tests/groupby/test_apply.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def f(g):
113113

114114
splitter = grouper._get_splitter(g._selected_obj, axis=g.axis)
115115
group_keys = grouper._get_group_keys()
116-
sdata = splitter._get_sorted_data()
116+
sdata = splitter.sorted_data
117117

118118
values, mutated = splitter.fast_apply(f, sdata, group_keys)
119119

0 commit comments

Comments
 (0)