Skip to content

Commit ab2c65d

Browse files
committed
ENH: Only apply first group once in fast GroupBy.apply
1 parent 2626215 commit ab2c65d

File tree

6 files changed

+106
-48
lines changed

6 files changed

+106
-48
lines changed

doc/source/user_guide/groupby.rst

-17
Original file line numberDiff line numberDiff line change
@@ -946,23 +946,6 @@ that is itself a series, and possibly upcast the result to a DataFrame:
946946
So depending on the path taken, and exactly what you are grouping. Thus the grouped columns(s) may be included in
947947
the output as well as set the indices.
948948

949-
.. warning::
950-
951-
In the current implementation apply calls func twice on the
952-
first group to decide whether it can take a fast or slow code
953-
path. This can lead to unexpected behavior if func has
954-
side-effects, as they will take effect twice for the first
955-
group.
956-
957-
.. ipython:: python
958-
959-
d = pd.DataFrame({"a": ["x", "y"], "b": [1, 2]})
960-
def identity(df):
961-
print(df)
962-
return df
963-
964-
d.groupby("a").apply(identity)
965-
966949

967950
Other useful features
968951
---------------------

doc/source/whatsnew/v0.25.0.rst

+45
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,51 @@ Other Enhancements
2626
Backwards incompatible API changes
2727
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2828

29+
GroupBy.apply on ``DataFrame`` evaluates first group only once
30+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
31+
32+
(:issue:`2936`, :issue:`2656`, :issue:`7739`, :issue:`10519`, :issue:`12155`,
33+
:issue:`20084`, :issue:`21417`)
34+
35+
The implementation of ``DataFrame.groupby.apply`` previously evaluated func
36+
consistently twice on the first group to infer if it is safe to use a fast
37+
code path. Particularly for functions with side effects, this was an undesired
38+
behavior and may have led to surprises.
39+
40+
Now every group is evaluated only a single time.
41+
42+
Previous behavior:
43+
44+
.. code-block:: ipython
45+
46+
In [2]: df = pd.DataFrame({"a": ["x", "y"], "b": [1, 2]})
47+
48+
In [3]: side_effects = []
49+
50+
In [4]: def func_fast_apply(group):
51+
...: side_effects.append(group.name)
52+
...: return len(group)
53+
...:
54+
55+
In [5]: df.groupby("a").apply(func_fast_apply)
56+
57+
In [6]: assert side_effects == ["x", "x", "y"]
58+
59+
New behavior:
60+
61+
.. ipython:: python
62+
63+
df = pd.DataFrame({"a": ["x", "y"], "b": [1, 2]})
64+
65+
side_effects = []
66+
def func(group):
67+
side_effects.append(group.name)
68+
return group
69+
70+
df.groupby("a").apply(func)
71+
assert side_effects == ["x", "y"]
72+
73+
2974
.. _whatsnew_0250.api.other:
3075

3176
Other API Changes

pandas/_libs/reduction.pyx

+14-17
Original file line numberDiff line numberDiff line change
@@ -507,44 +507,41 @@ def apply_frame_axis0(object frame, object f, object names,
507507

508508
results = []
509509

510-
# Need to infer if our low-level mucking is going to cause a segfault
511-
if n > 0:
512-
chunk = frame.iloc[starts[0]:ends[0]]
513-
object.__setattr__(chunk, 'name', names[0])
514-
try:
515-
result = f(chunk)
516-
if result is chunk:
517-
raise InvalidApply('Function unsafe for fast apply')
518-
except:
519-
raise InvalidApply('Let this error raise above us')
520-
521510
slider = BlockSlider(frame)
522511

523512
mutated = False
513+
status = 0
524514
item_cache = slider.dummy._item_cache
525515
try:
526516
for i in range(n):
527517
slider.move(starts[i], ends[i])
528518

529519
item_cache.clear() # ugh
520+
chunk = slider.dummy
521+
object.__setattr__(chunk, 'name', names[i])
530522

531-
object.__setattr__(slider.dummy, 'name', names[i])
532-
piece = f(slider.dummy)
533-
534-
# I'm paying the price for index-sharing, ugh
535523
try:
536-
if piece.index is slider.dummy.index:
524+
piece = f(chunk)
525+
except:
526+
raise InvalidApply('Let this error raise above us')
527+
# Need to infer if low level index slider will cause segfaults
528+
if i == 0 and piece is chunk:
529+
status = 1
530+
try:
531+
if piece.index is chunk.index:
537532
piece = piece.copy(deep='all')
538533
else:
539534
mutated = True
540535
except AttributeError:
541536
pass
542537

543538
results.append(piece)
539+
if status > 0:
540+
break
544541
finally:
545542
slider.reset()
546543

547-
return results, mutated
544+
return results, mutated, status
548545

549546

550547
cdef class BlockSlider:

pandas/core/groupby/ops.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,17 @@ def apply(self, f, data, axis=0):
165165
mutated = self.mutated
166166
splitter = self._get_splitter(data, axis=axis)
167167
group_keys = self._get_group_keys()
168-
168+
status = 0
169+
result_values = []
169170
# oh boy
170171
f_name = com.get_callable_name(f)
171172
if (f_name not in base.plotting_methods and
172173
hasattr(splitter, 'fast_apply') and axis == 0):
173174
try:
174-
values, mutated = splitter.fast_apply(f, group_keys)
175-
return group_keys, values, mutated
175+
result = splitter.fast_apply(f, group_keys)
176+
result_values, mutated, status = result
177+
if status == 0:
178+
return group_keys, result_values, mutated
176179
except reduction.InvalidApply:
177180
# we detect a mutation of some kind
178181
# so take slow path
@@ -181,9 +184,10 @@ def apply(self, f, data, axis=0):
181184
# raise this error to the caller
182185
pass
183186

184-
result_values = []
185187
for key, (i, group) in zip(group_keys, splitter):
186188
object.__setattr__(group, 'name', key)
189+
if status > 0 and i == 0:
190+
continue
187191

188192
# group might be modified
189193
group_axes = _get_axes(group)
@@ -854,10 +858,7 @@ def fast_apply(self, f, names):
854858
return [], True
855859

856860
sdata = self._get_sorted_data()
857-
results, mutated = reduction.apply_frame_axis0(sdata, f, names,
858-
starts, ends)
859-
860-
return results, mutated
861+
return reduction.apply_frame_axis0(sdata, f, names, starts, ends)
861862

862863
def _chop(self, sdata, slice_obj):
863864
if self.axis == 0:

pandas/tests/groupby/test_apply.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,32 @@ def f(g):
101101
splitter = grouper._get_splitter(g._selected_obj, axis=g.axis)
102102
group_keys = grouper._get_group_keys()
103103

104-
values, mutated = splitter.fast_apply(f, group_keys)
104+
values, mutated, status = splitter.fast_apply(f, group_keys)
105+
assert status == 0
105106
assert not mutated
106107

107108

109+
def test_group_apply_once_per_group():
110+
# GH24748 ,GH2936, GH2656, GH7739, GH10519, GH12155, GH20084, GH21417
111+
df = pd.DataFrame({'a': [0, 0, 1, 1, 2, 2], 'b': np.arange(6)})
112+
113+
names = []
114+
115+
def f_copy(group):
116+
names.append(group.name)
117+
return group.copy()
118+
df.groupby("a").apply(f_copy)
119+
assert names == [0, 1, 2]
120+
121+
def f_nocopy(group):
122+
names.append(group.name)
123+
return group
124+
names = []
125+
# this takes the slow apply path
126+
df.groupby("a").apply(f_nocopy)
127+
assert names == [0, 1, 2]
128+
129+
108130
def test_apply_with_mixed_dtype():
109131
# GH3480, apply with mixed dtype on axis=1 breaks in 0.11
110132
df = DataFrame({'foo1': np.random.randn(6),

pandas/tests/groupby/test_groupby.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -1420,20 +1420,30 @@ def foo(x):
14201420

14211421
def test_group_name_available_in_inference_pass():
14221422
# gh-15062
1423+
# GH24748 ,GH2936, GH2656, GH7739, GH10519, GH12155, GH20084, GH21417
14231424
df = pd.DataFrame({'a': [0, 0, 1, 1, 2, 2], 'b': np.arange(6)})
14241425

14251426
names = []
14261427

1427-
def f(group):
1428+
def f_fast(group):
14281429
names.append(group.name)
14291430
return group.copy()
14301431

1431-
df.groupby('a', sort=False, group_keys=False).apply(f)
1432-
# we expect 2 zeros because we call ``f`` once to see if a faster route
1433-
# can be used.
1434-
expected_names = [0, 0, 1, 2]
1432+
df.groupby('a', sort=False, group_keys=False).apply(f_fast)
1433+
1434+
# every group should appear once, i.e. apply is called once per group
1435+
expected_names = [0, 1, 2]
14351436
assert names == expected_names
14361437

1438+
names_slow = []
1439+
1440+
def f_slow(group):
1441+
names_slow.append(group.name)
1442+
return group
1443+
1444+
df.groupby('a', sort=False, group_keys=False).apply(f_slow)
1445+
assert names_slow == [0, 1, 2]
1446+
14371447

14381448
def test_no_dummy_key_names(df):
14391449
# see gh-1291

0 commit comments

Comments
 (0)