diff --git a/doc/source/whatsnew/v1.3.3.rst b/doc/source/whatsnew/v1.3.3.rst index 1340188c3d609..3dee3aa5e7c7a 100644 --- a/doc/source/whatsnew/v1.3.3.rst +++ b/doc/source/whatsnew/v1.3.3.rst @@ -25,7 +25,7 @@ Fixed regressions Bug fixes ~~~~~~~~~ -- +- Bug in :meth:`.DataFrameGroupBy.agg` and :meth:`.DataFrameGroupBy.transform` with ``engine="numba"`` where ``index`` data was not being correctly passed into ``func`` (:issue:`43133`) - .. --------------------------------------------------------------------------- diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 0080791a51a4b..5a70db517ad12 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1143,9 +1143,15 @@ def _numba_prep(self, func, data): sorted_ids = algorithms.take_nd(ids, sorted_index, allow_fill=False) sorted_data = data.take(sorted_index, axis=self.axis).to_numpy() + sorted_index_data = data.index.take(sorted_index).to_numpy() starts, ends = lib.generate_slices(sorted_ids, ngroups) - return starts, ends, sorted_index, sorted_data + return ( + starts, + ends, + sorted_index_data, + sorted_data, + ) @final def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs): diff --git a/pandas/tests/groupby/aggregate/test_numba.py b/pandas/tests/groupby/aggregate/test_numba.py index ba2d6eeb287c0..4b915cd4c29ae 100644 --- a/pandas/tests/groupby/aggregate/test_numba.py +++ b/pandas/tests/groupby/aggregate/test_numba.py @@ -173,3 +173,17 @@ def sum_last(values, index, n): result = grouped_x.agg(sum_last, 2, engine="numba") expected = Series([2.0] * 2, name="x", index=Index([0, 1], name="id")) tm.assert_series_equal(result, expected) + + +@td.skip_if_no("numba", "0.46.0") +def test_index_data_correctly_passed(): + # GH 43133 + def f(values, index): + return np.mean(index) + + df = DataFrame({"group": ["A", "A", "B"], "v": [4, 5, 6]}, index=[-1, -2, -3]) + result = df.groupby("group").aggregate(f, engine="numba") + expected = DataFrame( + [-1.5, -3.0], columns=["v"], index=Index(["A", "B"], name="group") + ) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index 8019071be72f3..b2d72aec0527f 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -164,3 +164,15 @@ def sum_last(values, index, n): result = grouped_x.transform(sum_last, 2, engine="numba") expected = Series([2.0] * 4, name="x") tm.assert_series_equal(result, expected) + + +@td.skip_if_no("numba", "0.46.0") +def test_index_data_correctly_passed(): + # GH 43133 + def f(values, index): + return index - 1 + + df = DataFrame({"group": ["A", "A", "B"], "v": [4, 5, 6]}, index=[-1, -2, -3]) + result = df.groupby("group").transform(f, engine="numba") + expected = DataFrame([-4.0, -3.0, -2.0], columns=["v"], index=[-1, -2, -3]) + tm.assert_frame_equal(result, expected)