Skip to content

ENH: Add **kwargs to pivot_table to allow the specification of aggfunc keyword arguments #57884 #58893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Other enhancements
- :meth:`DataFrame.corrwith` now accepts ``min_periods`` as optional arguments, as in :meth:`DataFrame.corr` and :meth:`Series.corr` (:issue:`9490`)
- :meth:`DataFrame.cummin`, :meth:`DataFrame.cummax`, :meth:`DataFrame.cumprod` and :meth:`DataFrame.cumsum` methods now have a ``numeric_only`` parameter (:issue:`53072`)
- :meth:`DataFrame.fillna` and :meth:`Series.fillna` can now accept ``value=None``; for non-object dtype the corresponding NA value will be used (:issue:`57723`)
- :meth:`DataFrame.pivot_table` and :func:`pivot_table` now allow the passing of keyword arguments to ``aggfunc`` through ``**kwargs`` (:issue:`57884`)
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
- Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`)
Expand Down
7 changes: 7 additions & 0 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9275,6 +9275,11 @@ def pivot(

.. versionadded:: 1.3.0

**kwargs : dict
Optional keyword arguments to pass to ``aggfunc``.

.. versionadded:: 3.0.0

Returns
-------
DataFrame
Expand Down Expand Up @@ -9382,6 +9387,7 @@ def pivot_table(
margins_name: Level = "All",
observed: bool = True,
sort: bool = True,
**kwargs,
) -> DataFrame:
from pandas.core.reshape.pivot import pivot_table

Expand All @@ -9397,6 +9403,7 @@ def pivot_table(
margins_name=margins_name,
observed=observed,
sort=sort,
**kwargs,
)

def stack(
Expand Down
63 changes: 47 additions & 16 deletions pandas/core/reshape/pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def pivot_table(
margins_name: Hashable = "All",
observed: bool = True,
sort: bool = True,
**kwargs,
) -> DataFrame:
"""
Create a spreadsheet-style pivot table as a DataFrame.
Expand Down Expand Up @@ -124,6 +125,11 @@ def pivot_table(

.. versionadded:: 1.3.0

**kwargs : dict
Optional keyword arguments to pass to ``aggfunc``.

.. versionadded:: 3.0.0

Returns
-------
DataFrame
Expand Down Expand Up @@ -251,6 +257,7 @@ def pivot_table(
margins_name=margins_name,
observed=observed,
sort=sort,
kwargs=kwargs,
)
pieces.append(_table)
keys.append(getattr(func, "__name__", func))
Expand All @@ -270,6 +277,7 @@ def pivot_table(
margins_name,
observed,
sort,
kwargs,
)
return table.__finalize__(data, method="pivot_table")

Expand All @@ -286,6 +294,7 @@ def __internal_pivot_table(
margins_name: Hashable,
observed: bool,
sort: bool,
kwargs,
) -> DataFrame:
"""
Helper of :func:`pandas.pivot_table` for any non-list ``aggfunc``.
Expand Down Expand Up @@ -328,7 +337,7 @@ def __internal_pivot_table(
values = list(values)

grouped = data.groupby(keys, observed=observed, sort=sort, dropna=dropna)
agged = grouped.agg(aggfunc)
agged = grouped.agg(aggfunc, **kwargs)

if dropna and isinstance(agged, ABCDataFrame) and len(agged.columns):
agged = agged.dropna(how="all")
Expand Down Expand Up @@ -383,6 +392,7 @@ def __internal_pivot_table(
rows=index,
cols=columns,
aggfunc=aggfunc,
kwargs=kwargs,
observed=dropna,
margins_name=margins_name,
fill_value=fill_value,
Expand All @@ -408,6 +418,7 @@ def _add_margins(
rows,
cols,
aggfunc,
kwargs,
observed: bool,
margins_name: Hashable = "All",
fill_value=None,
Expand All @@ -420,7 +431,7 @@ def _add_margins(
if margins_name in table.index.get_level_values(level):
raise ValueError(msg)

grand_margin = _compute_grand_margin(data, values, aggfunc, margins_name)
grand_margin = _compute_grand_margin(data, values, aggfunc, kwargs, margins_name)

if table.ndim == 2:
# i.e. DataFrame
Expand All @@ -441,7 +452,15 @@ def _add_margins(

elif values:
marginal_result_set = _generate_marginal_results(
table, data, values, rows, cols, aggfunc, observed, margins_name
table,
data,
values,
rows,
cols,
aggfunc,
kwargs,
observed,
margins_name,
)
if not isinstance(marginal_result_set, tuple):
return marginal_result_set
Expand All @@ -450,7 +469,7 @@ def _add_margins(
# no values, and table is a DataFrame
assert isinstance(table, ABCDataFrame)
marginal_result_set = _generate_marginal_results_without_values(
table, data, rows, cols, aggfunc, observed, margins_name
table, data, rows, cols, aggfunc, kwargs, observed, margins_name
)
if not isinstance(marginal_result_set, tuple):
return marginal_result_set
Expand Down Expand Up @@ -487,26 +506,26 @@ def _add_margins(


def _compute_grand_margin(
data: DataFrame, values, aggfunc, margins_name: Hashable = "All"
data: DataFrame, values, aggfunc, kwargs, margins_name: Hashable = "All"
):
if values:
grand_margin = {}
for k, v in data[values].items():
try:
if isinstance(aggfunc, str):
grand_margin[k] = getattr(v, aggfunc)()
grand_margin[k] = getattr(v, aggfunc)(**kwargs)
elif isinstance(aggfunc, dict):
if isinstance(aggfunc[k], str):
grand_margin[k] = getattr(v, aggfunc[k])()
grand_margin[k] = getattr(v, aggfunc[k])(**kwargs)
else:
grand_margin[k] = aggfunc[k](v)
grand_margin[k] = aggfunc[k](v, **kwargs)
else:
grand_margin[k] = aggfunc(v)
grand_margin[k] = aggfunc(v, **kwargs)
except TypeError:
pass
return grand_margin
else:
return {margins_name: aggfunc(data.index)}
return {margins_name: aggfunc(data.index, **kwargs)}


def _generate_marginal_results(
Expand All @@ -516,6 +535,7 @@ def _generate_marginal_results(
rows,
cols,
aggfunc,
kwargs,
observed: bool,
margins_name: Hashable = "All",
):
Expand All @@ -529,7 +549,11 @@ def _all_key(key):
return (key, margins_name) + ("",) * (len(cols) - 1)

if len(rows) > 0:
margin = data[rows + values].groupby(rows, observed=observed).agg(aggfunc)
margin = (
data[rows + values]
.groupby(rows, observed=observed)
.agg(aggfunc, **kwargs)
)
cat_axis = 1

for key, piece in table.T.groupby(level=0, observed=observed):
Expand All @@ -554,7 +578,7 @@ def _all_key(key):
table_pieces.append(piece)
# GH31016 this is to calculate margin for each group, and assign
# corresponded key as index
transformed_piece = DataFrame(piece.apply(aggfunc)).T
transformed_piece = DataFrame(piece.apply(aggfunc, **kwargs)).T
if isinstance(piece.index, MultiIndex):
# We are adding an empty level
transformed_piece.index = MultiIndex.from_tuples(
Expand Down Expand Up @@ -584,7 +608,9 @@ def _all_key(key):
margin_keys = table.columns

if len(cols) > 0:
row_margin = data[cols + values].groupby(cols, observed=observed).agg(aggfunc)
row_margin = (
data[cols + values].groupby(cols, observed=observed).agg(aggfunc, **kwargs)
)
row_margin = row_margin.stack()

# GH#26568. Use names instead of indices in case of numeric names
Expand All @@ -603,6 +629,7 @@ def _generate_marginal_results_without_values(
rows,
cols,
aggfunc,
kwargs,
observed: bool,
margins_name: Hashable = "All",
):
Expand All @@ -617,14 +644,16 @@ def _all_key():
return (margins_name,) + ("",) * (len(cols) - 1)

if len(rows) > 0:
margin = data.groupby(rows, observed=observed)[rows].apply(aggfunc)
margin = data.groupby(rows, observed=observed)[rows].apply(
aggfunc, **kwargs
)
all_key = _all_key()
table[all_key] = margin
result = table
margin_keys.append(all_key)

else:
margin = data.groupby(level=0, observed=observed).apply(aggfunc)
margin = data.groupby(level=0, observed=observed).apply(aggfunc, **kwargs)
all_key = _all_key()
table[all_key] = margin
result = table
Expand All @@ -635,7 +664,9 @@ def _all_key():
margin_keys = table.columns

if len(cols):
row_margin = data.groupby(cols, observed=observed)[cols].apply(aggfunc)
row_margin = data.groupby(cols, observed=observed)[cols].apply(
aggfunc, **kwargs
)
else:
row_margin = Series(np.nan, index=result.columns)

Expand Down
54 changes: 54 additions & 0 deletions pandas/tests/reshape/test_pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2058,6 +2058,60 @@ def test_pivot_string_as_func(self):
).rename_axis("A")
tm.assert_frame_equal(result, expected)

@pytest.mark.parametrize("kwargs", [{"a": 2}, {"a": 2, "b": 3}, {"b": 3, "a": 2}])
def test_pivot_table_kwargs(self, kwargs):
# GH#57884
def f(x, a, b=3):
return x.sum() * a + b

def g(x):
return f(x, **kwargs)

df = DataFrame(
{
"A": ["good", "bad", "good", "bad", "good"],
"B": ["one", "two", "one", "three", "two"],
"X": [2, 5, 4, 20, 10],
}
)
result = pivot_table(
df, index="A", columns="B", values="X", aggfunc=f, **kwargs
)
expected = pivot_table(df, index="A", columns="B", values="X", aggfunc=g)
tm.assert_frame_equal(result, expected)

@pytest.mark.parametrize(
"kwargs", [{}, {"b": 10}, {"a": 3}, {"a": 3, "b": 10}, {"b": 10, "a": 3}]
)
def test_pivot_table_kwargs_margin(self, data, kwargs):
# GH#57884
def f(x, a=5, b=7):
return (x.sum() + b) * a

def g(x):
return f(x, **kwargs)

result = data.pivot_table(
values="D",
index=["A", "B"],
columns="C",
aggfunc=f,
margins=True,
fill_value=0,
**kwargs,
)

expected = data.pivot_table(
values="D",
index=["A", "B"],
columns="C",
aggfunc=g,
margins=True,
fill_value=0,
)

tm.assert_frame_equal(result, expected)

@pytest.mark.parametrize(
"f, f_numpy",
[
Expand Down