diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 7f5c879c0d9f5..0bc9a08ee29e8 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -42,6 +42,7 @@ Other enhancements - :meth:`DataFrame.corrwith` now accepts ``min_periods`` as optional arguments, as in :meth:`DataFrame.corr` and :meth:`Series.corr` (:issue:`9490`) - :meth:`DataFrame.cummin`, :meth:`DataFrame.cummax`, :meth:`DataFrame.cumprod` and :meth:`DataFrame.cumsum` methods now have a ``numeric_only`` parameter (:issue:`53072`) - :meth:`DataFrame.fillna` and :meth:`Series.fillna` can now accept ``value=None``; for non-object dtype the corresponding NA value will be used (:issue:`57723`) +- :meth:`DataFrame.pivot_table` and :func:`pivot_table` now allow the passing of keyword arguments to ``aggfunc`` through ``**kwargs`` (:issue:`57884`) - :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`) - :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`) - Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index c37dfa225de5a..a6c0e1e372530 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -9275,6 +9275,11 @@ def pivot( .. versionadded:: 1.3.0 + **kwargs : dict + Optional keyword arguments to pass to ``aggfunc``. + + .. versionadded:: 3.0.0 + Returns ------- DataFrame @@ -9382,6 +9387,7 @@ def pivot_table( margins_name: Level = "All", observed: bool = True, sort: bool = True, + **kwargs, ) -> DataFrame: from pandas.core.reshape.pivot import pivot_table @@ -9397,6 +9403,7 @@ def pivot_table( margins_name=margins_name, observed=observed, sort=sort, + **kwargs, ) def stack( diff --git a/pandas/core/reshape/pivot.py b/pandas/core/reshape/pivot.py index 86da19f13bacf..f767a94682ead 100644 --- a/pandas/core/reshape/pivot.py +++ b/pandas/core/reshape/pivot.py @@ -66,6 +66,7 @@ def pivot_table( margins_name: Hashable = "All", observed: bool = True, sort: bool = True, + **kwargs, ) -> DataFrame: """ Create a spreadsheet-style pivot table as a DataFrame. @@ -124,6 +125,11 @@ def pivot_table( .. versionadded:: 1.3.0 + **kwargs : dict + Optional keyword arguments to pass to ``aggfunc``. + + .. versionadded:: 3.0.0 + Returns ------- DataFrame @@ -251,6 +257,7 @@ def pivot_table( margins_name=margins_name, observed=observed, sort=sort, + kwargs=kwargs, ) pieces.append(_table) keys.append(getattr(func, "__name__", func)) @@ -270,6 +277,7 @@ def pivot_table( margins_name, observed, sort, + kwargs, ) return table.__finalize__(data, method="pivot_table") @@ -286,6 +294,7 @@ def __internal_pivot_table( margins_name: Hashable, observed: bool, sort: bool, + kwargs, ) -> DataFrame: """ Helper of :func:`pandas.pivot_table` for any non-list ``aggfunc``. @@ -328,7 +337,7 @@ def __internal_pivot_table( values = list(values) grouped = data.groupby(keys, observed=observed, sort=sort, dropna=dropna) - agged = grouped.agg(aggfunc) + agged = grouped.agg(aggfunc, **kwargs) if dropna and isinstance(agged, ABCDataFrame) and len(agged.columns): agged = agged.dropna(how="all") @@ -383,6 +392,7 @@ def __internal_pivot_table( rows=index, cols=columns, aggfunc=aggfunc, + kwargs=kwargs, observed=dropna, margins_name=margins_name, fill_value=fill_value, @@ -408,6 +418,7 @@ def _add_margins( rows, cols, aggfunc, + kwargs, observed: bool, margins_name: Hashable = "All", fill_value=None, @@ -420,7 +431,7 @@ def _add_margins( if margins_name in table.index.get_level_values(level): raise ValueError(msg) - grand_margin = _compute_grand_margin(data, values, aggfunc, margins_name) + grand_margin = _compute_grand_margin(data, values, aggfunc, kwargs, margins_name) if table.ndim == 2: # i.e. DataFrame @@ -441,7 +452,15 @@ def _add_margins( elif values: marginal_result_set = _generate_marginal_results( - table, data, values, rows, cols, aggfunc, observed, margins_name + table, + data, + values, + rows, + cols, + aggfunc, + kwargs, + observed, + margins_name, ) if not isinstance(marginal_result_set, tuple): return marginal_result_set @@ -450,7 +469,7 @@ def _add_margins( # no values, and table is a DataFrame assert isinstance(table, ABCDataFrame) marginal_result_set = _generate_marginal_results_without_values( - table, data, rows, cols, aggfunc, observed, margins_name + table, data, rows, cols, aggfunc, kwargs, observed, margins_name ) if not isinstance(marginal_result_set, tuple): return marginal_result_set @@ -487,26 +506,26 @@ def _add_margins( def _compute_grand_margin( - data: DataFrame, values, aggfunc, margins_name: Hashable = "All" + data: DataFrame, values, aggfunc, kwargs, margins_name: Hashable = "All" ): if values: grand_margin = {} for k, v in data[values].items(): try: if isinstance(aggfunc, str): - grand_margin[k] = getattr(v, aggfunc)() + grand_margin[k] = getattr(v, aggfunc)(**kwargs) elif isinstance(aggfunc, dict): if isinstance(aggfunc[k], str): - grand_margin[k] = getattr(v, aggfunc[k])() + grand_margin[k] = getattr(v, aggfunc[k])(**kwargs) else: - grand_margin[k] = aggfunc[k](v) + grand_margin[k] = aggfunc[k](v, **kwargs) else: - grand_margin[k] = aggfunc(v) + grand_margin[k] = aggfunc(v, **kwargs) except TypeError: pass return grand_margin else: - return {margins_name: aggfunc(data.index)} + return {margins_name: aggfunc(data.index, **kwargs)} def _generate_marginal_results( @@ -516,6 +535,7 @@ def _generate_marginal_results( rows, cols, aggfunc, + kwargs, observed: bool, margins_name: Hashable = "All", ): @@ -529,7 +549,11 @@ def _all_key(key): return (key, margins_name) + ("",) * (len(cols) - 1) if len(rows) > 0: - margin = data[rows + values].groupby(rows, observed=observed).agg(aggfunc) + margin = ( + data[rows + values] + .groupby(rows, observed=observed) + .agg(aggfunc, **kwargs) + ) cat_axis = 1 for key, piece in table.T.groupby(level=0, observed=observed): @@ -554,7 +578,7 @@ def _all_key(key): table_pieces.append(piece) # GH31016 this is to calculate margin for each group, and assign # corresponded key as index - transformed_piece = DataFrame(piece.apply(aggfunc)).T + transformed_piece = DataFrame(piece.apply(aggfunc, **kwargs)).T if isinstance(piece.index, MultiIndex): # We are adding an empty level transformed_piece.index = MultiIndex.from_tuples( @@ -584,7 +608,9 @@ def _all_key(key): margin_keys = table.columns if len(cols) > 0: - row_margin = data[cols + values].groupby(cols, observed=observed).agg(aggfunc) + row_margin = ( + data[cols + values].groupby(cols, observed=observed).agg(aggfunc, **kwargs) + ) row_margin = row_margin.stack() # GH#26568. Use names instead of indices in case of numeric names @@ -603,6 +629,7 @@ def _generate_marginal_results_without_values( rows, cols, aggfunc, + kwargs, observed: bool, margins_name: Hashable = "All", ): @@ -617,14 +644,16 @@ def _all_key(): return (margins_name,) + ("",) * (len(cols) - 1) if len(rows) > 0: - margin = data.groupby(rows, observed=observed)[rows].apply(aggfunc) + margin = data.groupby(rows, observed=observed)[rows].apply( + aggfunc, **kwargs + ) all_key = _all_key() table[all_key] = margin result = table margin_keys.append(all_key) else: - margin = data.groupby(level=0, observed=observed).apply(aggfunc) + margin = data.groupby(level=0, observed=observed).apply(aggfunc, **kwargs) all_key = _all_key() table[all_key] = margin result = table @@ -635,7 +664,9 @@ def _all_key(): margin_keys = table.columns if len(cols): - row_margin = data.groupby(cols, observed=observed)[cols].apply(aggfunc) + row_margin = data.groupby(cols, observed=observed)[cols].apply( + aggfunc, **kwargs + ) else: row_margin = Series(np.nan, index=result.columns) diff --git a/pandas/tests/reshape/test_pivot.py b/pandas/tests/reshape/test_pivot.py index 4a13c1f5e1167..728becc76b71f 100644 --- a/pandas/tests/reshape/test_pivot.py +++ b/pandas/tests/reshape/test_pivot.py @@ -2058,6 +2058,60 @@ def test_pivot_string_as_func(self): ).rename_axis("A") tm.assert_frame_equal(result, expected) + @pytest.mark.parametrize("kwargs", [{"a": 2}, {"a": 2, "b": 3}, {"b": 3, "a": 2}]) + def test_pivot_table_kwargs(self, kwargs): + # GH#57884 + def f(x, a, b=3): + return x.sum() * a + b + + def g(x): + return f(x, **kwargs) + + df = DataFrame( + { + "A": ["good", "bad", "good", "bad", "good"], + "B": ["one", "two", "one", "three", "two"], + "X": [2, 5, 4, 20, 10], + } + ) + result = pivot_table( + df, index="A", columns="B", values="X", aggfunc=f, **kwargs + ) + expected = pivot_table(df, index="A", columns="B", values="X", aggfunc=g) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "kwargs", [{}, {"b": 10}, {"a": 3}, {"a": 3, "b": 10}, {"b": 10, "a": 3}] + ) + def test_pivot_table_kwargs_margin(self, data, kwargs): + # GH#57884 + def f(x, a=5, b=7): + return (x.sum() + b) * a + + def g(x): + return f(x, **kwargs) + + result = data.pivot_table( + values="D", + index=["A", "B"], + columns="C", + aggfunc=f, + margins=True, + fill_value=0, + **kwargs, + ) + + expected = data.pivot_table( + values="D", + index=["A", "B"], + columns="C", + aggfunc=g, + margins=True, + fill_value=0, + ) + + tm.assert_frame_equal(result, expected) + @pytest.mark.parametrize( "f, f_numpy", [