diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 0ca6ef043fffb..2d882437805b2 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1380,7 +1380,7 @@ def _define_paths(self, func, *args, **kwargs): ) return fast_path, slow_path - def _choose_path(self, fast_path, slow_path, group): + def _choose_path(self, fast_path: Callable, slow_path: Callable, group: DataFrame): path = slow_path res = slow_path(group) @@ -1390,8 +1390,8 @@ def _choose_path(self, fast_path, slow_path, group): except AssertionError: raise except Exception: - # Hard to know ex-ante what exceptions `fast_path` might raise - # TODO: no test cases get here + # GH#29631 For user-defined function, we cant predict what may be + # raised; see test_transform.test_transform_fastpath_raises return path, res # verify fast path does not change columns (and names), otherwise diff --git a/pandas/tests/groupby/test_transform.py b/pandas/tests/groupby/test_transform.py index db44a4a57230c..3d9a349d94e10 100644 --- a/pandas/tests/groupby/test_transform.py +++ b/pandas/tests/groupby/test_transform.py @@ -1073,3 +1073,33 @@ def test_transform_lambda_with_datetimetz(): name="time", ) tm.assert_series_equal(result, expected) + + +def test_transform_fastpath_raises(): + # GH#29631 case where fastpath defined in groupby.generic _choose_path + # raises, but slow_path does not + + df = pd.DataFrame({"A": [1, 1, 2, 2], "B": [1, -1, 1, 2]}) + gb = df.groupby("A") + + def func(grp): + # we want a function such that func(frame) fails but func.apply(frame) + # works + if grp.ndim == 2: + # Ensure that fast_path fails + raise NotImplementedError("Don't cross the streams") + return grp * 2 + + # Check that the fastpath raises, see _transform_general + obj = gb._obj_with_exclusions + gen = gb.grouper.get_iterator(obj, axis=gb.axis) + fast_path, slow_path = gb._define_paths(func) + _, group = next(gen) + + with pytest.raises(NotImplementedError, match="Don't cross the streams"): + fast_path(group) + + result = gb.transform(func) + + expected = pd.DataFrame([2, -2, 2, 4], columns=["B"]) + tm.assert_frame_equal(result, expected)