Skip to content

Commit bd98841

Browse files
DiegoAlbertoTorresjreback
authored andcommitted
BUG: fix groupby.transform rename bug (#23461) (#23463)
1 parent c95bfd1 commit bd98841

File tree

3 files changed

+31
-4
lines changed

3 files changed

+31
-4
lines changed

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1308,6 +1308,7 @@ Groupby/Resample/Rolling
13081308
- :func:`RollingGroupby.agg` and :func:`ExpandingGroupby.agg` now support multiple aggregation functions as parameters (:issue:`15072`)
13091309
- Bug in :meth:`DataFrame.resample` and :meth:`Series.resample` when resampling by a weekly offset (``'W'``) across a DST transition (:issue:`9119`, :issue:`21459`)
13101310
- Bug in :meth:`DataFrame.expanding` in which the ``axis`` argument was not being respected during aggregations (:issue:`23372`)
1311+
- Bug in :meth:`pandas.core.groupby.DataFrameGroupBy.transform` which caused missing values when the input function can accept a :class:`DataFrame` but renames it (:issue:`23455`).
13111312

13121313
Reshaping
13131314
^^^^^^^^^

pandas/core/groupby/generic.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -586,14 +586,17 @@ def _choose_path(self, fast_path, slow_path, group):
586586
try:
587587
res_fast = fast_path(group)
588588

589-
# compare that we get the same results
589+
# verify fast path does not change columns (and names), otherwise
590+
# its results cannot be joined with those of the slow path
591+
if res_fast.columns != group.columns:
592+
return path, res
593+
# verify numerical equality with the slow path
590594
if res.shape == res_fast.shape:
591595
res_r = res.values.ravel()
592596
res_fast_r = res_fast.values.ravel()
593597
mask = notna(res_r)
594-
if (res_r[mask] == res_fast_r[mask]).all():
595-
path = fast_path
596-
598+
if (res_r[mask] == res_fast_r[mask]).all():
599+
path = fast_path
597600
except Exception:
598601
pass
599602
return path, res

pandas/tests/groupby/test_transform.py

+23
Original file line numberDiff line numberDiff line change
@@ -808,3 +808,26 @@ def test_any_all_np_func(func):
808808

809809
res = df.groupby('key')['val'].transform(func)
810810
tm.assert_series_equal(res, exp)
811+
812+
813+
def test_groupby_transform_rename():
814+
# https://github.com/pandas-dev/pandas/issues/23461
815+
def demean_rename(x):
816+
result = x - x.mean()
817+
818+
if isinstance(x, pd.Series):
819+
return result
820+
821+
result = result.rename(
822+
columns={c: '{}_demeaned'.format(c) for c in result.columns})
823+
824+
return result
825+
826+
df = pd.DataFrame({'group': list('ababa'),
827+
'value': [1, 1, 1, 2, 2]})
828+
expected = pd.DataFrame({'value': [-1. / 3, -0.5, -1. / 3, 0.5, 2. / 3]})
829+
830+
result = df.groupby('group').transform(demean_rename)
831+
tm.assert_frame_equal(result, expected)
832+
result_single = df.groupby('group').value.transform(demean_rename)
833+
tm.assert_series_equal(result_single, expected['value'])

0 commit comments

Comments
 (0)