Skip to content

Commit e2337b0

Browse files
Backport PR #31456: BUG: Groupby.apply wasn't allowing for functions which return lists (#31541)
Co-authored-by: Marco Gorelli <[email protected]>
1 parent 34f6c7e commit e2337b0

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

doc/source/whatsnew/v1.0.1.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ including other versions of pandas.
1515

1616
Bug fixes
1717
~~~~~~~~~
18-
18+
- Bug in :meth:`GroupBy.apply` was raising ``TypeError`` if called with function which returned a non-pandas non-scalar object (e.g. a list) (:issue:`31441`)
1919

2020
Categorical
2121
^^^^^^^^^^^

pandas/_libs/reduction.pyx

+2-2
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,9 @@ def apply_frame_axis0(object frame, object f, object names,
501501

502502
if not is_scalar(piece):
503503
# Need to copy data to avoid appending references
504-
if hasattr(piece, "copy"):
504+
try:
505505
piece = piece.copy(deep="all")
506-
else:
506+
except (TypeError, AttributeError):
507507
piece = copy(piece)
508508

509509
results.append(piece)

pandas/tests/groupby/test_apply.py

+24
Original file line numberDiff line numberDiff line change
@@ -785,3 +785,27 @@ def test_apply_index_has_complex_internals(index):
785785
df = DataFrame({"group": [1, 1, 2], "value": [0, 1, 0]}, index=index)
786786
result = df.groupby("group").apply(lambda x: x)
787787
tm.assert_frame_equal(result, df)
788+
789+
790+
@pytest.mark.parametrize(
791+
"function, expected_values",
792+
[
793+
(lambda x: x.index.to_list(), [[0, 1], [2, 3]]),
794+
(lambda x: set(x.index.to_list()), [{0, 1}, {2, 3}]),
795+
(lambda x: tuple(x.index.to_list()), [(0, 1), (2, 3)]),
796+
(
797+
lambda x: {n: i for (n, i) in enumerate(x.index.to_list())},
798+
[{0: 0, 1: 1}, {0: 2, 1: 3}],
799+
),
800+
(
801+
lambda x: [{n: i} for (n, i) in enumerate(x.index.to_list())],
802+
[[{0: 0}, {1: 1}], [{0: 2}, {1: 3}]],
803+
),
804+
],
805+
)
806+
def test_apply_function_returns_non_pandas_non_scalar(function, expected_values):
807+
# GH 31441
808+
df = pd.DataFrame(["A", "A", "B", "B"], columns=["groups"])
809+
result = df.groupby("groups").apply(function)
810+
expected = pd.Series(expected_values, index=pd.Index(["A", "B"], name="groups"))
811+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)