Skip to content

Commit 00cfdf5

Browse files
jbrockmendeljacobaustin123
authored andcommitted
TST: add test case for user-defined function taking correct path in groupby transform (pandas-dev#29631)
1 parent 0939732 commit 00cfdf5

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

pandas/core/groupby/generic.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1382,7 +1382,7 @@ def _define_paths(self, func, *args, **kwargs):
13821382
)
13831383
return fast_path, slow_path
13841384

1385-
def _choose_path(self, fast_path, slow_path, group):
1385+
def _choose_path(self, fast_path: Callable, slow_path: Callable, group: DataFrame):
13861386
path = slow_path
13871387
res = slow_path(group)
13881388

@@ -1392,8 +1392,8 @@ def _choose_path(self, fast_path, slow_path, group):
13921392
except AssertionError:
13931393
raise
13941394
except Exception:
1395-
# Hard to know ex-ante what exceptions `fast_path` might raise
1396-
# TODO: no test cases get here
1395+
# GH#29631 For user-defined function, we cant predict what may be
1396+
# raised; see test_transform.test_transform_fastpath_raises
13971397
return path, res
13981398

13991399
# verify fast path does not change columns (and names), otherwise

pandas/tests/groupby/test_transform.py

+30
Original file line numberDiff line numberDiff line change
@@ -1073,3 +1073,33 @@ def test_transform_lambda_with_datetimetz():
10731073
name="time",
10741074
)
10751075
tm.assert_series_equal(result, expected)
1076+
1077+
1078+
def test_transform_fastpath_raises():
1079+
# GH#29631 case where fastpath defined in groupby.generic _choose_path
1080+
# raises, but slow_path does not
1081+
1082+
df = pd.DataFrame({"A": [1, 1, 2, 2], "B": [1, -1, 1, 2]})
1083+
gb = df.groupby("A")
1084+
1085+
def func(grp):
1086+
# we want a function such that func(frame) fails but func.apply(frame)
1087+
# works
1088+
if grp.ndim == 2:
1089+
# Ensure that fast_path fails
1090+
raise NotImplementedError("Don't cross the streams")
1091+
return grp * 2
1092+
1093+
# Check that the fastpath raises, see _transform_general
1094+
obj = gb._obj_with_exclusions
1095+
gen = gb.grouper.get_iterator(obj, axis=gb.axis)
1096+
fast_path, slow_path = gb._define_paths(func)
1097+
_, group = next(gen)
1098+
1099+
with pytest.raises(NotImplementedError, match="Don't cross the streams"):
1100+
fast_path(group)
1101+
1102+
result = gb.transform(func)
1103+
1104+
expected = pd.DataFrame([2, -2, 2, 4], columns=["B"])
1105+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)