Skip to content

Commit 7169acf

Browse files
committed
groupy apply: Ensure same index is returned for slow and fast path
1 parent c815ffa commit 7169acf

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

doc/source/whatsnew/v1.0.1.rst

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Deprecations
2424
Bug fixes
2525
~~~~~~~~~
2626
- Bug in :meth:`GroupBy.apply` was raising ``TypeError`` if called with function which returned a non-pandas non-scalar object (e.g. a list) (:issue:`31441`)
27+
- Bug in :meth:`GroupBy.apply` where the output index type was depending on internals (:issue:`31612`)
2728

2829
Categorical
2930
^^^^^^^^^^^

pandas/_libs/reduction.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def apply_frame_axis0(object frame, object f, object names,
493493
# Need to infer if low level index slider will cause segfaults
494494
require_slow_apply = i == 0 and piece is chunk
495495
try:
496-
if piece.index is not chunk.index:
496+
if not piece.index.equals(chunk.index):
497497
mutated = True
498498
except AttributeError:
499499
# `piece` might not have an index, could be e.g. an int

pandas/tests/groupby/test_apply.py

+17
Original file line numberDiff line numberDiff line change
@@ -851,3 +851,20 @@ def test_apply_function_returns_non_pandas_non_scalar(function, expected_values)
851851
result = df.groupby("groups").apply(function)
852852
expected = pd.Series(expected_values, index=pd.Index(["A", "B"], name="groups"))
853853
tm.assert_series_equal(result, expected)
854+
855+
856+
def test_apply_fast_slow_identical():
857+
858+
df = DataFrame({"A": [0, 0, 1], "b": range(3)})
859+
860+
def slow(group):
861+
# slow apply because of check `result is input`, c.f. https://github.com/pandas-dev/pandas/blob/44782c0809e296a8e57b7f77d963e999c7e0f4a7/pandas/_libs/reduction.pyx#L494
862+
return group
863+
864+
def fast(group):
865+
return group.copy()
866+
867+
fast_df = df.groupby("A").apply(fast)
868+
slow_df = df.groupby("A").apply(slow)
869+
870+
tm.assert_frame_equal(fast_df, slow_df)

0 commit comments

Comments
 (0)