From 884fc10bd8b2eba3b8dc7a6e1e6243bc25488b25 Mon Sep 17 00:00:00 2001 From: Rui Amaral Date: Mon, 27 May 2024 01:55:24 +0100 Subject: [PATCH 1/4] feat: Add **kwargs to pivot_table #57884 Add the option of passing keyword arguments to DataFrame.pivot_table and pivot_table's aggfunc through **kwargs. Co-authored-by: Pedro Freitas --- doc/source/whatsnew/v3.0.0.rst | 1 + pandas/core/frame.py | 5 ++ pandas/core/reshape/pivot.py | 58 ++++++++++++++++------- pandas/tests/reshape/test_pivot.py | 75 ++++++++++++++++++++++++++++++ 4 files changed, 123 insertions(+), 16 deletions(-) diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 865996bdf8892..d465243e771e9 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -42,6 +42,7 @@ Other enhancements - :meth:`DataFrame.corrwith` now accepts ``min_periods`` as optional arguments, as in :meth:`DataFrame.corr` and :meth:`Series.corr` (:issue:`9490`) - :meth:`DataFrame.cummin`, :meth:`DataFrame.cummax`, :meth:`DataFrame.cumprod` and :meth:`DataFrame.cumsum` methods now have a ``numeric_only`` parameter (:issue:`53072`) - :meth:`DataFrame.fillna` and :meth:`Series.fillna` can now accept ``value=None``; for non-object dtype the corresponding NA value will be used (:issue:`57723`) +- :meth:`DataFrame.pivot_table` and :func:`pivot_table` now allow the passing of keyword arguments to ``aggfun`` through ``**kwargs`` (:issue:`57884`) - :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`) - :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`) - Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 97a4e414608b8..862fb5c8f413c 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -9292,6 +9292,9 @@ def pivot( .. versionadded:: 1.3.0 + **kwargs : dict + Optional keyword arguments to pass to ``aggfunc``. + Returns ------- DataFrame @@ -9399,6 +9402,7 @@ def pivot_table( margins_name: Level = "All", observed: bool = True, sort: bool = True, + **kwargs, ) -> DataFrame: from pandas.core.reshape.pivot import pivot_table @@ -9414,6 +9418,7 @@ def pivot_table( margins_name=margins_name, observed=observed, sort=sort, + **kwargs, ) def stack( diff --git a/pandas/core/reshape/pivot.py b/pandas/core/reshape/pivot.py index e0126d439a79c..ca01769add450 100644 --- a/pandas/core/reshape/pivot.py +++ b/pandas/core/reshape/pivot.py @@ -70,6 +70,7 @@ def pivot_table( margins_name: Hashable = "All", observed: bool = True, sort: bool = True, + **kwargs, ) -> DataFrame: index = _convert_by(index) columns = _convert_by(columns) @@ -90,6 +91,7 @@ def pivot_table( margins_name=margins_name, observed=observed, sort=sort, + kwargs=kwargs, ) pieces.append(_table) keys.append(getattr(func, "__name__", func)) @@ -109,6 +111,7 @@ def pivot_table( margins_name, observed, sort, + kwargs, ) return table.__finalize__(data, method="pivot_table") @@ -125,6 +128,7 @@ def __internal_pivot_table( margins_name: Hashable, observed: bool, sort: bool, + kwargs, ) -> DataFrame: """ Helper of :func:`pandas.pivot_table` for any non-list ``aggfunc``. @@ -167,7 +171,7 @@ def __internal_pivot_table( values = list(values) grouped = data.groupby(keys, observed=observed, sort=sort, dropna=dropna) - agged = grouped.agg(aggfunc) + agged = grouped.agg(aggfunc, **kwargs) if dropna and isinstance(agged, ABCDataFrame) and len(agged.columns): agged = agged.dropna(how="all") @@ -222,6 +226,7 @@ def __internal_pivot_table( rows=index, cols=columns, aggfunc=aggfunc, + kwargs=kwargs, observed=dropna, margins_name=margins_name, fill_value=fill_value, @@ -247,6 +252,7 @@ def _add_margins( rows, cols, aggfunc, + kwargs, observed: bool, margins_name: Hashable = "All", fill_value=None, @@ -259,7 +265,7 @@ def _add_margins( if margins_name in table.index.get_level_values(level): raise ValueError(msg) - grand_margin = _compute_grand_margin(data, values, aggfunc, margins_name) + grand_margin = _compute_grand_margin(data, values, aggfunc, kwargs, margins_name) if table.ndim == 2: # i.e. DataFrame @@ -280,7 +286,15 @@ def _add_margins( elif values: marginal_result_set = _generate_marginal_results( - table, data, values, rows, cols, aggfunc, observed, margins_name + table, + data, + values, + rows, + cols, + aggfunc, + kwargs, + observed, + margins_name, ) if not isinstance(marginal_result_set, tuple): return marginal_result_set @@ -289,7 +303,7 @@ def _add_margins( # no values, and table is a DataFrame assert isinstance(table, ABCDataFrame) marginal_result_set = _generate_marginal_results_without_values( - table, data, rows, cols, aggfunc, observed, margins_name + table, data, rows, cols, aggfunc, kwargs, observed, margins_name ) if not isinstance(marginal_result_set, tuple): return marginal_result_set @@ -326,26 +340,26 @@ def _add_margins( def _compute_grand_margin( - data: DataFrame, values, aggfunc, margins_name: Hashable = "All" + data: DataFrame, values, aggfunc, kwargs, margins_name: Hashable = "All" ): if values: grand_margin = {} for k, v in data[values].items(): try: if isinstance(aggfunc, str): - grand_margin[k] = getattr(v, aggfunc)() + grand_margin[k] = getattr(v, aggfunc)(**kwargs) elif isinstance(aggfunc, dict): if isinstance(aggfunc[k], str): - grand_margin[k] = getattr(v, aggfunc[k])() + grand_margin[k] = getattr(v, aggfunc[k])(**kwargs) else: - grand_margin[k] = aggfunc[k](v) + grand_margin[k] = aggfunc[k](v, **kwargs) else: - grand_margin[k] = aggfunc(v) + grand_margin[k] = aggfunc(v, **kwargs) except TypeError: pass return grand_margin else: - return {margins_name: aggfunc(data.index)} + return {margins_name: aggfunc(data.index, **kwargs)} def _generate_marginal_results( @@ -355,6 +369,7 @@ def _generate_marginal_results( rows, cols, aggfunc, + kwargs, observed: bool, margins_name: Hashable = "All", ): @@ -368,7 +383,11 @@ def _all_key(key): return (key, margins_name) + ("",) * (len(cols) - 1) if len(rows) > 0: - margin = data[rows + values].groupby(rows, observed=observed).agg(aggfunc) + margin = ( + data[rows + values] + .groupby(rows, observed=observed) + .agg(aggfunc, **kwargs) + ) cat_axis = 1 for key, piece in table.T.groupby(level=0, observed=observed): @@ -393,7 +412,7 @@ def _all_key(key): table_pieces.append(piece) # GH31016 this is to calculate margin for each group, and assign # corresponded key as index - transformed_piece = DataFrame(piece.apply(aggfunc)).T + transformed_piece = DataFrame(piece.apply(aggfunc, **kwargs)).T if isinstance(piece.index, MultiIndex): # We are adding an empty level transformed_piece.index = MultiIndex.from_tuples( @@ -423,7 +442,9 @@ def _all_key(key): margin_keys = table.columns if len(cols) > 0: - row_margin = data[cols + values].groupby(cols, observed=observed).agg(aggfunc) + row_margin = ( + data[cols + values].groupby(cols, observed=observed).agg(aggfunc, **kwargs) + ) row_margin = row_margin.stack() # GH#26568. Use names instead of indices in case of numeric names @@ -442,6 +463,7 @@ def _generate_marginal_results_without_values( rows, cols, aggfunc, + kwargs, observed: bool, margins_name: Hashable = "All", ): @@ -456,14 +478,16 @@ def _all_key(): return (margins_name,) + ("",) * (len(cols) - 1) if len(rows) > 0: - margin = data.groupby(rows, observed=observed)[rows].apply(aggfunc) + margin = data.groupby(rows, observed=observed)[rows].apply( + aggfunc, **kwargs + ) all_key = _all_key() table[all_key] = margin result = table margin_keys.append(all_key) else: - margin = data.groupby(level=0, observed=observed).apply(aggfunc) + margin = data.groupby(level=0, observed=observed).apply(aggfunc, **kwargs) all_key = _all_key() table[all_key] = margin result = table @@ -474,7 +498,9 @@ def _all_key(): margin_keys = table.columns if len(cols): - row_margin = data.groupby(cols, observed=observed)[cols].apply(aggfunc) + row_margin = data.groupby(cols, observed=observed)[cols].apply( + aggfunc, **kwargs + ) else: row_margin = Series(np.nan, index=result.columns) diff --git a/pandas/tests/reshape/test_pivot.py b/pandas/tests/reshape/test_pivot.py index 4a13c1f5e1167..4dba6de0f998d 100644 --- a/pandas/tests/reshape/test_pivot.py +++ b/pandas/tests/reshape/test_pivot.py @@ -2058,6 +2058,81 @@ def test_pivot_string_as_func(self): ).rename_axis("A") tm.assert_frame_equal(result, expected) + @pytest.mark.parametrize("kwargs", [{"a": 2}, {"a": 2, "b": 3}, {"b": 3, "a": 2}]) + def test_pivot_table_kwargs(self, kwargs): + def f(x, a, b=3): + return x.sum() * a + b + + def g(x): + return f(x, **kwargs) + + df = DataFrame( + { + "A": ["good", "bad", "good", "bad", "good"], + "B": ["one", "two", "one", "three", "two"], + "X": [2, 5, 4, 20, 10], + } + ) + result = pivot_table( + df, index="A", columns="B", values="X", aggfunc=f, **kwargs + ) + expected = pivot_table(df, index="A", columns="B", values="X", aggfunc=g) + tm.assert_frame_equal(result, expected) + + expected = DataFrame( + [[np.nan, 43.0, 13.0], [15.0, np.nan, 23.0]], + columns=Index(["one", "three", "two"], name="B"), + index=Index(["bad", "good"], name="A"), + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "kwargs", [{}, {"b": 10}, {"a": 3}, {"a": 3, "b": 10}, {"b": 10, "a": 3}] + ) + def test_pivot_table_kwargs_margin(self, data, kwargs): + def f(x, a=5, b=7): + return (x.sum() + b) * a + + def g(x): + return f(x, **kwargs) + + result = data.pivot_table( + values="D", + index=["A", "B"], + columns="C", + aggfunc=f, + margins=True, + fill_value=0, + **kwargs, + ) + + expected = data.pivot_table( + values="D", + index=["A", "B"], + columns="C", + aggfunc=g, + margins=True, + fill_value=0, + ) + + tm.assert_frame_equal(result, expected) + + grand_margin = g(data["D"]) + + margin_col = pivot_table( + data, values="D", index=["A", "B"], aggfunc=g, fill_value=0 + ) + margin_col.loc[("All", ""), "D"] = grand_margin + margin_col = margin_col["D"] + margin_col.name = "All" + tm.assert_series_equal(result["All"], margin_col) + + margin_row = pivot_table(data, values="D", columns="C", aggfunc=g, fill_value=0) + margin_row["All"] = grand_margin + margin_row.index = [""] + margin_row.index.name = "B" + tm.assert_frame_equal(result.loc["All"], margin_row) + @pytest.mark.parametrize( "f, f_numpy", [ From 7708b14f667993fba64a957238fcdccebd81d270 Mon Sep 17 00:00:00 2001 From: Rui Amaral Date: Sun, 2 Jun 2024 17:32:10 +0100 Subject: [PATCH 2/4] fix: typo --- doc/source/whatsnew/v3.0.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index d465243e771e9..4faedcf23aa3d 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -42,7 +42,7 @@ Other enhancements - :meth:`DataFrame.corrwith` now accepts ``min_periods`` as optional arguments, as in :meth:`DataFrame.corr` and :meth:`Series.corr` (:issue:`9490`) - :meth:`DataFrame.cummin`, :meth:`DataFrame.cummax`, :meth:`DataFrame.cumprod` and :meth:`DataFrame.cumsum` methods now have a ``numeric_only`` parameter (:issue:`53072`) - :meth:`DataFrame.fillna` and :meth:`Series.fillna` can now accept ``value=None``; for non-object dtype the corresponding NA value will be used (:issue:`57723`) -- :meth:`DataFrame.pivot_table` and :func:`pivot_table` now allow the passing of keyword arguments to ``aggfun`` through ``**kwargs`` (:issue:`57884`) +- :meth:`DataFrame.pivot_table` and :func:`pivot_table` now allow the passing of keyword arguments to ``aggfunc`` through ``**kwargs`` (:issue:`57884`) - :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`) - :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`) - Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`) From f5f62c010858b3d62bbfc360466dd79da4a38e78 Mon Sep 17 00:00:00 2001 From: Rui Amaral Date: Wed, 5 Jun 2024 10:47:02 +0100 Subject: [PATCH 3/4] fix: address review feedback --- pandas/core/frame.py | 2 ++ pandas/tests/reshape/test_pivot.py | 25 ++----------------------- 2 files changed, 4 insertions(+), 23 deletions(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 862fb5c8f413c..21ebf6d4b64ee 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -9295,6 +9295,8 @@ def pivot( **kwargs : dict Optional keyword arguments to pass to ``aggfunc``. + .. versionadded:: 3.0.0 + Returns ------- DataFrame diff --git a/pandas/tests/reshape/test_pivot.py b/pandas/tests/reshape/test_pivot.py index 4dba6de0f998d..728becc76b71f 100644 --- a/pandas/tests/reshape/test_pivot.py +++ b/pandas/tests/reshape/test_pivot.py @@ -2060,6 +2060,7 @@ def test_pivot_string_as_func(self): @pytest.mark.parametrize("kwargs", [{"a": 2}, {"a": 2, "b": 3}, {"b": 3, "a": 2}]) def test_pivot_table_kwargs(self, kwargs): + # GH#57884 def f(x, a, b=3): return x.sum() * a + b @@ -2079,17 +2080,11 @@ def g(x): expected = pivot_table(df, index="A", columns="B", values="X", aggfunc=g) tm.assert_frame_equal(result, expected) - expected = DataFrame( - [[np.nan, 43.0, 13.0], [15.0, np.nan, 23.0]], - columns=Index(["one", "three", "two"], name="B"), - index=Index(["bad", "good"], name="A"), - ) - tm.assert_frame_equal(result, expected) - @pytest.mark.parametrize( "kwargs", [{}, {"b": 10}, {"a": 3}, {"a": 3, "b": 10}, {"b": 10, "a": 3}] ) def test_pivot_table_kwargs_margin(self, data, kwargs): + # GH#57884 def f(x, a=5, b=7): return (x.sum() + b) * a @@ -2117,22 +2112,6 @@ def g(x): tm.assert_frame_equal(result, expected) - grand_margin = g(data["D"]) - - margin_col = pivot_table( - data, values="D", index=["A", "B"], aggfunc=g, fill_value=0 - ) - margin_col.loc[("All", ""), "D"] = grand_margin - margin_col = margin_col["D"] - margin_col.name = "All" - tm.assert_series_equal(result["All"], margin_col) - - margin_row = pivot_table(data, values="D", columns="C", aggfunc=g, fill_value=0) - margin_row["All"] = grand_margin - margin_row.index = [""] - margin_row.index.name = "B" - tm.assert_frame_equal(result.loc["All"], margin_row) - @pytest.mark.parametrize( "f, f_numpy", [ From a4d26115b58b13ea1dd881ac18675300dd402b04 Mon Sep 17 00:00:00 2001 From: Rui Amaral Date: Wed, 5 Jun 2024 23:13:08 +0100 Subject: [PATCH 4/4] docs: add documentation to comply with #58896 --- pandas/core/reshape/pivot.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pandas/core/reshape/pivot.py b/pandas/core/reshape/pivot.py index 2e160e7c140ab..f767a94682ead 100644 --- a/pandas/core/reshape/pivot.py +++ b/pandas/core/reshape/pivot.py @@ -125,6 +125,11 @@ def pivot_table( .. versionadded:: 1.3.0 + **kwargs : dict + Optional keyword arguments to pass to ``aggfunc``. + + .. versionadded:: 3.0.0 + Returns ------- DataFrame