Skip to content

Commit 7108a9f

Browse files
committed
Implementation
1 parent b7b2db9 commit 7108a9f

File tree

1 file changed

+133
-3
lines changed

1 file changed

+133
-3
lines changed

pandas/core/groupby/generic.py

+133-3
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
all_indexes_same,
7676
default_index,
7777
)
78+
from pandas.core.reshape.concat import concat
7879
from pandas.core.series import Series
7980
from pandas.core.sorting import get_group_index
8081
from pandas.core.util.numba_ import maybe_use_numba
@@ -1863,15 +1864,144 @@ def _transform_general(self, func, engine, engine_kwargs, *args, **kwargs):
18631864
3 5 9
18641865
4 5 8
18651866
5 5 9
1867+
1868+
Using list-like arguments
1869+
1870+
>>> df = pd.DataFrame({"col": list("aab"), "val": range(3)})
1871+
>>> df.groupby("col").transform(["sum", "min"])
1872+
val
1873+
sum min
1874+
0 1 0
1875+
1 1 0
1876+
2 2 2
1877+
1878+
.. versionchanged:: 3.0.0
1879+
1880+
Named aggregation
1881+
1882+
>>> df = pd.DataFrame({"A": list("aaabbbccc"), "B": range(9), "D": range(9, 18)})
1883+
>>> df.groupby("A").transform(
1884+
... b_min=pd.NamedAgg(column="B", aggfunc="min"),
1885+
... c_sum=pd.NamedAgg(column="D", aggfunc="sum")
1886+
... )
1887+
b_min c_sum
1888+
0 0 30
1889+
1 0 30
1890+
2 0 30
1891+
3 3 39
1892+
4 3 39
1893+
5 3 39
1894+
6 6 48
1895+
7 6 48
1896+
8 6 48
1897+
1898+
.. versionchanged:: 3.0.0
18661899
"""
18671900
)
18681901

18691902
@Substitution(klass="DataFrame", example=__examples_dataframe_doc)
18701903
@Appender(_transform_template)
1871-
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
1872-
return self._transform(
1873-
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
1904+
def transform(
1905+
self,
1906+
func: None
1907+
| (Callable | str | list[Callable | str] | dict[str, NamedAgg]) = None,
1908+
*args,
1909+
engine: str | None = None,
1910+
engine_kwargs: dict | None = None,
1911+
**kwargs,
1912+
) -> DataFrame:
1913+
if func is None:
1914+
# Convert named aggregations to dictionary format
1915+
transformed_func = self._named_agg_to_dict(*kwargs.items())
1916+
kwargs = {}
1917+
if isinstance(transformed_func, dict):
1918+
return self._transform_multiple_funcs(
1919+
transformed_func,
1920+
*args,
1921+
engine=engine,
1922+
engine_kwargs=engine_kwargs,
1923+
**kwargs,
1924+
)
1925+
else:
1926+
if isinstance(func, list):
1927+
func = maybe_mangle_lambdas(func)
1928+
return self._transform_multiple_funcs(
1929+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
1930+
)
1931+
else:
1932+
return self._transform(
1933+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
1934+
)
1935+
1936+
def _transform_multiple_funcs(
1937+
self,
1938+
func: Any,
1939+
*args,
1940+
engine: str | None = None,
1941+
engine_kwargs: dict | None = None,
1942+
**kwargs,
1943+
) -> DataFrame:
1944+
results = []
1945+
1946+
if isinstance(func, dict):
1947+
for name, named_agg in func.items():
1948+
column_name = named_agg.column
1949+
agg_func = named_agg.aggfunc
1950+
result = self._transform_single_column(
1951+
column_name,
1952+
agg_func,
1953+
*args,
1954+
engine=engine,
1955+
engine_kwargs=engine_kwargs,
1956+
**kwargs,
1957+
)
1958+
result.name = name
1959+
results.append(result)
1960+
output = concat(results, axis=1)
1961+
elif isinstance(func, list):
1962+
col_names = []
1963+
columns = (com.get_callable_name(f) or f for f in func)
1964+
func_pairs = zip(columns, func)
1965+
for idx, (name, func_item) in enumerate(func_pairs):
1966+
result = self._transform(
1967+
func_item,
1968+
*args,
1969+
engine=engine,
1970+
engine_kwargs=engine_kwargs,
1971+
**kwargs,
1972+
)
1973+
results.append(result)
1974+
col_names.extend([(col, name) for col in result.columns])
1975+
output = concat(results, axis=1)
1976+
output.columns = MultiIndex.from_tuples(col_names)
1977+
output.sort_index(axis=1, level=[0], sort_remaining=False, inplace=True)
1978+
1979+
return output
1980+
1981+
def _transform_single_column(
1982+
self,
1983+
column_name: Hashable,
1984+
agg_func: Callable | str,
1985+
*args,
1986+
engine: str | None = None,
1987+
engine_kwargs: dict | None = None,
1988+
**kwargs,
1989+
) -> Series:
1990+
data = self._obj_with_exclusions[column_name]
1991+
groupings = self._grouper.groupings
1992+
result = data.groupby(groupings).transform(
1993+
agg_func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
18741994
)
1995+
return result
1996+
1997+
@staticmethod
1998+
def _named_agg_to_dict(*named_aggs: tuple[str, NamedAgg]) -> dict[str, NamedAgg]:
1999+
valid_items = [
2000+
(name, aggfunc)
2001+
for name, aggfunc in named_aggs
2002+
if not isinstance(aggfunc[1], DataFrame)
2003+
]
2004+
return dict(valid_items)
18752005

18762006
def _define_paths(self, func, *args, **kwargs):
18772007
if isinstance(func, str):

0 commit comments

Comments
 (0)