Skip to content

Commit 72c87f4

Browse files
authored
BUG: DataFrameGroupby.transform("size") fails (#45716)
1 parent e38ffb6 commit 72c87f4

File tree

4 files changed

+21
-38
lines changed

4 files changed

+21
-38
lines changed

doc/source/whatsnew/v1.5.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ Plotting
308308
Groupby/resample/rolling
309309
^^^^^^^^^^^^^^^^^^^^^^^^
310310
- Bug in :meth:`DataFrame.resample` ignoring ``closed="right"`` on :class:`TimedeltaIndex` (:issue:`45414`)
311-
-
311+
- Bug in :meth:`.DataFrameGroupBy.transform` fails when the input DataFrame has multiple columns (:issue:`27469`)
312312

313313
Reshaping
314314
^^^^^^^^^

pandas/core/groupby/generic.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def _transform_general(self, func: Callable, *args, **kwargs) -> Series:
472472
result.name = self.obj.name
473473
return result
474474

475-
def _can_use_transform_fast(self, result) -> bool:
475+
def _can_use_transform_fast(self, func: str, result) -> bool:
476476
return True
477477

478478
def filter(self, func, dropna: bool = True, *args, **kwargs):
@@ -1185,9 +1185,10 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
11851185
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
11861186
)
11871187

1188-
def _can_use_transform_fast(self, result) -> bool:
1189-
return isinstance(result, DataFrame) and result.columns.equals(
1190-
self._obj_with_exclusions.columns
1188+
def _can_use_transform_fast(self, func: str, result) -> bool:
1189+
return func == "size" or (
1190+
isinstance(result, DataFrame)
1191+
and result.columns.equals(self._obj_with_exclusions.columns)
11911192
)
11921193

11931194
def _define_paths(self, func, *args, **kwargs):

pandas/core/groupby/groupby.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1650,7 +1650,7 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
16501650
with com.temp_setattr(self, "observed", True):
16511651
result = getattr(self, func)(*args, **kwargs)
16521652

1653-
if self._can_use_transform_fast(result):
1653+
if self._can_use_transform_fast(func, result):
16541654
return self._wrap_transform_fast_result(result)
16551655

16561656
# only reached for DataFrameGroupBy

pandas/tests/groupby/transform/test_transform.py

+14-32
Original file line numberDiff line numberDiff line change
@@ -803,39 +803,29 @@ def test_transform_with_non_scalar_group():
803803

804804

805805
@pytest.mark.parametrize(
806-
"cols,exp,comp_func",
806+
"cols,expected",
807807
[
808-
("a", Series([1, 1, 1], name="a"), tm.assert_series_equal),
808+
("a", Series([1, 1, 1], name="a")),
809809
(
810810
["a", "c"],
811811
DataFrame({"a": [1, 1, 1], "c": [1, 1, 1]}),
812-
tm.assert_frame_equal,
813812
),
814813
],
815814
)
816815
@pytest.mark.parametrize("agg_func", ["count", "rank", "size"])
817-
def test_transform_numeric_ret(cols, exp, comp_func, agg_func, request):
818-
if agg_func == "size" and isinstance(cols, list):
819-
# https://github.com/pytest-dev/pytest/issues/6300
820-
# workaround to xfail fixture/param permutations
821-
reason = "'size' transformation not supported with NDFrameGroupy"
822-
request.node.add_marker(pytest.mark.xfail(reason=reason))
823-
824-
# GH 19200
816+
def test_transform_numeric_ret(cols, expected, agg_func):
817+
# GH#19200 and GH#27469
825818
df = DataFrame(
826819
{"a": date_range("2018-01-01", periods=3), "b": range(3), "c": range(7, 10)}
827820
)
828-
829-
warn = FutureWarning
830-
if isinstance(exp, Series) or agg_func != "size":
831-
warn = None
832-
with tm.assert_produces_warning(warn, match="Dropping invalid columns"):
833-
result = df.groupby("b")[cols].transform(agg_func)
821+
result = df.groupby("b")[cols].transform(agg_func)
834822

835823
if agg_func == "rank":
836-
exp = exp.astype("float")
837-
838-
comp_func(result, exp)
824+
expected = expected.astype("float")
825+
elif agg_func == "size" and cols == ["a", "c"]:
826+
# transform("size") returns a Series
827+
expected = expected["a"].rename(None)
828+
tm.assert_equal(result, expected)
839829

840830

841831
def test_transform_ffill():
@@ -1131,27 +1121,19 @@ def test_transform_agg_by_name(request, reduction_func, obj):
11311121
request.node.add_marker(
11321122
pytest.mark.xfail(reason="TODO: g.transform('ngroup') doesn't work")
11331123
)
1134-
if func == "size" and obj.ndim == 2: # GH#27469
1135-
request.node.add_marker(
1136-
pytest.mark.xfail(reason="TODO: g.transform('size') doesn't work")
1137-
)
11381124
if func == "corrwith" and isinstance(obj, Series): # GH#32293
11391125
request.node.add_marker(
11401126
pytest.mark.xfail(reason="TODO: implement SeriesGroupBy.corrwith")
11411127
)
11421128

11431129
args = {"nth": [0], "quantile": [0.5], "corrwith": [obj]}.get(func, [])
1144-
1145-
warn = None
1146-
if isinstance(obj, DataFrame) and func == "size":
1147-
warn = FutureWarning
1148-
1149-
with tm.assert_produces_warning(warn, match="Dropping invalid columns"):
1150-
result = g.transform(func, *args)
1130+
result = g.transform(func, *args)
11511131

11521132
# this is the *definition* of a transformation
11531133
tm.assert_index_equal(result.index, obj.index)
1154-
if hasattr(obj, "columns"):
1134+
1135+
if func != "size" and obj.ndim == 2:
1136+
# size returns a Series, unlike other transforms
11551137
tm.assert_index_equal(result.columns, obj.columns)
11561138

11571139
# verify that values were broadcasted across each group

0 commit comments

Comments
 (0)