Skip to content

Commit db24e50

Browse files
author
Diego Torres Quintanilla
committed
BUG: fix groupby.transform rename bug (#23461)
1 parent 391aded commit db24e50

File tree

3 files changed

+30
-4
lines changed

3 files changed

+30
-4
lines changed

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,7 @@ Groupby/Resample/Rolling
12991299
- :func:`RollingGroupby.agg` and :func:`ExpandingGroupby.agg` now support multiple aggregation functions as parameters (:issue:`15072`)
13001300
- Bug in :meth:`DataFrame.resample` and :meth:`Series.resample` when resampling by a weekly offset (``'W'``) across a DST transition (:issue:`9119`, :issue:`21459`)
13011301
- Bug in :meth:`DataFrame.expanding` in which the ``axis`` argument was not being respected during aggregations (:issue:`23372`)
1302+
- Bug in :meth:`pandas.core.groupby.DataFrameGroupBy.transform` which caused missing values when the input function can accept a :class:`DataFrame` but renames it (:issue:`23455`).
13021303

13031304
Reshaping
13041305
^^^^^^^^^

pandas/core/groupby/generic.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -583,14 +583,18 @@ def _choose_path(self, fast_path, slow_path, group):
583583
try:
584584
res_fast = fast_path(group)
585585

586-
# compare that we get the same results
586+
# verify fast path does not change columns (and names), otherwise
587+
# its results cannot be joined with those of the slow path
588+
if (getattr(res_fast, 'columns', None)
589+
!= getattr(group, 'columns', None)):
590+
return path, res
591+
# verify numerical equality with the slow path
587592
if res.shape == res_fast.shape:
588593
res_r = res.values.ravel()
589594
res_fast_r = res_fast.values.ravel()
590595
mask = notna(res_r)
591-
if (res_r[mask] == res_fast_r[mask]).all():
592-
path = fast_path
593-
596+
if (res_r[mask] == res_fast_r[mask]).all():
597+
path = fast_path
594598
except Exception:
595599
pass
596600
return path, res

pandas/tests/groupby/test_transform.py

+21
Original file line numberDiff line numberDiff line change
@@ -808,3 +808,24 @@ def test_any_all_np_func(func):
808808

809809
res = df.groupby('key')['val'].transform(func)
810810
tm.assert_series_equal(res, exp)
811+
812+
813+
def test_groupby_transform_rename():
814+
# https://github.com/pandas-dev/pandas/issues/23461
815+
def demean_rename(x):
816+
result = x - x.mean()
817+
818+
if isinstance(x, pd.Series):
819+
return result
820+
821+
result = result.rename(
822+
columns={c: '{}_demeaned'.format(c) for c in result.columns})
823+
824+
return result
825+
826+
df = pd.DataFrame({'group': list('ababa'),
827+
'value': [1, 1, 1, 2, 2]})
828+
expected = pd.DataFrame({'value': [-1. / 3, -0.5, -1. / 3, 0.5, 2. / 3]})
829+
830+
result = df.groupby('group').transform(demean_rename)
831+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)