Skip to content

Commit 9afa234

Browse files
committed
Add more tests for apply first row
1 parent 7354f02 commit 9afa234

File tree

4 files changed

+89
-4
lines changed

4 files changed

+89
-4
lines changed

pandas/tests/frame/test_apply.py

+31
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,37 @@ def test_apply_dup_names_multi_agg(self):
568568

569569
tm.assert_frame_equal(result, expected)
570570

571+
@pytest.mark.parametrize("axis, expected", [
572+
(0, ['a', 'b']),
573+
(1, [0, 1, 2, 3, 4, 5]),
574+
])
575+
def test_apply_first_row_once(self, axis, expected):
576+
df = pd.DataFrame({'a': [0, 0, 1, 1, 2, 2], 'b': np.arange(6)})
577+
578+
rows = []
579+
580+
def f_fast(row):
581+
rows.append(row.name)
582+
return 0
583+
df.apply(f_fast, axis=axis)
584+
# gh-2936
585+
# every row should appear once, i.e. apply is called once per row
586+
assert rows == expected
587+
588+
rows_slow = []
589+
590+
def f_slow(row):
591+
"""
592+
This function triggers a `function does not reduce`
593+
exception and uses the slow path
594+
"""
595+
rows_slow.append(row.name)
596+
return row.copy()
597+
598+
df.apply(f_slow, axis=axis)
599+
expected_first_row_twice = [expected[0]] + expected
600+
assert rows_slow == expected_first_row_twice
601+
571602

572603
class TestInferOutputShape(object):
573604
# the user has supplied an opaque UDF where

pandas/tests/groupby/test_apply.py

+21
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,27 @@ def f(g):
105105
assert not mutated
106106

107107

108+
def test_group_apply_once_per_group():
109+
df = pd.DataFrame({'a': [0, 0, 1, 1, 2, 2], 'b': np.arange(6)})
110+
111+
names = []
112+
113+
def f_copy(group):
114+
names.append(group.name)
115+
return group.copy()
116+
df.groupby("a").apply(f_copy)
117+
assert names == [0, 1, 2]
118+
119+
def f_nocopy(group):
120+
names.append(group.name)
121+
return group
122+
names.clear()
123+
# this takes the slow apply path, i.e. we need to apply the
124+
# function to the first row twice
125+
df.groupby("a").apply(f_copy)
126+
assert names == [0, 0, 1, 2]
127+
128+
108129
def test_apply_with_mixed_dtype():
109130
# GH3480, apply with mixed dtype on axis=1 breaks in 0.11
110131
df = DataFrame({'foo1': np.random.randn(6),

pandas/tests/groupby/test_groupby.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -1390,16 +1390,25 @@ def test_group_name_available_in_inference_pass():
13901390

13911391
names = []
13921392

1393-
def f(group):
1393+
def f_fast(group):
13941394
names.append(group.name)
13951395
return group.copy()
13961396

1397-
df.groupby('a', sort=False, group_keys=False).apply(f)
1398-
# we expect 2 zeros because we call ``f`` once to see if a faster route
1399-
# can be used.
1397+
df.groupby('a', sort=False, group_keys=False).apply(f_fast)
1398+
# gh-2936
1399+
# every group should appear once, i.e. apply is called once per group
14001400
expected_names = [0, 1, 2]
14011401
assert names == expected_names
14021402

1403+
names_slow = []
1404+
1405+
def f_slow(group):
1406+
names_slow.append(group.name)
1407+
return group
1408+
1409+
df.groupby('a', sort=False, group_keys=False).apply(f_slow)
1410+
assert names_slow == [0, 0, 1, 2]
1411+
14031412

14041413
def test_no_dummy_key_names(df):
14051414
# see gh-1291

pandas/tests/series/test_apply.py

+24
Original file line numberDiff line numberDiff line change
@@ -665,3 +665,27 @@ def test_map_missing_mixed(self, vals, mapping, exp):
665665
result = s.map(mapping)
666666

667667
tm.assert_series_equal(result, pd.Series(exp))
668+
669+
def test_apply_only_once(self):
670+
ser = pd.Series([0, 0, 1, 1, 2, 2], name="series")
671+
rows = []
672+
673+
def f(row):
674+
rows.append(row)
675+
return row
676+
ser.apply(f)
677+
# gh-2936
678+
# every row should appear once, i.e. apply is called once per row
679+
expected_names = [0, 0, 1, 1, 2, 2]
680+
assert rows == expected_names
681+
682+
# Rows should also only be applied once if the return
683+
# shape is different
684+
rows = []
685+
686+
def g(row):
687+
rows.append(row)
688+
return (row, row)
689+
ser.apply(g)
690+
expected_names = [0, 0, 1, 1, 2, 2]
691+
assert rows == expected_names

0 commit comments

Comments
 (0)