Skip to content

Commit c4027ad

Browse files
mroeschkejreback
andauthored
BUG: groupby.transform/agg with engine='numba' and a MultiIndex (#47057)
Co-authored-by: Jeff Reback <[email protected]>
1 parent 7c01e13 commit c4027ad

File tree

4 files changed

+65
-1
lines changed

4 files changed

+65
-1
lines changed

doc/source/whatsnew/v1.4.3.rst

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Fixed regressions
1616
~~~~~~~~~~~~~~~~~
1717
- Fixed regression in :meth:`DataFrame.nsmallest` led to wrong results when ``np.nan`` in the sorting column (:issue:`46589`)
1818
- Fixed regression in :func:`read_fwf` raising ``ValueError`` when ``widths`` was specified with ``usecols`` (:issue:`46580`)
19+
- Fixed regression in :meth:`.Groupby.transform` and :meth:`.Groupby.agg` failing with ``engine="numba"`` when the index was a :class:`MultiIndex` (:issue:`46867`)
1920
- Fixed regression is :meth:`.Styler.to_latex` and :meth:`.Styler.to_html` where ``buf`` failed in combination with ``encoding`` (:issue:`47053`)
2021

2122
.. ---------------------------------------------------------------------------

pandas/core/groupby/groupby.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1310,7 +1310,16 @@ def _numba_prep(self, data):
13101310
sorted_ids = algorithms.take_nd(ids, sorted_index, allow_fill=False)
13111311

13121312
sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()
1313-
sorted_index_data = data.index.take(sorted_index).to_numpy()
1313+
if len(self.grouper.groupings) > 1:
1314+
raise NotImplementedError(
1315+
"More than 1 grouping labels are not supported with engine='numba'"
1316+
)
1317+
# GH 46867
1318+
index_data = data.index
1319+
if isinstance(index_data, MultiIndex):
1320+
group_key = self.grouper.groupings[0].name
1321+
index_data = index_data.get_level_values(group_key)
1322+
sorted_index_data = index_data.take(sorted_index).to_numpy()
13141323

13151324
starts, ends = lib.generate_slices(sorted_ids, ngroups)
13161325
return (

pandas/tests/groupby/aggregate/test_numba.py

+27
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,30 @@ def func_kwargs(values, index):
211211
)
212212
expected = DataFrame({"value": [1.0, 1.0, 1.0]})
213213
tm.assert_frame_equal(result, expected)
214+
215+
216+
@td.skip_if_no("numba")
217+
def test_multiindex_one_key(nogil, parallel, nopython):
218+
def numba_func(values, index):
219+
return 1
220+
221+
df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
222+
engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
223+
result = df.groupby("A").agg(
224+
numba_func, engine="numba", engine_kwargs=engine_kwargs
225+
)
226+
expected = DataFrame([1.0], index=Index([1], name="A"), columns=["C"])
227+
tm.assert_frame_equal(result, expected)
228+
229+
230+
@td.skip_if_no("numba")
231+
def test_multiindex_multi_key_not_supported(nogil, parallel, nopython):
232+
def numba_func(values, index):
233+
return 1
234+
235+
df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
236+
engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
237+
with pytest.raises(NotImplementedError, match="More than 1 grouping labels"):
238+
df.groupby(["A", "B"]).agg(
239+
numba_func, engine="numba", engine_kwargs=engine_kwargs
240+
)

pandas/tests/groupby/transform/test_numba.py

+27
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,30 @@ def func_kwargs(values, index):
199199
)
200200
expected = DataFrame({"value": [1.0, 1.0, 1.0]})
201201
tm.assert_frame_equal(result, expected)
202+
203+
204+
@td.skip_if_no("numba")
205+
def test_multiindex_one_key(nogil, parallel, nopython):
206+
def numba_func(values, index):
207+
return 1
208+
209+
df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
210+
engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
211+
result = df.groupby("A").transform(
212+
numba_func, engine="numba", engine_kwargs=engine_kwargs
213+
)
214+
expected = DataFrame([{"A": 1, "B": 2, "C": 1.0}]).set_index(["A", "B"])
215+
tm.assert_frame_equal(result, expected)
216+
217+
218+
@td.skip_if_no("numba")
219+
def test_multiindex_multi_key_not_supported(nogil, parallel, nopython):
220+
def numba_func(values, index):
221+
return 1
222+
223+
df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
224+
engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
225+
with pytest.raises(NotImplementedError, match="More than 1 grouping labels"):
226+
df.groupby(["A", "B"]).transform(
227+
numba_func, engine="numba", engine_kwargs=engine_kwargs
228+
)

0 commit comments

Comments
 (0)