Skip to content

Commit 806faba

Browse files
authored
Backport PR #47057 on branch 1.4.x (BUG: groupby.transform/agg with engine='numba' and a MultiIndex) (#47062)
BUG: groupby.transform/agg with engine='numba' and a MultiIndex (#47057) Co-authored-by: Jeff Reback <[email protected]> (cherry picked from commit c4027ad)
1 parent 4591849 commit 806faba

File tree

4 files changed

+66
-2
lines changed

4 files changed

+66
-2
lines changed

doc/source/whatsnew/v1.4.3.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ Fixed regressions
1616
~~~~~~~~~~~~~~~~~
1717
- Fixed regression in :meth:`DataFrame.nsmallest` led to wrong results when ``np.nan`` in the sorting column (:issue:`46589`)
1818
- Fixed regression in :func:`read_fwf` raising ``ValueError`` when ``widths`` was specified with ``usecols`` (:issue:`46580`)
19-
-
19+
- Fixed regression in :meth:`.Groupby.transform` and :meth:`.Groupby.agg` failing with ``engine="numba"`` when the index was a :class:`MultiIndex` (:issue:`46867`)
20+
- Fixed regression is :meth:`.Styler.to_latex` and :meth:`.Styler.to_html` where ``buf`` failed in combination with ``encoding`` (:issue:`47053`)
2021

2122
.. ---------------------------------------------------------------------------
2223

pandas/core/groupby/groupby.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1257,7 +1257,16 @@ def _numba_prep(self, func, data):
12571257
sorted_ids = algorithms.take_nd(ids, sorted_index, allow_fill=False)
12581258

12591259
sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()
1260-
sorted_index_data = data.index.take(sorted_index).to_numpy()
1260+
if len(self.grouper.groupings) > 1:
1261+
raise NotImplementedError(
1262+
"More than 1 grouping labels are not supported with engine='numba'"
1263+
)
1264+
# GH 46867
1265+
index_data = data.index
1266+
if isinstance(index_data, MultiIndex):
1267+
group_key = self.grouper.groupings[0].name
1268+
index_data = index_data.get_level_values(group_key)
1269+
sorted_index_data = index_data.take(sorted_index).to_numpy()
12611270

12621271
starts, ends = lib.generate_slices(sorted_ids, ngroups)
12631272
return (

pandas/tests/groupby/aggregate/test_numba.py

+27
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,30 @@ def f(values, index):
187187
[-1.5, -3.0], columns=["v"], index=Index(["A", "B"], name="group")
188188
)
189189
tm.assert_frame_equal(result, expected)
190+
191+
192+
@td.skip_if_no("numba")
193+
def test_multiindex_one_key(nogil, parallel, nopython):
194+
def numba_func(values, index):
195+
return 1
196+
197+
df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
198+
engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
199+
result = df.groupby("A").agg(
200+
numba_func, engine="numba", engine_kwargs=engine_kwargs
201+
)
202+
expected = DataFrame([1.0], index=Index([1], name="A"), columns=["C"])
203+
tm.assert_frame_equal(result, expected)
204+
205+
206+
@td.skip_if_no("numba")
207+
def test_multiindex_multi_key_not_supported(nogil, parallel, nopython):
208+
def numba_func(values, index):
209+
return 1
210+
211+
df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
212+
engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
213+
with pytest.raises(NotImplementedError, match="More than 1 grouping labels"):
214+
df.groupby(["A", "B"]).agg(
215+
numba_func, engine="numba", engine_kwargs=engine_kwargs
216+
)

pandas/tests/groupby/transform/test_numba.py

+27
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,30 @@ def f(values, index):
176176
result = df.groupby("group").transform(f, engine="numba")
177177
expected = DataFrame([-4.0, -3.0, -2.0], columns=["v"], index=[-1, -2, -3])
178178
tm.assert_frame_equal(result, expected)
179+
180+
181+
@td.skip_if_no("numba")
182+
def test_multiindex_one_key(nogil, parallel, nopython):
183+
def numba_func(values, index):
184+
return 1
185+
186+
df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
187+
engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
188+
result = df.groupby("A").transform(
189+
numba_func, engine="numba", engine_kwargs=engine_kwargs
190+
)
191+
expected = DataFrame([{"A": 1, "B": 2, "C": 1.0}]).set_index(["A", "B"])
192+
tm.assert_frame_equal(result, expected)
193+
194+
195+
@td.skip_if_no("numba")
196+
def test_multiindex_multi_key_not_supported(nogil, parallel, nopython):
197+
def numba_func(values, index):
198+
return 1
199+
200+
df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
201+
engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
202+
with pytest.raises(NotImplementedError, match="More than 1 grouping labels"):
203+
df.groupby(["A", "B"]).transform(
204+
numba_func, engine="numba", engine_kwargs=engine_kwargs
205+
)

0 commit comments

Comments
 (0)