Skip to content

Commit 884fc10

Browse files
ruimamaralPF2100
andcommitted
feat: Add **kwargs to pivot_table pandas-dev#57884
Add the option of passing keyword arguments to DataFrame.pivot_table and pivot_table's aggfunc through **kwargs. Co-authored-by: Pedro Freitas <[email protected]>
1 parent b0c4194 commit 884fc10

File tree

4 files changed

+123
-16
lines changed

4 files changed

+123
-16
lines changed

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Other enhancements
4242
- :meth:`DataFrame.corrwith` now accepts ``min_periods`` as optional arguments, as in :meth:`DataFrame.corr` and :meth:`Series.corr` (:issue:`9490`)
4343
- :meth:`DataFrame.cummin`, :meth:`DataFrame.cummax`, :meth:`DataFrame.cumprod` and :meth:`DataFrame.cumsum` methods now have a ``numeric_only`` parameter (:issue:`53072`)
4444
- :meth:`DataFrame.fillna` and :meth:`Series.fillna` can now accept ``value=None``; for non-object dtype the corresponding NA value will be used (:issue:`57723`)
45+
- :meth:`DataFrame.pivot_table` and :func:`pivot_table` now allow the passing of keyword arguments to ``aggfun`` through ``**kwargs`` (:issue:`57884`)
4546
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
4647
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
4748
- Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`)

pandas/core/frame.py

+5
Original file line numberDiff line numberDiff line change
@@ -9292,6 +9292,9 @@ def pivot(
92929292
92939293
.. versionadded:: 1.3.0
92949294
9295+
**kwargs : dict
9296+
Optional keyword arguments to pass to ``aggfunc``.
9297+
92959298
Returns
92969299
-------
92979300
DataFrame
@@ -9399,6 +9402,7 @@ def pivot_table(
93999402
margins_name: Level = "All",
94009403
observed: bool = True,
94019404
sort: bool = True,
9405+
**kwargs,
94029406
) -> DataFrame:
94039407
from pandas.core.reshape.pivot import pivot_table
94049408

@@ -9414,6 +9418,7 @@ def pivot_table(
94149418
margins_name=margins_name,
94159419
observed=observed,
94169420
sort=sort,
9421+
**kwargs,
94179422
)
94189423

94199424
def stack(

pandas/core/reshape/pivot.py

+42-16
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def pivot_table(
7070
margins_name: Hashable = "All",
7171
observed: bool = True,
7272
sort: bool = True,
73+
**kwargs,
7374
) -> DataFrame:
7475
index = _convert_by(index)
7576
columns = _convert_by(columns)
@@ -90,6 +91,7 @@ def pivot_table(
9091
margins_name=margins_name,
9192
observed=observed,
9293
sort=sort,
94+
kwargs=kwargs,
9395
)
9496
pieces.append(_table)
9597
keys.append(getattr(func, "__name__", func))
@@ -109,6 +111,7 @@ def pivot_table(
109111
margins_name,
110112
observed,
111113
sort,
114+
kwargs,
112115
)
113116
return table.__finalize__(data, method="pivot_table")
114117

@@ -125,6 +128,7 @@ def __internal_pivot_table(
125128
margins_name: Hashable,
126129
observed: bool,
127130
sort: bool,
131+
kwargs,
128132
) -> DataFrame:
129133
"""
130134
Helper of :func:`pandas.pivot_table` for any non-list ``aggfunc``.
@@ -167,7 +171,7 @@ def __internal_pivot_table(
167171
values = list(values)
168172

169173
grouped = data.groupby(keys, observed=observed, sort=sort, dropna=dropna)
170-
agged = grouped.agg(aggfunc)
174+
agged = grouped.agg(aggfunc, **kwargs)
171175

172176
if dropna and isinstance(agged, ABCDataFrame) and len(agged.columns):
173177
agged = agged.dropna(how="all")
@@ -222,6 +226,7 @@ def __internal_pivot_table(
222226
rows=index,
223227
cols=columns,
224228
aggfunc=aggfunc,
229+
kwargs=kwargs,
225230
observed=dropna,
226231
margins_name=margins_name,
227232
fill_value=fill_value,
@@ -247,6 +252,7 @@ def _add_margins(
247252
rows,
248253
cols,
249254
aggfunc,
255+
kwargs,
250256
observed: bool,
251257
margins_name: Hashable = "All",
252258
fill_value=None,
@@ -259,7 +265,7 @@ def _add_margins(
259265
if margins_name in table.index.get_level_values(level):
260266
raise ValueError(msg)
261267

262-
grand_margin = _compute_grand_margin(data, values, aggfunc, margins_name)
268+
grand_margin = _compute_grand_margin(data, values, aggfunc, kwargs, margins_name)
263269

264270
if table.ndim == 2:
265271
# i.e. DataFrame
@@ -280,7 +286,15 @@ def _add_margins(
280286

281287
elif values:
282288
marginal_result_set = _generate_marginal_results(
283-
table, data, values, rows, cols, aggfunc, observed, margins_name
289+
table,
290+
data,
291+
values,
292+
rows,
293+
cols,
294+
aggfunc,
295+
kwargs,
296+
observed,
297+
margins_name,
284298
)
285299
if not isinstance(marginal_result_set, tuple):
286300
return marginal_result_set
@@ -289,7 +303,7 @@ def _add_margins(
289303
# no values, and table is a DataFrame
290304
assert isinstance(table, ABCDataFrame)
291305
marginal_result_set = _generate_marginal_results_without_values(
292-
table, data, rows, cols, aggfunc, observed, margins_name
306+
table, data, rows, cols, aggfunc, kwargs, observed, margins_name
293307
)
294308
if not isinstance(marginal_result_set, tuple):
295309
return marginal_result_set
@@ -326,26 +340,26 @@ def _add_margins(
326340

327341

328342
def _compute_grand_margin(
329-
data: DataFrame, values, aggfunc, margins_name: Hashable = "All"
343+
data: DataFrame, values, aggfunc, kwargs, margins_name: Hashable = "All"
330344
):
331345
if values:
332346
grand_margin = {}
333347
for k, v in data[values].items():
334348
try:
335349
if isinstance(aggfunc, str):
336-
grand_margin[k] = getattr(v, aggfunc)()
350+
grand_margin[k] = getattr(v, aggfunc)(**kwargs)
337351
elif isinstance(aggfunc, dict):
338352
if isinstance(aggfunc[k], str):
339-
grand_margin[k] = getattr(v, aggfunc[k])()
353+
grand_margin[k] = getattr(v, aggfunc[k])(**kwargs)
340354
else:
341-
grand_margin[k] = aggfunc[k](v)
355+
grand_margin[k] = aggfunc[k](v, **kwargs)
342356
else:
343-
grand_margin[k] = aggfunc(v)
357+
grand_margin[k] = aggfunc(v, **kwargs)
344358
except TypeError:
345359
pass
346360
return grand_margin
347361
else:
348-
return {margins_name: aggfunc(data.index)}
362+
return {margins_name: aggfunc(data.index, **kwargs)}
349363

350364

351365
def _generate_marginal_results(
@@ -355,6 +369,7 @@ def _generate_marginal_results(
355369
rows,
356370
cols,
357371
aggfunc,
372+
kwargs,
358373
observed: bool,
359374
margins_name: Hashable = "All",
360375
):
@@ -368,7 +383,11 @@ def _all_key(key):
368383
return (key, margins_name) + ("",) * (len(cols) - 1)
369384

370385
if len(rows) > 0:
371-
margin = data[rows + values].groupby(rows, observed=observed).agg(aggfunc)
386+
margin = (
387+
data[rows + values]
388+
.groupby(rows, observed=observed)
389+
.agg(aggfunc, **kwargs)
390+
)
372391
cat_axis = 1
373392

374393
for key, piece in table.T.groupby(level=0, observed=observed):
@@ -393,7 +412,7 @@ def _all_key(key):
393412
table_pieces.append(piece)
394413
# GH31016 this is to calculate margin for each group, and assign
395414
# corresponded key as index
396-
transformed_piece = DataFrame(piece.apply(aggfunc)).T
415+
transformed_piece = DataFrame(piece.apply(aggfunc, **kwargs)).T
397416
if isinstance(piece.index, MultiIndex):
398417
# We are adding an empty level
399418
transformed_piece.index = MultiIndex.from_tuples(
@@ -423,7 +442,9 @@ def _all_key(key):
423442
margin_keys = table.columns
424443

425444
if len(cols) > 0:
426-
row_margin = data[cols + values].groupby(cols, observed=observed).agg(aggfunc)
445+
row_margin = (
446+
data[cols + values].groupby(cols, observed=observed).agg(aggfunc, **kwargs)
447+
)
427448
row_margin = row_margin.stack()
428449

429450
# GH#26568. Use names instead of indices in case of numeric names
@@ -442,6 +463,7 @@ def _generate_marginal_results_without_values(
442463
rows,
443464
cols,
444465
aggfunc,
466+
kwargs,
445467
observed: bool,
446468
margins_name: Hashable = "All",
447469
):
@@ -456,14 +478,16 @@ def _all_key():
456478
return (margins_name,) + ("",) * (len(cols) - 1)
457479

458480
if len(rows) > 0:
459-
margin = data.groupby(rows, observed=observed)[rows].apply(aggfunc)
481+
margin = data.groupby(rows, observed=observed)[rows].apply(
482+
aggfunc, **kwargs
483+
)
460484
all_key = _all_key()
461485
table[all_key] = margin
462486
result = table
463487
margin_keys.append(all_key)
464488

465489
else:
466-
margin = data.groupby(level=0, observed=observed).apply(aggfunc)
490+
margin = data.groupby(level=0, observed=observed).apply(aggfunc, **kwargs)
467491
all_key = _all_key()
468492
table[all_key] = margin
469493
result = table
@@ -474,7 +498,9 @@ def _all_key():
474498
margin_keys = table.columns
475499

476500
if len(cols):
477-
row_margin = data.groupby(cols, observed=observed)[cols].apply(aggfunc)
501+
row_margin = data.groupby(cols, observed=observed)[cols].apply(
502+
aggfunc, **kwargs
503+
)
478504
else:
479505
row_margin = Series(np.nan, index=result.columns)
480506

pandas/tests/reshape/test_pivot.py

+75
Original file line numberDiff line numberDiff line change
@@ -2058,6 +2058,81 @@ def test_pivot_string_as_func(self):
20582058
).rename_axis("A")
20592059
tm.assert_frame_equal(result, expected)
20602060

2061+
@pytest.mark.parametrize("kwargs", [{"a": 2}, {"a": 2, "b": 3}, {"b": 3, "a": 2}])
2062+
def test_pivot_table_kwargs(self, kwargs):
2063+
def f(x, a, b=3):
2064+
return x.sum() * a + b
2065+
2066+
def g(x):
2067+
return f(x, **kwargs)
2068+
2069+
df = DataFrame(
2070+
{
2071+
"A": ["good", "bad", "good", "bad", "good"],
2072+
"B": ["one", "two", "one", "three", "two"],
2073+
"X": [2, 5, 4, 20, 10],
2074+
}
2075+
)
2076+
result = pivot_table(
2077+
df, index="A", columns="B", values="X", aggfunc=f, **kwargs
2078+
)
2079+
expected = pivot_table(df, index="A", columns="B", values="X", aggfunc=g)
2080+
tm.assert_frame_equal(result, expected)
2081+
2082+
expected = DataFrame(
2083+
[[np.nan, 43.0, 13.0], [15.0, np.nan, 23.0]],
2084+
columns=Index(["one", "three", "two"], name="B"),
2085+
index=Index(["bad", "good"], name="A"),
2086+
)
2087+
tm.assert_frame_equal(result, expected)
2088+
2089+
@pytest.mark.parametrize(
2090+
"kwargs", [{}, {"b": 10}, {"a": 3}, {"a": 3, "b": 10}, {"b": 10, "a": 3}]
2091+
)
2092+
def test_pivot_table_kwargs_margin(self, data, kwargs):
2093+
def f(x, a=5, b=7):
2094+
return (x.sum() + b) * a
2095+
2096+
def g(x):
2097+
return f(x, **kwargs)
2098+
2099+
result = data.pivot_table(
2100+
values="D",
2101+
index=["A", "B"],
2102+
columns="C",
2103+
aggfunc=f,
2104+
margins=True,
2105+
fill_value=0,
2106+
**kwargs,
2107+
)
2108+
2109+
expected = data.pivot_table(
2110+
values="D",
2111+
index=["A", "B"],
2112+
columns="C",
2113+
aggfunc=g,
2114+
margins=True,
2115+
fill_value=0,
2116+
)
2117+
2118+
tm.assert_frame_equal(result, expected)
2119+
2120+
grand_margin = g(data["D"])
2121+
2122+
margin_col = pivot_table(
2123+
data, values="D", index=["A", "B"], aggfunc=g, fill_value=0
2124+
)
2125+
margin_col.loc[("All", ""), "D"] = grand_margin
2126+
margin_col = margin_col["D"]
2127+
margin_col.name = "All"
2128+
tm.assert_series_equal(result["All"], margin_col)
2129+
2130+
margin_row = pivot_table(data, values="D", columns="C", aggfunc=g, fill_value=0)
2131+
margin_row["All"] = grand_margin
2132+
margin_row.index = [""]
2133+
margin_row.index.name = "B"
2134+
tm.assert_frame_equal(result.loc["All"], margin_row)
2135+
20612136
@pytest.mark.parametrize(
20622137
"f, f_numpy",
20632138
[

0 commit comments

Comments
 (0)