Skip to content

Commit 324bb84

Browse files
fjetterWillAyd
authored andcommitted
ENH: Only apply first group once in fast GroupBy.apply (#24748)
1 parent 51c6a05 commit 324bb84

File tree

6 files changed

+159
-47
lines changed

6 files changed

+159
-47
lines changed

doc/source/user_guide/groupby.rst

-17
Original file line numberDiff line numberDiff line change
@@ -946,23 +946,6 @@ that is itself a series, and possibly upcast the result to a DataFrame:
946946
So depending on the path taken, and exactly what you are grouping. Thus the grouped columns(s) may be included in
947947
the output as well as set the indices.
948948

949-
.. warning::
950-
951-
In the current implementation apply calls func twice on the
952-
first group to decide whether it can take a fast or slow code
953-
path. This can lead to unexpected behavior if func has
954-
side-effects, as they will take effect twice for the first
955-
group.
956-
957-
.. ipython:: python
958-
959-
d = pd.DataFrame({"a": ["x", "y"], "b": [1, 2]})
960-
def identity(df):
961-
print(df)
962-
return df
963-
964-
d.groupby("a").apply(identity)
965-
966949

967950
Other useful features
968951
---------------------

doc/source/whatsnew/v0.25.0.rst

+46-2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,50 @@ is respected in indexing. (:issue:`24076`, :issue:`16785`)
7373
df = pd.DataFrame([0], index=pd.DatetimeIndex(['2019-01-01'], tz='US/Pacific'))
7474
df['2019-01-01 12:00:00+04:00':'2019-01-01 13:00:00+04:00']
7575

76+
.. _whatsnew_0250.api_breaking.groupby_apply_first_group_once:
77+
78+
GroupBy.apply on ``DataFrame`` evaluates first group only once
79+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
80+
81+
The implementation of :meth:`DataFrameGroupBy.apply() <pandas.core.groupby.DataFrameGroupBy.apply>`
82+
previously evaluated the supplied function consistently twice on the first group
83+
to infer if it is safe to use a fast code path. Particularly for functions with
84+
side effects, this was an undesired behavior and may have led to surprises.
85+
86+
(:issue:`2936`, :issue:`2656`, :issue:`7739`, :issue:`10519`, :issue:`12155`,
87+
:issue:`20084`, :issue:`21417`)
88+
89+
Now every group is evaluated only a single time.
90+
91+
.. ipython:: python
92+
93+
df = pd.DataFrame({"a": ["x", "y"], "b": [1, 2]})
94+
df
95+
96+
def func(group):
97+
print(group.name)
98+
return group
99+
100+
*Previous Behaviour*:
101+
102+
.. code-block:: python
103+
104+
In [3]: df.groupby('a').apply(func)
105+
x
106+
x
107+
y
108+
Out[3]:
109+
a b
110+
0 x 1
111+
1 y 2
112+
113+
*New Behaviour*:
114+
115+
.. ipython:: python
116+
117+
df.groupby("a").apply(func)
118+
119+
76120
Concatenating Sparse Values
77121
^^^^^^^^^^^^^^^^^^^^^^^^^^^
78122

@@ -83,14 +127,14 @@ Series or DataFrame with sparse values, rather than a ``SparseDataFrame`` (:issu
83127
84128
df = pd.DataFrame({"A": pd.SparseArray([0, 1])})
85129
86-
*Previous Behavior:*
130+
*Previous Behavior*:
87131

88132
.. code-block:: ipython
89133
90134
In [2]: type(pd.concat([df, df]))
91135
pandas.core.sparse.frame.SparseDataFrame
92136
93-
*New Behavior:*
137+
*New Behavior*:
94138

95139
.. ipython:: python
96140

pandas/_libs/reduction.pyx

+15-15
Original file line numberDiff line numberDiff line change
@@ -509,17 +509,6 @@ def apply_frame_axis0(object frame, object f, object names,
509509

510510
results = []
511511

512-
# Need to infer if our low-level mucking is going to cause a segfault
513-
if n > 0:
514-
chunk = frame.iloc[starts[0]:ends[0]]
515-
object.__setattr__(chunk, 'name', names[0])
516-
try:
517-
result = f(chunk)
518-
if result is chunk:
519-
raise InvalidApply('Function unsafe for fast apply')
520-
except:
521-
raise InvalidApply('Let this error raise above us')
522-
523512
slider = BlockSlider(frame)
524513

525514
mutated = False
@@ -529,20 +518,31 @@ def apply_frame_axis0(object frame, object f, object names,
529518
slider.move(starts[i], ends[i])
530519

531520
item_cache.clear() # ugh
521+
chunk = slider.dummy
522+
object.__setattr__(chunk, 'name', names[i])
532523

533-
object.__setattr__(slider.dummy, 'name', names[i])
534-
piece = f(slider.dummy)
524+
try:
525+
piece = f(chunk)
526+
except:
527+
raise InvalidApply('Let this error raise above us')
535528

536-
# I'm paying the price for index-sharing, ugh
529+
# Need to infer if low level index slider will cause segfaults
530+
require_slow_apply = i == 0 and piece is chunk
537531
try:
538-
if piece.index is slider.dummy.index:
532+
if piece.index is chunk.index:
539533
piece = piece.copy(deep='all')
540534
else:
541535
mutated = True
542536
except AttributeError:
543537
pass
544538

545539
results.append(piece)
540+
541+
# If the data was modified inplace we need to
542+
# take the slow path to not risk segfaults
543+
# we have already computed the first piece
544+
if require_slow_apply:
545+
break
546546
finally:
547547
slider.reset()
548548

pandas/core/groupby/ops.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -165,26 +165,45 @@ def apply(self, f, data, axis=0):
165165
mutated = self.mutated
166166
splitter = self._get_splitter(data, axis=axis)
167167
group_keys = self._get_group_keys()
168+
result_values = None
168169

169170
# oh boy
170171
f_name = com.get_callable_name(f)
171172
if (f_name not in base.plotting_methods and
172173
hasattr(splitter, 'fast_apply') and axis == 0):
173174
try:
174-
values, mutated = splitter.fast_apply(f, group_keys)
175-
return group_keys, values, mutated
175+
result_values, mutated = splitter.fast_apply(f, group_keys)
176+
177+
# If the fast apply path could be used we can return here.
178+
# Otherwise we need to fall back to the slow implementation.
179+
if len(result_values) == len(group_keys):
180+
return group_keys, result_values, mutated
181+
176182
except reduction.InvalidApply:
177-
# we detect a mutation of some kind
178-
# so take slow path
183+
# Cannot fast apply on MultiIndex (_has_complex_internals).
184+
# This Exception is also raised if `f` triggers an exception
185+
# but it is preferable to raise the exception in Python.
179186
pass
180187
except Exception:
181188
# raise this error to the caller
182189
pass
183190

184-
result_values = []
185191
for key, (i, group) in zip(group_keys, splitter):
186192
object.__setattr__(group, 'name', key)
187193

194+
# result_values is None if fast apply path wasn't taken
195+
# or fast apply aborted with an unexpected exception.
196+
# In either case, initialize the result list and perform
197+
# the slow iteration.
198+
if result_values is None:
199+
result_values = []
200+
201+
# If result_values is not None we're in the case that the
202+
# fast apply loop was broken prematurely but we have
203+
# already the result for the first group which we can reuse.
204+
elif i == 0:
205+
continue
206+
188207
# group might be modified
189208
group_axes = _get_axes(group)
190209
res = f(group)
@@ -854,10 +873,7 @@ def fast_apply(self, f, names):
854873
return [], True
855874

856875
sdata = self._get_sorted_data()
857-
results, mutated = reduction.apply_frame_axis0(sdata, f, names,
858-
starts, ends)
859-
860-
return results, mutated
876+
return reduction.apply_frame_axis0(sdata, f, names, starts, ends)
861877

862878
def _chop(self, sdata, slice_obj):
863879
if self.axis == 0:

pandas/tests/groupby/test_apply.py

+71
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,80 @@ def f(g):
102102
group_keys = grouper._get_group_keys()
103103

104104
values, mutated = splitter.fast_apply(f, group_keys)
105+
105106
assert not mutated
106107

107108

109+
@pytest.mark.parametrize(
110+
"df, group_names",
111+
[
112+
(DataFrame({"a": [1, 1, 1, 2, 3],
113+
"b": ["a", "a", "a", "b", "c"]}),
114+
[1, 2, 3]),
115+
(DataFrame({"a": [0, 0, 1, 1],
116+
"b": [0, 1, 0, 1]}),
117+
[0, 1]),
118+
(DataFrame({"a": [1]}),
119+
[1]),
120+
(DataFrame({"a": [1, 1, 1, 2, 2, 1, 1, 2],
121+
"b": range(8)}),
122+
[1, 2]),
123+
(DataFrame({"a": [1, 2, 3, 1, 2, 3],
124+
"two": [4, 5, 6, 7, 8, 9]}),
125+
[1, 2, 3]),
126+
(DataFrame({"a": list("aaabbbcccc"),
127+
"B": [3, 4, 3, 6, 5, 2, 1, 9, 5, 4],
128+
"C": [4, 0, 2, 2, 2, 7, 8, 6, 2, 8]}),
129+
["a", "b", "c"]),
130+
(DataFrame([[1, 2, 3], [2, 2, 3]], columns=["a", "b", "c"]),
131+
[1, 2]),
132+
], ids=['GH2936', 'GH7739 & GH10519', 'GH10519',
133+
'GH2656', 'GH12155', 'GH20084', 'GH21417'])
134+
def test_group_apply_once_per_group(df, group_names):
135+
# GH2936, GH7739, GH10519, GH2656, GH12155, GH20084, GH21417
136+
137+
# This test should ensure that a function is only evaluted
138+
# once per group. Previously the function has been evaluated twice
139+
# on the first group to check if the Cython index slider is safe to use
140+
# This test ensures that the side effect (append to list) is only triggered
141+
# once per group
142+
143+
names = []
144+
# cannot parameterize over the functions since they need external
145+
# `names` to detect side effects
146+
147+
def f_copy(group):
148+
# this takes the fast apply path
149+
names.append(group.name)
150+
return group.copy()
151+
152+
def f_nocopy(group):
153+
# this takes the slow apply path
154+
names.append(group.name)
155+
return group
156+
157+
def f_scalar(group):
158+
# GH7739, GH2656
159+
names.append(group.name)
160+
return 0
161+
162+
def f_none(group):
163+
# GH10519, GH12155, GH21417
164+
names.append(group.name)
165+
return None
166+
167+
def f_constant_df(group):
168+
# GH2936, GH20084
169+
names.append(group.name)
170+
return DataFrame({"a": [1], "b": [1]})
171+
172+
for func in [f_copy, f_nocopy, f_scalar, f_none, f_constant_df]:
173+
del names[:]
174+
175+
df.groupby("a").apply(func)
176+
assert names == group_names
177+
178+
108179
def test_apply_with_mixed_dtype():
109180
# GH3480, apply with mixed dtype on axis=1 breaks in 0.11
110181
df = DataFrame({'foo1': np.random.randn(6),

pandas/tests/groupby/test_groupby.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1381,11 +1381,9 @@ def test_group_name_available_in_inference_pass():
13811381
def f(group):
13821382
names.append(group.name)
13831383
return group.copy()
1384-
13851384
df.groupby('a', sort=False, group_keys=False).apply(f)
1386-
# we expect 2 zeros because we call ``f`` once to see if a faster route
1387-
# can be used.
1388-
expected_names = [0, 0, 1, 2]
1385+
1386+
expected_names = [0, 1, 2]
13891387
assert names == expected_names
13901388

13911389

0 commit comments

Comments
 (0)