From ec778ce95056dd181311236db8274785df7f26be Mon Sep 17 00:00:00 2001 From: 121238257 Date: Sat, 18 May 2024 17:08:05 +0100 Subject: [PATCH 1/7] Add Enhancement to latest vX.X.X.rst --- doc/source/whatsnew/v3.0.0.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index a15da861cfbec..fec16231c41db 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -42,6 +42,7 @@ Other enhancements - :meth:`DataFrame.corrwith` now accepts ``min_periods`` as optional arguments, as in :meth:`DataFrame.corr` and :meth:`Series.corr` (:issue:`9490`) - :meth:`DataFrame.cummin`, :meth:`DataFrame.cummax`, :meth:`DataFrame.cumprod` and :meth:`DataFrame.cumsum` methods now have a ``numeric_only`` parameter (:issue:`53072`) - :meth:`DataFrame.fillna` and :meth:`Series.fillna` can now accept ``value=None``; for non-object dtype the corresponding NA value will be used (:issue:`57723`) +- :meth:`GroupBy.transform` now accepts list-like arguments similar to :meth:`GroupBy.agg`, and supports :class:`NamedAgg` (:issue:`58318`) - :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`) - :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`) - Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`) From e39d8a4bc19af759ed9b45c7153cf6b36b558ba9 Mon Sep 17 00:00:00 2001 From: 121238257 Date: Sat, 18 May 2024 17:15:36 +0100 Subject: [PATCH 2/7] Implementation --- pandas/core/groupby/generic.py | 118 ++++++++++++++++++++++++++++++++- 1 file changed, 115 insertions(+), 3 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index a20577e8d3df9..2750372e26e4d 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -75,6 +75,7 @@ all_indexes_same, default_index, ) +from pandas.core.reshape.concat import concat from pandas.core.series import Series from pandas.core.sorting import get_group_index from pandas.core.util.numba_ import maybe_use_numba @@ -1863,15 +1864,126 @@ def _transform_general(self, func, engine, engine_kwargs, *args, **kwargs): 3 5 9 4 5 8 5 5 9 + + Using list-like arguments + + >>> df = pd.DataFrame({"col": list("aab"), "val": range(3)}) + >>> df.groupby("col").transform(["sum", "min"]) + val + sum min + 0 1 0 + 1 1 0 + 2 2 2 + + .. versionchanged:: 3.0.0 + + Named aggregation + + >>> df = pd.DataFrame({"A": list("aaabbbccc"), "B": range(9), "D": range(9, 18)}) + >>> df.groupby("A").transform( + ... b_min=pd.NamedAgg(column="B", aggfunc="min"), + ... c_sum=pd.NamedAgg(column="D", aggfunc="sum") + ... ) + b_min c_sum + 0 0 30 + 1 0 30 + 2 0 30 + 3 3 39 + 4 3 39 + 5 3 39 + 6 6 48 + 7 6 48 + 8 6 48 + + .. versionchanged:: 3.0.0 """ ) @Substitution(klass="DataFrame", example=__examples_dataframe_doc) @Appender(_transform_template) - def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): - return self._transform( - func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs + def transform( + self, + func: None + | (Callable | str | list[Callable | str] | dict[str, NamedAgg]) = None, + *args, + engine: str | None = None, + engine_kwargs: dict | None = None, + **kwargs, + ) -> DataFrame: + if func is None: + transformed_func = dict(kwargs.items()) + return self._transform_multiple_funcs( + transformed_func, *args, engine=engine, engine_kwargs=engine_kwargs + ) + else: + if isinstance(func, list): + func = maybe_mangle_lambdas(func) + return self._transform_multiple_funcs( + func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs + ) + else: + return self._transform( + func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs + ) + + def _transform_multiple_funcs( + self, + func: Any, + *args, + engine: str | None = None, + engine_kwargs: dict | None = None, + **kwargs, + ) -> DataFrame: + results = [] + if isinstance(func, dict): + for name, named_agg in func.items(): + column_name = named_agg.column + agg_func = named_agg.aggfunc + result = self._transform_single_column( + column_name, + agg_func, + *args, + engine=engine, + engine_kwargs=engine_kwargs, + **kwargs, + ) + result.name = name + results.append(result) + output = concat(results, axis=1) + elif isinstance(func, list): + col_names = [] + columns = [com.get_callable_name(f) or f for f in func] + func_pairs = zip(columns, func) + for name, func_item in func_pairs: + result = self._transform( + func_item, + *args, + engine=engine, + engine_kwargs=engine_kwargs, + **kwargs, + ) + results.append(result) + col_names.extend([(col, name) for col in result.columns]) + output = concat(results, ignore_index=True, axis=1) + arrays = [list(x) for x in zip(*col_names)] + output.columns = MultiIndex.from_arrays(arrays) + + return output + + def _transform_single_column( + self, + column_name: Hashable, + agg_func: Callable | str, + *args, + engine: str | None = None, + engine_kwargs: dict | None = None, + **kwargs, + ) -> Series: + data = self._gotitem(column_name, ndim=1) + result = data.transform( + agg_func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs ) + return result def _define_paths(self, func, *args, **kwargs): if isinstance(func, str): From a1b1d3d9f3f52f5e2ea9dca94d59f6899a7e1dec Mon Sep 17 00:00:00 2001 From: 121238257 Date: Sat, 18 May 2024 17:19:30 +0100 Subject: [PATCH 3/7] Tests --- .../tests/groupby/transform/test_transform.py | 39 +++++++++++++++++++ try.py | 7 ++++ 2 files changed, 46 insertions(+) create mode 100644 try.py diff --git a/pandas/tests/groupby/transform/test_transform.py b/pandas/tests/groupby/transform/test_transform.py index d6d545a8c4834..a50de932d40cf 100644 --- a/pandas/tests/groupby/transform/test_transform.py +++ b/pandas/tests/groupby/transform/test_transform.py @@ -19,6 +19,7 @@ date_range, ) import pandas._testing as tm +from pandas.core.groupby import NamedAgg from pandas.tests.groupby import get_groupby_method_args @@ -84,6 +85,44 @@ def demean(arr): tm.assert_frame_equal(result, expected) +def test_transform_with_namedagg(): + df = DataFrame({"A": list("aaabbbccc"), "B": range(9), "D": range(9, 18)}) + result = df.groupby("A").transform( + b_min=NamedAgg(column="B", aggfunc="min"), + d_sum=NamedAgg(column="D", aggfunc="sum"), + ) + expected = DataFrame( + { + "b_min": [0, 0, 0, 3, 3, 3, 6, 6, 6], + "d_sum": [30, 30, 30, 39, 39, 39, 48, 48, 48], + } + ) + tm.assert_frame_equal(result, expected) + + +def test_transform_with_list_like(): + df = DataFrame({"col": list("aab"), "val": range(3)}) + result = df.groupby("col").transform(["sum", "min"]) + expected = DataFrame({"val_sum": [1, 1, 2], "val_min": [0, 0, 2]}) + expected.columns = MultiIndex.from_tuples([("val", "sum"), ("val", "min")]) + tm.assert_frame_equal(result, expected) + + +def test_transform_with_duplicate_columns(): + df = DataFrame({"A": list("aaabbbccc"), "B": range(9, 18)}) + result = df.groupby("A").transform( + b_min=NamedAgg(column="B", aggfunc="min"), + b_max=NamedAgg(column="B", aggfunc="max"), + ) + expected = DataFrame( + { + "b_min": [9, 9, 9, 12, 12, 12, 15, 15, 15], + "b_max": [11, 11, 11, 14, 14, 14, 17, 17, 17], + } + ) + tm.assert_frame_equal(result, expected) + + def test_transform_fast(): df = DataFrame( { diff --git a/try.py b/try.py new file mode 100644 index 0000000000000..8397fe1d7ff31 --- /dev/null +++ b/try.py @@ -0,0 +1,7 @@ +def process_user(name: str, age: int) -> str: + return f"{name} is {age} years old." + + +print( + process_user("John", "twenty") +) # mypy will not catch this type error, but Python will run it without complaints From 78b6e90363b6723cfae95eef0c882c2203e67dd2 Mon Sep 17 00:00:00 2001 From: 121238257 Date: Thu, 23 May 2024 20:43:49 +0100 Subject: [PATCH 4/7] add dictionary option and fix list issue --- doc/source/whatsnew/v3.0.0.rst | 2 +- pandas/core/groupby/generic.py | 116 ++++++++++-------- .../tests/groupby/transform/test_transform.py | 32 +++-- 3 files changed, 92 insertions(+), 58 deletions(-) diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index fec16231c41db..a31e6868f53ac 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -42,7 +42,7 @@ Other enhancements - :meth:`DataFrame.corrwith` now accepts ``min_periods`` as optional arguments, as in :meth:`DataFrame.corr` and :meth:`Series.corr` (:issue:`9490`) - :meth:`DataFrame.cummin`, :meth:`DataFrame.cummax`, :meth:`DataFrame.cumprod` and :meth:`DataFrame.cumsum` methods now have a ``numeric_only`` parameter (:issue:`53072`) - :meth:`DataFrame.fillna` and :meth:`Series.fillna` can now accept ``value=None``; for non-object dtype the corresponding NA value will be used (:issue:`57723`) -- :meth:`GroupBy.transform` now accepts list-like arguments similar to :meth:`GroupBy.agg`, and supports :class:`NamedAgg` (:issue:`58318`) +- :meth:`GroupBy.transform` now accepts list-like arguments and dictionary arguments similar to :meth:`GroupBy.agg`, and supports :class:`NamedAgg` (:issue:`58318`) - :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`) - :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`) - Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 2750372e26e4d..b8a6fea42ec84 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1865,35 +1865,40 @@ def _transform_general(self, func, engine, engine_kwargs, *args, **kwargs): 4 5 8 5 5 9 - Using list-like arguments + List-like arguments - >>> df = pd.DataFrame({"col": list("aab"), "val": range(3)}) + >>> df = pd.DataFrame({"col": list("aab"), "val": range(3), "other_val": range(3)}) >>> df.groupby("col").transform(["sum", "min"]) - val - sum min - 0 1 0 - 1 1 0 - 2 2 2 + val other_val + sum min sum min + 0 1 0 1 0 + 1 1 0 1 0 + 2 2 2 2 2 + + .. versionchanged:: 3.0.0 + + Dictionary arguments + + >>> df = pd.DataFrame({"col": list("aab"), "val": range(3), "other_val": range(3)}) + >>> df.groupby("col").transform({"val": "sum", "other_val": "min"}) + val other_val + 0 1 0 + 1 1 0 + 2 2 2 .. versionchanged:: 3.0.0 Named aggregation - >>> df = pd.DataFrame({"A": list("aaabbbccc"), "B": range(9), "D": range(9, 18)}) - >>> df.groupby("A").transform( - ... b_min=pd.NamedAgg(column="B", aggfunc="min"), - ... c_sum=pd.NamedAgg(column="D", aggfunc="sum") + >>> df = pd.DataFrame({"col": list("aab"), "val": range(3), "other_val": range(3)}) + >>> df.groupby("col").transform( + ... val_sum=pd.NamedAgg(column="val", aggfunc="sum"), + ... other_min=pd.NamedAgg(column="other_val", aggfunc="min") ... ) - b_min c_sum - 0 0 30 - 1 0 30 - 2 0 30 - 3 3 39 - 4 3 39 - 5 3 39 - 6 6 48 - 7 6 48 - 8 6 48 + val_sum other_min + 0 1 0 + 1 1 0 + 2 2 2 .. versionchanged:: 3.0.0 """ @@ -1915,16 +1920,19 @@ def transform( return self._transform_multiple_funcs( transformed_func, *args, engine=engine, engine_kwargs=engine_kwargs ) + elif isinstance(func, dict): + return self._transform_multiple_funcs( + func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs + ) + elif isinstance(func, list): + func = maybe_mangle_lambdas(func) + return self._transform_multiple_funcs( + func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs + ) else: - if isinstance(func, list): - func = maybe_mangle_lambdas(func) - return self._transform_multiple_funcs( - func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs - ) - else: - return self._transform( - func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs - ) + return self._transform( + func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs + ) def _transform_multiple_funcs( self, @@ -1934,11 +1942,15 @@ def _transform_multiple_funcs( engine_kwargs: dict | None = None, **kwargs, ) -> DataFrame: - results = [] if isinstance(func, dict): - for name, named_agg in func.items(): - column_name = named_agg.column - agg_func = named_agg.aggfunc + results = [] + for name, agg in func.items(): + if isinstance(agg, NamedAgg): + column_name = agg.column + agg_func = agg.aggfunc + else: + column_name = name + agg_func = agg result = self._transform_single_column( column_name, agg_func, @@ -1951,21 +1963,27 @@ def _transform_multiple_funcs( results.append(result) output = concat(results, axis=1) elif isinstance(func, list): - col_names = [] - columns = [com.get_callable_name(f) or f for f in func] - func_pairs = zip(columns, func) - for name, func_item in func_pairs: - result = self._transform( - func_item, - *args, - engine=engine, - engine_kwargs=engine_kwargs, - **kwargs, - ) - results.append(result) - col_names.extend([(col, name) for col in result.columns]) - output = concat(results, ignore_index=True, axis=1) - arrays = [list(x) for x in zip(*col_names)] + results = [] + col_order = [] + for column in self.obj.columns: + if column in self.keys: + continue + column_results = [ + self._transform_single_column( + column, + agg_func, + *args, + engine=engine, + engine_kwargs=engine_kwargs, + **kwargs, + ).rename((column, agg_func)) + for agg_func in func + ] + combined_result = concat(column_results, axis=1) + results.append(combined_result) + col_order.extend([(column, f) for f in func]) + output = concat(results, axis=1) + arrays = [list(x) for x in zip(*col_order)] output.columns = MultiIndex.from_arrays(arrays) return output diff --git a/pandas/tests/groupby/transform/test_transform.py b/pandas/tests/groupby/transform/test_transform.py index a50de932d40cf..6dbfc0fd94e2d 100644 --- a/pandas/tests/groupby/transform/test_transform.py +++ b/pandas/tests/groupby/transform/test_transform.py @@ -85,6 +85,30 @@ def demean(arr): tm.assert_frame_equal(result, expected) +def test_transform_with_list_like(): + df = DataFrame({"col": list("aab"), "val": range(3), "another": range(3)}) + result = df.groupby("col").transform(["sum", "min"]) + expected = DataFrame( + { + ("val", "sum"): [1, 1, 2], + ("val", "min"): [0, 0, 2], + ("another", "sum"): [1, 1, 2], + ("another", "min"): [0, 0, 2], + } + ) + expected.columns = MultiIndex.from_tuples( + [("val", "sum"), ("val", "min"), ("another", "sum"), ("another", "min")] + ) + tm.assert_frame_equal(result, expected) + + +def test_transform_with_dict(): + df = DataFrame({"col": list("aab"), "val": range(3), "another": range(3)}) + result = df.groupby("col").transform({"val": "sum", "another": "min"}) + expected = DataFrame({"val": [1, 1, 2], "another": [0, 0, 2]}) + tm.assert_frame_equal(result, expected) + + def test_transform_with_namedagg(): df = DataFrame({"A": list("aaabbbccc"), "B": range(9), "D": range(9, 18)}) result = df.groupby("A").transform( @@ -100,14 +124,6 @@ def test_transform_with_namedagg(): tm.assert_frame_equal(result, expected) -def test_transform_with_list_like(): - df = DataFrame({"col": list("aab"), "val": range(3)}) - result = df.groupby("col").transform(["sum", "min"]) - expected = DataFrame({"val_sum": [1, 1, 2], "val_min": [0, 0, 2]}) - expected.columns = MultiIndex.from_tuples([("val", "sum"), ("val", "min")]) - tm.assert_frame_equal(result, expected) - - def test_transform_with_duplicate_columns(): df = DataFrame({"A": list("aaabbbccc"), "B": range(9, 18)}) result = df.groupby("A").transform( From f90a3e5b0d5727f5640fadc89fbfb5bfafbc7d56 Mon Sep 17 00:00:00 2001 From: 121238257 Date: Thu, 23 May 2024 21:49:29 +0100 Subject: [PATCH 5/7] fix mypy --- pandas/core/groupby/generic.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index b8a6fea42ec84..df5014c880eaa 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1965,8 +1965,9 @@ def _transform_multiple_funcs( elif isinstance(func, list): results = [] col_order = [] + keys_list = list(self.keys) if isinstance(self.keys, list) else [self.keys] for column in self.obj.columns: - if column in self.keys: + if column in keys_list: continue column_results = [ self._transform_single_column( @@ -1979,9 +1980,9 @@ def _transform_multiple_funcs( ).rename((column, agg_func)) for agg_func in func ] - combined_result = concat(column_results, axis=1) - results.append(combined_result) - col_order.extend([(column, f) for f in func]) + for col_result in column_results: + results.append(col_result) + col_order.append(col_result.name) output = concat(results, axis=1) arrays = [list(x) for x in zip(*col_order)] output.columns = MultiIndex.from_arrays(arrays) From 324f38c518b0d2e28d760a97fd903428bbf9292e Mon Sep 17 00:00:00 2001 From: 121238257 Date: Thu, 23 May 2024 22:45:57 +0100 Subject: [PATCH 6/7] ignore_index --- pandas/core/groupby/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index df5014c880eaa..45e3640e7541b 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1983,7 +1983,7 @@ def _transform_multiple_funcs( for col_result in column_results: results.append(col_result) col_order.append(col_result.name) - output = concat(results, axis=1) + output = concat(results, ignore_index=True, axis=1) arrays = [list(x) for x in zip(*col_order)] output.columns = MultiIndex.from_arrays(arrays) From 888e2ea8ddacfb8bb745426fc24d6cb54fc37186 Mon Sep 17 00:00:00 2001 From: 121238257 Date: Fri, 24 May 2024 08:33:05 +0100 Subject: [PATCH 7/7] remove accidental file --- try.py | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 try.py diff --git a/try.py b/try.py deleted file mode 100644 index 8397fe1d7ff31..0000000000000 --- a/try.py +++ /dev/null @@ -1,7 +0,0 @@ -def process_user(name: str, age: int) -> str: - return f"{name} is {age} years old." - - -print( - process_user("John", "twenty") -) # mypy will not catch this type error, but Python will run it without complaints