From 05ed992dbbbb964c142e4c6bc297fce91c8d6725 Mon Sep 17 00:00:00 2001 From: Richard Shadrach Date: Sun, 30 Jan 2022 11:07:11 -0500 Subject: [PATCH 1/2] BUG: DataFrameGroupby.transform("size") fails --- pandas/core/groupby/generic.py | 9 ++-- pandas/core/groupby/groupby.py | 2 +- .../tests/groupby/transform/test_transform.py | 46 ++++++------------- 3 files changed, 20 insertions(+), 37 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 175067b4b7c20..539519ee04dd0 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -472,7 +472,7 @@ def _transform_general(self, func: Callable, *args, **kwargs) -> Series: result.name = self.obj.name return result - def _can_use_transform_fast(self, result) -> bool: + def _can_use_transform_fast(self, func: str, result) -> bool: return True def filter(self, func, dropna: bool = True, *args, **kwargs): @@ -1185,9 +1185,10 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs ) - def _can_use_transform_fast(self, result) -> bool: - return isinstance(result, DataFrame) and result.columns.equals( - self._obj_with_exclusions.columns + def _can_use_transform_fast(self, func: str, result) -> bool: + return func == "size" or ( + isinstance(result, DataFrame) + and result.columns.equals(self._obj_with_exclusions.columns) ) def _define_paths(self, func, *args, **kwargs): diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index b682723cb10de..4eb907e06adf1 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1650,7 +1650,7 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): with com.temp_setattr(self, "observed", True): result = getattr(self, func)(*args, **kwargs) - if self._can_use_transform_fast(result): + if self._can_use_transform_fast(func, result): return self._wrap_transform_fast_result(result) # only reached for DataFrameGroupBy diff --git a/pandas/tests/groupby/transform/test_transform.py b/pandas/tests/groupby/transform/test_transform.py index 12a25a1e61211..7f130dea4e0a7 100644 --- a/pandas/tests/groupby/transform/test_transform.py +++ b/pandas/tests/groupby/transform/test_transform.py @@ -803,39 +803,29 @@ def test_transform_with_non_scalar_group(): @pytest.mark.parametrize( - "cols,exp,comp_func", + "cols,expected", [ - ("a", Series([1, 1, 1], name="a"), tm.assert_series_equal), + ("a", Series([1, 1, 1], name="a")), ( ["a", "c"], DataFrame({"a": [1, 1, 1], "c": [1, 1, 1]}), - tm.assert_frame_equal, ), ], ) @pytest.mark.parametrize("agg_func", ["count", "rank", "size"]) -def test_transform_numeric_ret(cols, exp, comp_func, agg_func, request): - if agg_func == "size" and isinstance(cols, list): - # https://github.com/pytest-dev/pytest/issues/6300 - # workaround to xfail fixture/param permutations - reason = "'size' transformation not supported with NDFrameGroupy" - request.node.add_marker(pytest.mark.xfail(reason=reason)) - - # GH 19200 +def test_transform_numeric_ret(cols, expected, agg_func): + # GH#19200 and GH#27469 df = DataFrame( {"a": date_range("2018-01-01", periods=3), "b": range(3), "c": range(7, 10)} ) - - warn = FutureWarning - if isinstance(exp, Series) or agg_func != "size": - warn = None - with tm.assert_produces_warning(warn, match="Dropping invalid columns"): - result = df.groupby("b")[cols].transform(agg_func) + result = df.groupby("b")[cols].transform(agg_func) if agg_func == "rank": - exp = exp.astype("float") - - comp_func(result, exp) + expected = expected.astype("float") + elif agg_func == "size" and cols == ["a", "c"]: + # transform("size") returns a Series + expected = expected["a"].rename(None) + tm.assert_equal(result, expected) def test_transform_ffill(): @@ -1131,27 +1121,19 @@ def test_transform_agg_by_name(request, reduction_func, obj): request.node.add_marker( pytest.mark.xfail(reason="TODO: g.transform('ngroup') doesn't work") ) - if func == "size" and obj.ndim == 2: # GH#27469 - request.node.add_marker( - pytest.mark.xfail(reason="TODO: g.transform('size') doesn't work") - ) if func == "corrwith" and isinstance(obj, Series): # GH#32293 request.node.add_marker( pytest.mark.xfail(reason="TODO: implement SeriesGroupBy.corrwith") ) args = {"nth": [0], "quantile": [0.5], "corrwith": [obj]}.get(func, []) - - warn = None - if isinstance(obj, DataFrame) and func == "size": - warn = FutureWarning - - with tm.assert_produces_warning(warn, match="Dropping invalid columns"): - result = g.transform(func, *args) + result = g.transform(func, *args) # this is the *definition* of a transformation tm.assert_index_equal(result.index, obj.index) - if hasattr(obj, "columns"): + + if func != "size" and obj.ndim == 2: + # size returns a Series, unlike other transforms tm.assert_index_equal(result.columns, obj.columns) # verify that values were broadcasted across each group From 3933814f679a5fe15b516ec81501d771621a7732 Mon Sep 17 00:00:00 2001 From: Richard Shadrach Date: Sun, 30 Jan 2022 11:15:02 -0500 Subject: [PATCH 2/2] whatsnew --- doc/source/whatsnew/v1.5.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v1.5.0.rst b/doc/source/whatsnew/v1.5.0.rst index 1d4054d5ea0f1..d558ebb673ad8 100644 --- a/doc/source/whatsnew/v1.5.0.rst +++ b/doc/source/whatsnew/v1.5.0.rst @@ -306,7 +306,7 @@ Plotting Groupby/resample/rolling ^^^^^^^^^^^^^^^^^^^^^^^^ - Bug in :meth:`DataFrame.resample` ignoring ``closed="right"`` on :class:`TimedeltaIndex` (:issue:`45414`) -- +- Bug in :meth:`.DataFrameGroupBy.transform` fails when the input DataFrame has multiple columns (:issue:`27469`) Reshaping ^^^^^^^^^