Skip to content

Commit bfb80fa

Browse files
authored
BUG: Groupby.apply wasn't allowing for functions which return lists (#31456)
1 parent 01623f8 commit bfb80fa

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

doc/source/whatsnew/v1.0.1.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ including other versions of pandas.
1515

1616
Bug fixes
1717
~~~~~~~~~
18-
18+
- Bug in :meth:`GroupBy.apply` was raising ``TypeError`` if called with function which returned a non-pandas non-scalar object (e.g. a list) (:issue:`31441`)
1919

2020
Categorical
2121
^^^^^^^^^^^

pandas/_libs/reduction.pyx

+2-2
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,9 @@ def apply_frame_axis0(object frame, object f, object names,
501501

502502
if not is_scalar(piece):
503503
# Need to copy data to avoid appending references
504-
if hasattr(piece, "copy"):
504+
try:
505505
piece = piece.copy(deep="all")
506-
else:
506+
except (TypeError, AttributeError):
507507
piece = copy(piece)
508508

509509
results.append(piece)

pandas/tests/groupby/test_apply.py

+24
Original file line numberDiff line numberDiff line change
@@ -827,3 +827,27 @@ def test_apply_index_has_complex_internals(index):
827827
df = DataFrame({"group": [1, 1, 2], "value": [0, 1, 0]}, index=index)
828828
result = df.groupby("group").apply(lambda x: x)
829829
tm.assert_frame_equal(result, df)
830+
831+
832+
@pytest.mark.parametrize(
833+
"function, expected_values",
834+
[
835+
(lambda x: x.index.to_list(), [[0, 1], [2, 3]]),
836+
(lambda x: set(x.index.to_list()), [{0, 1}, {2, 3}]),
837+
(lambda x: tuple(x.index.to_list()), [(0, 1), (2, 3)]),
838+
(
839+
lambda x: {n: i for (n, i) in enumerate(x.index.to_list())},
840+
[{0: 0, 1: 1}, {0: 2, 1: 3}],
841+
),
842+
(
843+
lambda x: [{n: i} for (n, i) in enumerate(x.index.to_list())],
844+
[[{0: 0}, {1: 1}], [{0: 2}, {1: 3}]],
845+
),
846+
],
847+
)
848+
def test_apply_function_returns_non_pandas_non_scalar(function, expected_values):
849+
# GH 31441
850+
df = pd.DataFrame(["A", "A", "B", "B"], columns=["groups"])
851+
result = df.groupby("groups").apply(function)
852+
expected = pd.Series(expected_values, index=pd.Index(["A", "B"], name="groups"))
853+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)