@@ -70,6 +70,7 @@ def pivot_table(
70
70
margins_name : Hashable = "All" ,
71
71
observed : bool = True ,
72
72
sort : bool = True ,
73
+ ** kwargs ,
73
74
) -> DataFrame :
74
75
index = _convert_by (index )
75
76
columns = _convert_by (columns )
@@ -90,6 +91,7 @@ def pivot_table(
90
91
margins_name = margins_name ,
91
92
observed = observed ,
92
93
sort = sort ,
94
+ kwargs = kwargs ,
93
95
)
94
96
pieces .append (_table )
95
97
keys .append (getattr (func , "__name__" , func ))
@@ -109,6 +111,7 @@ def pivot_table(
109
111
margins_name ,
110
112
observed ,
111
113
sort ,
114
+ kwargs ,
112
115
)
113
116
return table .__finalize__ (data , method = "pivot_table" )
114
117
@@ -125,6 +128,7 @@ def __internal_pivot_table(
125
128
margins_name : Hashable ,
126
129
observed : bool ,
127
130
sort : bool ,
131
+ kwargs ,
128
132
) -> DataFrame :
129
133
"""
130
134
Helper of :func:`pandas.pivot_table` for any non-list ``aggfunc``.
@@ -167,7 +171,7 @@ def __internal_pivot_table(
167
171
values = list (values )
168
172
169
173
grouped = data .groupby (keys , observed = observed , sort = sort , dropna = dropna )
170
- agged = grouped .agg (aggfunc )
174
+ agged = grouped .agg (aggfunc , ** kwargs )
171
175
172
176
if dropna and isinstance (agged , ABCDataFrame ) and len (agged .columns ):
173
177
agged = agged .dropna (how = "all" )
@@ -222,6 +226,7 @@ def __internal_pivot_table(
222
226
rows = index ,
223
227
cols = columns ,
224
228
aggfunc = aggfunc ,
229
+ kwargs = kwargs ,
225
230
observed = dropna ,
226
231
margins_name = margins_name ,
227
232
fill_value = fill_value ,
@@ -247,6 +252,7 @@ def _add_margins(
247
252
rows ,
248
253
cols ,
249
254
aggfunc ,
255
+ kwargs ,
250
256
observed : bool ,
251
257
margins_name : Hashable = "All" ,
252
258
fill_value = None ,
@@ -259,7 +265,7 @@ def _add_margins(
259
265
if margins_name in table .index .get_level_values (level ):
260
266
raise ValueError (msg )
261
267
262
- grand_margin = _compute_grand_margin (data , values , aggfunc , margins_name )
268
+ grand_margin = _compute_grand_margin (data , values , aggfunc , kwargs , margins_name )
263
269
264
270
if table .ndim == 2 :
265
271
# i.e. DataFrame
@@ -280,7 +286,15 @@ def _add_margins(
280
286
281
287
elif values :
282
288
marginal_result_set = _generate_marginal_results (
283
- table , data , values , rows , cols , aggfunc , observed , margins_name
289
+ table ,
290
+ data ,
291
+ values ,
292
+ rows ,
293
+ cols ,
294
+ aggfunc ,
295
+ kwargs ,
296
+ observed ,
297
+ margins_name ,
284
298
)
285
299
if not isinstance (marginal_result_set , tuple ):
286
300
return marginal_result_set
@@ -289,7 +303,7 @@ def _add_margins(
289
303
# no values, and table is a DataFrame
290
304
assert isinstance (table , ABCDataFrame )
291
305
marginal_result_set = _generate_marginal_results_without_values (
292
- table , data , rows , cols , aggfunc , observed , margins_name
306
+ table , data , rows , cols , aggfunc , kwargs , observed , margins_name
293
307
)
294
308
if not isinstance (marginal_result_set , tuple ):
295
309
return marginal_result_set
@@ -326,26 +340,26 @@ def _add_margins(
326
340
327
341
328
342
def _compute_grand_margin (
329
- data : DataFrame , values , aggfunc , margins_name : Hashable = "All"
343
+ data : DataFrame , values , aggfunc , kwargs , margins_name : Hashable = "All"
330
344
):
331
345
if values :
332
346
grand_margin = {}
333
347
for k , v in data [values ].items ():
334
348
try :
335
349
if isinstance (aggfunc , str ):
336
- grand_margin [k ] = getattr (v , aggfunc )()
350
+ grand_margin [k ] = getattr (v , aggfunc )(** kwargs )
337
351
elif isinstance (aggfunc , dict ):
338
352
if isinstance (aggfunc [k ], str ):
339
- grand_margin [k ] = getattr (v , aggfunc [k ])()
353
+ grand_margin [k ] = getattr (v , aggfunc [k ])(** kwargs )
340
354
else :
341
- grand_margin [k ] = aggfunc [k ](v )
355
+ grand_margin [k ] = aggfunc [k ](v , ** kwargs )
342
356
else :
343
- grand_margin [k ] = aggfunc (v )
357
+ grand_margin [k ] = aggfunc (v , ** kwargs )
344
358
except TypeError :
345
359
pass
346
360
return grand_margin
347
361
else :
348
- return {margins_name : aggfunc (data .index )}
362
+ return {margins_name : aggfunc (data .index , ** kwargs )}
349
363
350
364
351
365
def _generate_marginal_results (
@@ -355,6 +369,7 @@ def _generate_marginal_results(
355
369
rows ,
356
370
cols ,
357
371
aggfunc ,
372
+ kwargs ,
358
373
observed : bool ,
359
374
margins_name : Hashable = "All" ,
360
375
):
@@ -368,7 +383,11 @@ def _all_key(key):
368
383
return (key , margins_name ) + ("" ,) * (len (cols ) - 1 )
369
384
370
385
if len (rows ) > 0 :
371
- margin = data [rows + values ].groupby (rows , observed = observed ).agg (aggfunc )
386
+ margin = (
387
+ data [rows + values ]
388
+ .groupby (rows , observed = observed )
389
+ .agg (aggfunc , ** kwargs )
390
+ )
372
391
cat_axis = 1
373
392
374
393
for key , piece in table .T .groupby (level = 0 , observed = observed ):
@@ -393,7 +412,7 @@ def _all_key(key):
393
412
table_pieces .append (piece )
394
413
# GH31016 this is to calculate margin for each group, and assign
395
414
# corresponded key as index
396
- transformed_piece = DataFrame (piece .apply (aggfunc )).T
415
+ transformed_piece = DataFrame (piece .apply (aggfunc , ** kwargs )).T
397
416
if isinstance (piece .index , MultiIndex ):
398
417
# We are adding an empty level
399
418
transformed_piece .index = MultiIndex .from_tuples (
@@ -423,7 +442,9 @@ def _all_key(key):
423
442
margin_keys = table .columns
424
443
425
444
if len (cols ) > 0 :
426
- row_margin = data [cols + values ].groupby (cols , observed = observed ).agg (aggfunc )
445
+ row_margin = (
446
+ data [cols + values ].groupby (cols , observed = observed ).agg (aggfunc , ** kwargs )
447
+ )
427
448
row_margin = row_margin .stack ()
428
449
429
450
# GH#26568. Use names instead of indices in case of numeric names
@@ -442,6 +463,7 @@ def _generate_marginal_results_without_values(
442
463
rows ,
443
464
cols ,
444
465
aggfunc ,
466
+ kwargs ,
445
467
observed : bool ,
446
468
margins_name : Hashable = "All" ,
447
469
):
@@ -456,14 +478,16 @@ def _all_key():
456
478
return (margins_name ,) + ("" ,) * (len (cols ) - 1 )
457
479
458
480
if len (rows ) > 0 :
459
- margin = data .groupby (rows , observed = observed )[rows ].apply (aggfunc )
481
+ margin = data .groupby (rows , observed = observed )[rows ].apply (
482
+ aggfunc , ** kwargs
483
+ )
460
484
all_key = _all_key ()
461
485
table [all_key ] = margin
462
486
result = table
463
487
margin_keys .append (all_key )
464
488
465
489
else :
466
- margin = data .groupby (level = 0 , observed = observed ).apply (aggfunc )
490
+ margin = data .groupby (level = 0 , observed = observed ).apply (aggfunc , ** kwargs )
467
491
all_key = _all_key ()
468
492
table [all_key ] = margin
469
493
result = table
@@ -474,7 +498,9 @@ def _all_key():
474
498
margin_keys = table .columns
475
499
476
500
if len (cols ):
477
- row_margin = data .groupby (cols , observed = observed )[cols ].apply (aggfunc )
501
+ row_margin = data .groupby (cols , observed = observed )[cols ].apply (
502
+ aggfunc , ** kwargs
503
+ )
478
504
else :
479
505
row_margin = Series (np .nan , index = result .columns )
480
506
0 commit comments