Skip to content

Enh/group by.transform should accept similar arguments to group by.agg #58773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Other enhancements
- :meth:`DataFrame.corrwith` now accepts ``min_periods`` as optional arguments, as in :meth:`DataFrame.corr` and :meth:`Series.corr` (:issue:`9490`)
- :meth:`DataFrame.cummin`, :meth:`DataFrame.cummax`, :meth:`DataFrame.cumprod` and :meth:`DataFrame.cumsum` methods now have a ``numeric_only`` parameter (:issue:`53072`)
- :meth:`DataFrame.fillna` and :meth:`Series.fillna` can now accept ``value=None``; for non-object dtype the corresponding NA value will be used (:issue:`57723`)
- :meth:`GroupBy.transform` now accepts list-like arguments and dictionary arguments similar to :meth:`GroupBy.agg`, and supports :class:`NamedAgg` (:issue:`58318`)
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
- Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`)
Expand Down
137 changes: 134 additions & 3 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
all_indexes_same,
default_index,
)
from pandas.core.reshape.concat import concat
from pandas.core.series import Series
from pandas.core.sorting import get_group_index
from pandas.core.util.numba_ import maybe_use_numba
Expand Down Expand Up @@ -1863,15 +1864,145 @@ def _transform_general(self, func, engine, engine_kwargs, *args, **kwargs):
3 5 9
4 5 8
5 5 9

List-like arguments

>>> df = pd.DataFrame({"col": list("aab"), "val": range(3), "other_val": range(3)})
>>> df.groupby("col").transform(["sum", "min"])
val other_val
sum min sum min
0 1 0 1 0
1 1 0 1 0
2 2 2 2 2

.. versionchanged:: 3.0.0

Dictionary arguments

>>> df = pd.DataFrame({"col": list("aab"), "val": range(3), "other_val": range(3)})
>>> df.groupby("col").transform({"val": "sum", "other_val": "min"})
val other_val
0 1 0
1 1 0
2 2 2

.. versionchanged:: 3.0.0

Named aggregation

>>> df = pd.DataFrame({"col": list("aab"), "val": range(3), "other_val": range(3)})
>>> df.groupby("col").transform(
... val_sum=pd.NamedAgg(column="val", aggfunc="sum"),
... other_min=pd.NamedAgg(column="other_val", aggfunc="min")
... )
val_sum other_min
0 1 0
1 1 0
2 2 2

.. versionchanged:: 3.0.0
"""
)

@Substitution(klass="DataFrame", example=__examples_dataframe_doc)
@Appender(_transform_template)
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
return self._transform(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
def transform(
self,
func: None
| (Callable | str | list[Callable | str] | dict[str, NamedAgg]) = None,
*args,
engine: str | None = None,
engine_kwargs: dict | None = None,
**kwargs,
) -> DataFrame:
if func is None:
transformed_func = dict(kwargs.items())
return self._transform_multiple_funcs(
transformed_func, *args, engine=engine, engine_kwargs=engine_kwargs
)
elif isinstance(func, dict):
return self._transform_multiple_funcs(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)
elif isinstance(func, list):
func = maybe_mangle_lambdas(func)
return self._transform_multiple_funcs(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)
else:
return self._transform(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)

def _transform_multiple_funcs(
self,
func: Any,
*args,
engine: str | None = None,
engine_kwargs: dict | None = None,
**kwargs,
) -> DataFrame:
if isinstance(func, dict):
results = []
for name, agg in func.items():
if isinstance(agg, NamedAgg):
column_name = agg.column
agg_func = agg.aggfunc
else:
column_name = name
agg_func = agg
result = self._transform_single_column(
column_name,
agg_func,
*args,
engine=engine,
engine_kwargs=engine_kwargs,
**kwargs,
)
result.name = name
results.append(result)
output = concat(results, axis=1)
elif isinstance(func, list):
results = []
col_order = []
keys_list = list(self.keys) if isinstance(self.keys, list) else [self.keys]
for column in self.obj.columns:
if column in keys_list:
continue
Comment on lines +1968 to +1971
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iterate over self._obj_with_exclusions instead. Then you don't need keys_list.

column_results = [
self._transform_single_column(
column,
agg_func,
*args,
engine=engine,
engine_kwargs=engine_kwargs,
**kwargs,
).rename((column, agg_func))
for agg_func in func
]
for col_result in column_results:
results.append(col_result)
col_order.append(col_result.name)
output = concat(results, ignore_index=True, axis=1)
arrays = [list(x) for x in zip(*col_order)]
output.columns = MultiIndex.from_arrays(arrays)

return output

def _transform_single_column(
self,
column_name: Hashable,
agg_func: Callable | str,
*args,
engine: str | None = None,
engine_kwargs: dict | None = None,
**kwargs,
) -> Series:
data = self._gotitem(column_name, ndim=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will fail if the input has duplicate column names. I would be okay with not supporting duplicate columns here (and raising a clear error message), but this would need further support from other maintainers. The other option is to make this work on duplicate column names.

result = data.transform(
agg_func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)
return result

def _define_paths(self, func, *args, **kwargs):
if isinstance(func, str):
Expand Down
55 changes: 55 additions & 0 deletions pandas/tests/groupby/transform/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
date_range,
)
import pandas._testing as tm
from pandas.core.groupby import NamedAgg
from pandas.tests.groupby import get_groupby_method_args


Expand Down Expand Up @@ -84,6 +85,60 @@ def demean(arr):
tm.assert_frame_equal(result, expected)


def test_transform_with_list_like():
df = DataFrame({"col": list("aab"), "val": range(3), "another": range(3)})
Comment on lines +88 to +89
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add references to each test, e.g.

def test_transform_with_list_like():
    # GH#58318

result = df.groupby("col").transform(["sum", "min"])
expected = DataFrame(
{
("val", "sum"): [1, 1, 2],
("val", "min"): [0, 0, 2],
("another", "sum"): [1, 1, 2],
("another", "min"): [0, 0, 2],
}
)
expected.columns = MultiIndex.from_tuples(
[("val", "sum"), ("val", "min"), ("another", "sum"), ("another", "min")]
)
tm.assert_frame_equal(result, expected)


def test_transform_with_dict():
df = DataFrame({"col": list("aab"), "val": range(3), "another": range(3)})
result = df.groupby("col").transform({"val": "sum", "another": "min"})
expected = DataFrame({"val": [1, 1, 2], "another": [0, 0, 2]})
tm.assert_frame_equal(result, expected)


def test_transform_with_namedagg():
df = DataFrame({"A": list("aaabbbccc"), "B": range(9), "D": range(9, 18)})
result = df.groupby("A").transform(
b_min=NamedAgg(column="B", aggfunc="min"),
d_sum=NamedAgg(column="D", aggfunc="sum"),
)
expected = DataFrame(
{
"b_min": [0, 0, 0, 3, 3, 3, 6, 6, 6],
"d_sum": [30, 30, 30, 39, 39, 39, 48, 48, 48],
}
)
tm.assert_frame_equal(result, expected)


def test_transform_with_duplicate_columns():
df = DataFrame({"A": list("aaabbbccc"), "B": range(9, 18)})
result = df.groupby("A").transform(
b_min=NamedAgg(column="B", aggfunc="min"),
b_max=NamedAgg(column="B", aggfunc="max"),
)
expected = DataFrame(
{
"b_min": [9, 9, 9, 12, 12, 12, 15, 15, 15],
"b_max": [11, 11, 11, 14, 14, 14, 17, 17, 17],
}
)
tm.assert_frame_equal(result, expected)


def test_transform_fast():
df = DataFrame(
{
Expand Down
Loading