Skip to content

Commit 78b6e90

Browse files
committed
add dictionary option and fix list issue
1 parent a1b1d3d commit 78b6e90

File tree

3 files changed

+92
-58
lines changed

3 files changed

+92
-58
lines changed

doc/source/whatsnew/v3.0.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Other enhancements
4242
- :meth:`DataFrame.corrwith` now accepts ``min_periods`` as optional arguments, as in :meth:`DataFrame.corr` and :meth:`Series.corr` (:issue:`9490`)
4343
- :meth:`DataFrame.cummin`, :meth:`DataFrame.cummax`, :meth:`DataFrame.cumprod` and :meth:`DataFrame.cumsum` methods now have a ``numeric_only`` parameter (:issue:`53072`)
4444
- :meth:`DataFrame.fillna` and :meth:`Series.fillna` can now accept ``value=None``; for non-object dtype the corresponding NA value will be used (:issue:`57723`)
45-
- :meth:`GroupBy.transform` now accepts list-like arguments similar to :meth:`GroupBy.agg`, and supports :class:`NamedAgg` (:issue:`58318`)
45+
- :meth:`GroupBy.transform` now accepts list-like arguments and dictionary arguments similar to :meth:`GroupBy.agg`, and supports :class:`NamedAgg` (:issue:`58318`)
4646
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
4747
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
4848
- Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`)

pandas/core/groupby/generic.py

+67-49
Original file line numberDiff line numberDiff line change
@@ -1865,35 +1865,40 @@ def _transform_general(self, func, engine, engine_kwargs, *args, **kwargs):
18651865
4 5 8
18661866
5 5 9
18671867
1868-
Using list-like arguments
1868+
List-like arguments
18691869
1870-
>>> df = pd.DataFrame({"col": list("aab"), "val": range(3)})
1870+
>>> df = pd.DataFrame({"col": list("aab"), "val": range(3), "other_val": range(3)})
18711871
>>> df.groupby("col").transform(["sum", "min"])
1872-
val
1873-
sum min
1874-
0 1 0
1875-
1 1 0
1876-
2 2 2
1872+
val other_val
1873+
sum min sum min
1874+
0 1 0 1 0
1875+
1 1 0 1 0
1876+
2 2 2 2 2
1877+
1878+
.. versionchanged:: 3.0.0
1879+
1880+
Dictionary arguments
1881+
1882+
>>> df = pd.DataFrame({"col": list("aab"), "val": range(3), "other_val": range(3)})
1883+
>>> df.groupby("col").transform({"val": "sum", "other_val": "min"})
1884+
val other_val
1885+
0 1 0
1886+
1 1 0
1887+
2 2 2
18771888
18781889
.. versionchanged:: 3.0.0
18791890
18801891
Named aggregation
18811892
1882-
>>> df = pd.DataFrame({"A": list("aaabbbccc"), "B": range(9), "D": range(9, 18)})
1883-
>>> df.groupby("A").transform(
1884-
... b_min=pd.NamedAgg(column="B", aggfunc="min"),
1885-
... c_sum=pd.NamedAgg(column="D", aggfunc="sum")
1893+
>>> df = pd.DataFrame({"col": list("aab"), "val": range(3), "other_val": range(3)})
1894+
>>> df.groupby("col").transform(
1895+
... val_sum=pd.NamedAgg(column="val", aggfunc="sum"),
1896+
... other_min=pd.NamedAgg(column="other_val", aggfunc="min")
18861897
... )
1887-
b_min c_sum
1888-
0 0 30
1889-
1 0 30
1890-
2 0 30
1891-
3 3 39
1892-
4 3 39
1893-
5 3 39
1894-
6 6 48
1895-
7 6 48
1896-
8 6 48
1898+
val_sum other_min
1899+
0 1 0
1900+
1 1 0
1901+
2 2 2
18971902
18981903
.. versionchanged:: 3.0.0
18991904
"""
@@ -1915,16 +1920,19 @@ def transform(
19151920
return self._transform_multiple_funcs(
19161921
transformed_func, *args, engine=engine, engine_kwargs=engine_kwargs
19171922
)
1923+
elif isinstance(func, dict):
1924+
return self._transform_multiple_funcs(
1925+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
1926+
)
1927+
elif isinstance(func, list):
1928+
func = maybe_mangle_lambdas(func)
1929+
return self._transform_multiple_funcs(
1930+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
1931+
)
19181932
else:
1919-
if isinstance(func, list):
1920-
func = maybe_mangle_lambdas(func)
1921-
return self._transform_multiple_funcs(
1922-
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
1923-
)
1924-
else:
1925-
return self._transform(
1926-
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
1927-
)
1933+
return self._transform(
1934+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
1935+
)
19281936

19291937
def _transform_multiple_funcs(
19301938
self,
@@ -1934,11 +1942,15 @@ def _transform_multiple_funcs(
19341942
engine_kwargs: dict | None = None,
19351943
**kwargs,
19361944
) -> DataFrame:
1937-
results = []
19381945
if isinstance(func, dict):
1939-
for name, named_agg in func.items():
1940-
column_name = named_agg.column
1941-
agg_func = named_agg.aggfunc
1946+
results = []
1947+
for name, agg in func.items():
1948+
if isinstance(agg, NamedAgg):
1949+
column_name = agg.column
1950+
agg_func = agg.aggfunc
1951+
else:
1952+
column_name = name
1953+
agg_func = agg
19421954
result = self._transform_single_column(
19431955
column_name,
19441956
agg_func,
@@ -1951,21 +1963,27 @@ def _transform_multiple_funcs(
19511963
results.append(result)
19521964
output = concat(results, axis=1)
19531965
elif isinstance(func, list):
1954-
col_names = []
1955-
columns = [com.get_callable_name(f) or f for f in func]
1956-
func_pairs = zip(columns, func)
1957-
for name, func_item in func_pairs:
1958-
result = self._transform(
1959-
func_item,
1960-
*args,
1961-
engine=engine,
1962-
engine_kwargs=engine_kwargs,
1963-
**kwargs,
1964-
)
1965-
results.append(result)
1966-
col_names.extend([(col, name) for col in result.columns])
1967-
output = concat(results, ignore_index=True, axis=1)
1968-
arrays = [list(x) for x in zip(*col_names)]
1966+
results = []
1967+
col_order = []
1968+
for column in self.obj.columns:
1969+
if column in self.keys:
1970+
continue
1971+
column_results = [
1972+
self._transform_single_column(
1973+
column,
1974+
agg_func,
1975+
*args,
1976+
engine=engine,
1977+
engine_kwargs=engine_kwargs,
1978+
**kwargs,
1979+
).rename((column, agg_func))
1980+
for agg_func in func
1981+
]
1982+
combined_result = concat(column_results, axis=1)
1983+
results.append(combined_result)
1984+
col_order.extend([(column, f) for f in func])
1985+
output = concat(results, axis=1)
1986+
arrays = [list(x) for x in zip(*col_order)]
19691987
output.columns = MultiIndex.from_arrays(arrays)
19701988

19711989
return output

pandas/tests/groupby/transform/test_transform.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,30 @@ def demean(arr):
8585
tm.assert_frame_equal(result, expected)
8686

8787

88+
def test_transform_with_list_like():
89+
df = DataFrame({"col": list("aab"), "val": range(3), "another": range(3)})
90+
result = df.groupby("col").transform(["sum", "min"])
91+
expected = DataFrame(
92+
{
93+
("val", "sum"): [1, 1, 2],
94+
("val", "min"): [0, 0, 2],
95+
("another", "sum"): [1, 1, 2],
96+
("another", "min"): [0, 0, 2],
97+
}
98+
)
99+
expected.columns = MultiIndex.from_tuples(
100+
[("val", "sum"), ("val", "min"), ("another", "sum"), ("another", "min")]
101+
)
102+
tm.assert_frame_equal(result, expected)
103+
104+
105+
def test_transform_with_dict():
106+
df = DataFrame({"col": list("aab"), "val": range(3), "another": range(3)})
107+
result = df.groupby("col").transform({"val": "sum", "another": "min"})
108+
expected = DataFrame({"val": [1, 1, 2], "another": [0, 0, 2]})
109+
tm.assert_frame_equal(result, expected)
110+
111+
88112
def test_transform_with_namedagg():
89113
df = DataFrame({"A": list("aaabbbccc"), "B": range(9), "D": range(9, 18)})
90114
result = df.groupby("A").transform(
@@ -100,14 +124,6 @@ def test_transform_with_namedagg():
100124
tm.assert_frame_equal(result, expected)
101125

102126

103-
def test_transform_with_list_like():
104-
df = DataFrame({"col": list("aab"), "val": range(3)})
105-
result = df.groupby("col").transform(["sum", "min"])
106-
expected = DataFrame({"val_sum": [1, 1, 2], "val_min": [0, 0, 2]})
107-
expected.columns = MultiIndex.from_tuples([("val", "sum"), ("val", "min")])
108-
tm.assert_frame_equal(result, expected)
109-
110-
111127
def test_transform_with_duplicate_columns():
112128
df = DataFrame({"A": list("aaabbbccc"), "B": range(9, 18)})
113129
result = df.groupby("A").transform(

0 commit comments

Comments
 (0)