-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
Enh/group by.transform should accept similar arguments to group by.agg #58773
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ec778ce
e39d8a4
a1b1d3d
78b6e90
f90a3e5
324f38c
888e2ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,6 +75,7 @@ | |
all_indexes_same, | ||
default_index, | ||
) | ||
from pandas.core.reshape.concat import concat | ||
from pandas.core.series import Series | ||
from pandas.core.sorting import get_group_index | ||
from pandas.core.util.numba_ import maybe_use_numba | ||
|
@@ -1863,15 +1864,145 @@ def _transform_general(self, func, engine, engine_kwargs, *args, **kwargs): | |
3 5 9 | ||
4 5 8 | ||
5 5 9 | ||
|
||
List-like arguments | ||
|
||
>>> df = pd.DataFrame({"col": list("aab"), "val": range(3), "other_val": range(3)}) | ||
>>> df.groupby("col").transform(["sum", "min"]) | ||
val other_val | ||
sum min sum min | ||
0 1 0 1 0 | ||
1 1 0 1 0 | ||
2 2 2 2 2 | ||
|
||
.. versionchanged:: 3.0.0 | ||
|
||
Dictionary arguments | ||
|
||
>>> df = pd.DataFrame({"col": list("aab"), "val": range(3), "other_val": range(3)}) | ||
>>> df.groupby("col").transform({"val": "sum", "other_val": "min"}) | ||
val other_val | ||
0 1 0 | ||
1 1 0 | ||
2 2 2 | ||
|
||
.. versionchanged:: 3.0.0 | ||
|
||
Named aggregation | ||
|
||
>>> df = pd.DataFrame({"col": list("aab"), "val": range(3), "other_val": range(3)}) | ||
>>> df.groupby("col").transform( | ||
... val_sum=pd.NamedAgg(column="val", aggfunc="sum"), | ||
... other_min=pd.NamedAgg(column="other_val", aggfunc="min") | ||
... ) | ||
val_sum other_min | ||
0 1 0 | ||
1 1 0 | ||
2 2 2 | ||
|
||
.. versionchanged:: 3.0.0 | ||
""" | ||
) | ||
|
||
@Substitution(klass="DataFrame", example=__examples_dataframe_doc) | ||
@Appender(_transform_template) | ||
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): | ||
return self._transform( | ||
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs | ||
def transform( | ||
self, | ||
func: None | ||
| (Callable | str | list[Callable | str] | dict[str, NamedAgg]) = None, | ||
*args, | ||
engine: str | None = None, | ||
engine_kwargs: dict | None = None, | ||
**kwargs, | ||
) -> DataFrame: | ||
if func is None: | ||
transformed_func = dict(kwargs.items()) | ||
return self._transform_multiple_funcs( | ||
transformed_func, *args, engine=engine, engine_kwargs=engine_kwargs | ||
) | ||
elif isinstance(func, dict): | ||
return self._transform_multiple_funcs( | ||
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs | ||
) | ||
elif isinstance(func, list): | ||
func = maybe_mangle_lambdas(func) | ||
return self._transform_multiple_funcs( | ||
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs | ||
) | ||
else: | ||
return self._transform( | ||
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs | ||
) | ||
|
||
def _transform_multiple_funcs( | ||
self, | ||
func: Any, | ||
*args, | ||
engine: str | None = None, | ||
engine_kwargs: dict | None = None, | ||
**kwargs, | ||
) -> DataFrame: | ||
if isinstance(func, dict): | ||
results = [] | ||
for name, agg in func.items(): | ||
if isinstance(agg, NamedAgg): | ||
column_name = agg.column | ||
agg_func = agg.aggfunc | ||
else: | ||
column_name = name | ||
agg_func = agg | ||
result = self._transform_single_column( | ||
column_name, | ||
agg_func, | ||
*args, | ||
engine=engine, | ||
engine_kwargs=engine_kwargs, | ||
**kwargs, | ||
) | ||
result.name = name | ||
results.append(result) | ||
output = concat(results, axis=1) | ||
elif isinstance(func, list): | ||
results = [] | ||
col_order = [] | ||
keys_list = list(self.keys) if isinstance(self.keys, list) else [self.keys] | ||
for column in self.obj.columns: | ||
if column in keys_list: | ||
continue | ||
column_results = [ | ||
self._transform_single_column( | ||
column, | ||
agg_func, | ||
*args, | ||
engine=engine, | ||
engine_kwargs=engine_kwargs, | ||
**kwargs, | ||
).rename((column, agg_func)) | ||
for agg_func in func | ||
] | ||
for col_result in column_results: | ||
results.append(col_result) | ||
col_order.append(col_result.name) | ||
output = concat(results, ignore_index=True, axis=1) | ||
arrays = [list(x) for x in zip(*col_order)] | ||
output.columns = MultiIndex.from_arrays(arrays) | ||
|
||
return output | ||
|
||
def _transform_single_column( | ||
self, | ||
column_name: Hashable, | ||
agg_func: Callable | str, | ||
*args, | ||
engine: str | None = None, | ||
engine_kwargs: dict | None = None, | ||
**kwargs, | ||
) -> Series: | ||
data = self._gotitem(column_name, ndim=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will fail if the input has duplicate column names. I would be okay with not supporting duplicate columns here (and raising a clear error message), but this would need further support from other maintainers. The other option is to make this work on duplicate column names. |
||
result = data.transform( | ||
agg_func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs | ||
) | ||
return result | ||
|
||
def _define_paths(self, func, *args, **kwargs): | ||
if isinstance(func, str): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
date_range, | ||
) | ||
import pandas._testing as tm | ||
from pandas.core.groupby import NamedAgg | ||
from pandas.tests.groupby import get_groupby_method_args | ||
|
||
|
||
|
@@ -84,6 +85,60 @@ def demean(arr): | |
tm.assert_frame_equal(result, expected) | ||
|
||
|
||
def test_transform_with_list_like(): | ||
df = DataFrame({"col": list("aab"), "val": range(3), "another": range(3)}) | ||
Comment on lines
+88
to
+89
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add references to each test, e.g. def test_transform_with_list_like():
# GH#58318 |
||
result = df.groupby("col").transform(["sum", "min"]) | ||
expected = DataFrame( | ||
{ | ||
("val", "sum"): [1, 1, 2], | ||
("val", "min"): [0, 0, 2], | ||
("another", "sum"): [1, 1, 2], | ||
("another", "min"): [0, 0, 2], | ||
} | ||
) | ||
expected.columns = MultiIndex.from_tuples( | ||
[("val", "sum"), ("val", "min"), ("another", "sum"), ("another", "min")] | ||
) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
|
||
def test_transform_with_dict(): | ||
df = DataFrame({"col": list("aab"), "val": range(3), "another": range(3)}) | ||
result = df.groupby("col").transform({"val": "sum", "another": "min"}) | ||
expected = DataFrame({"val": [1, 1, 2], "another": [0, 0, 2]}) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
|
||
def test_transform_with_namedagg(): | ||
df = DataFrame({"A": list("aaabbbccc"), "B": range(9), "D": range(9, 18)}) | ||
result = df.groupby("A").transform( | ||
b_min=NamedAgg(column="B", aggfunc="min"), | ||
d_sum=NamedAgg(column="D", aggfunc="sum"), | ||
) | ||
expected = DataFrame( | ||
{ | ||
"b_min": [0, 0, 0, 3, 3, 3, 6, 6, 6], | ||
"d_sum": [30, 30, 30, 39, 39, 39, 48, 48, 48], | ||
} | ||
) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
|
||
def test_transform_with_duplicate_columns(): | ||
df = DataFrame({"A": list("aaabbbccc"), "B": range(9, 18)}) | ||
result = df.groupby("A").transform( | ||
b_min=NamedAgg(column="B", aggfunc="min"), | ||
b_max=NamedAgg(column="B", aggfunc="max"), | ||
) | ||
expected = DataFrame( | ||
{ | ||
"b_min": [9, 9, 9, 12, 12, 12, 15, 15, 15], | ||
"b_max": [11, 11, 11, 14, 14, 14, 17, 17, 17], | ||
} | ||
) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
|
||
def test_transform_fast(): | ||
df = DataFrame( | ||
{ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Iterate over
self._obj_with_exclusions
instead. Then you don't needkeys_list
.