Skip to content

Commit f1aaf62

Browse files
charlesdong1991WillAyd
authored andcommitted
BUG: df.pivot_table fails when margin is True and only columns is defined (#31088)
1 parent cd20c95 commit f1aaf62

File tree

4 files changed

+77
-21
lines changed

4 files changed

+77
-21
lines changed

asv_bench/benchmarks/reshape.py

+3
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ def time_pivot_table_categorical_observed(self):
161161
observed=True,
162162
)
163163

164+
def time_pivot_table_margins_only_column(self):
165+
self.df.pivot_table(columns=["key2", "key3"], margins=True)
166+
164167

165168
class Crosstab:
166169
def setup(self):

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ Reshaping
141141

142142
-
143143
- Bug in :meth:`DataFrame.pivot_table` when only MultiIndexed columns is set (:issue:`17038`)
144+
- Bug in :meth:`DataFrame.pivot_table` when ``margin`` is ``True`` and only ``column`` is defined (:issue:`31016`)
144145
- Fix incorrect error message in :meth:`DataFrame.pivot` when ``columns`` is set to ``None``. (:issue:`30924`)
145146
- Bug in :func:`crosstab` when inputs are two Series and have tuple names, the output will keep dummy MultiIndex as columns. (:issue:`18321`)
146147

pandas/core/reshape/pivot.py

+15-21
Original file line numberDiff line numberDiff line change
@@ -226,15 +226,7 @@ def _add_margins(
226226

227227
elif values:
228228
marginal_result_set = _generate_marginal_results(
229-
table,
230-
data,
231-
values,
232-
rows,
233-
cols,
234-
aggfunc,
235-
observed,
236-
grand_margin,
237-
margins_name,
229+
table, data, values, rows, cols, aggfunc, observed, margins_name,
238230
)
239231
if not isinstance(marginal_result_set, tuple):
240232
return marginal_result_set
@@ -303,15 +295,7 @@ def _compute_grand_margin(data, values, aggfunc, margins_name: str = "All"):
303295

304296

305297
def _generate_marginal_results(
306-
table,
307-
data,
308-
values,
309-
rows,
310-
cols,
311-
aggfunc,
312-
observed,
313-
grand_margin,
314-
margins_name: str = "All",
298+
table, data, values, rows, cols, aggfunc, observed, margins_name: str = "All",
315299
):
316300
if len(cols) > 0:
317301
# need to "interleave" the margins
@@ -345,12 +329,22 @@ def _all_key(key):
345329
table_pieces.append(piece)
346330
margin_keys.append(all_key)
347331
else:
348-
margin = grand_margin
332+
from pandas import DataFrame
333+
349334
cat_axis = 0
350335
for key, piece in table.groupby(level=0, axis=cat_axis, observed=observed):
351-
all_key = _all_key(key)
336+
if len(cols) > 1:
337+
all_key = _all_key(key)
338+
else:
339+
all_key = margins_name
352340
table_pieces.append(piece)
353-
table_pieces.append(Series(margin[key], index=[all_key]))
341+
# GH31016 this is to calculate margin for each group, and assign
342+
# corresponded key as index
343+
transformed_piece = DataFrame(piece.apply(aggfunc)).T
344+
transformed_piece.index = Index([all_key], name=piece.index.name)
345+
346+
# append piece for margin into table_piece
347+
table_pieces.append(transformed_piece)
354348
margin_keys.append(all_key)
355349

356350
result = concat(table_pieces, axis=cat_axis)

pandas/tests/reshape/test_pivot.py

+58
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,64 @@ def _check_output(
910910
totals = table.loc[("All", ""), item]
911911
assert totals == self.data[item].mean()
912912

913+
@pytest.mark.parametrize(
914+
"columns, aggfunc, values, expected_columns",
915+
[
916+
(
917+
"A",
918+
np.mean,
919+
[[5.5, 5.5, 2.2, 2.2], [8.0, 8.0, 4.4, 4.4]],
920+
Index(["bar", "All", "foo", "All"], name="A"),
921+
),
922+
(
923+
["A", "B"],
924+
"sum",
925+
[[9, 13, 22, 5, 6, 11], [14, 18, 32, 11, 11, 22]],
926+
MultiIndex.from_tuples(
927+
[
928+
("bar", "one"),
929+
("bar", "two"),
930+
("bar", "All"),
931+
("foo", "one"),
932+
("foo", "two"),
933+
("foo", "All"),
934+
],
935+
names=["A", "B"],
936+
),
937+
),
938+
],
939+
)
940+
def test_margin_with_only_columns_defined(
941+
self, columns, aggfunc, values, expected_columns
942+
):
943+
# GH 31016
944+
df = pd.DataFrame(
945+
{
946+
"A": ["foo", "foo", "foo", "foo", "foo", "bar", "bar", "bar", "bar"],
947+
"B": ["one", "one", "one", "two", "two", "one", "one", "two", "two"],
948+
"C": [
949+
"small",
950+
"large",
951+
"large",
952+
"small",
953+
"small",
954+
"large",
955+
"small",
956+
"small",
957+
"large",
958+
],
959+
"D": [1, 2, 2, 3, 3, 4, 5, 6, 7],
960+
"E": [2, 4, 5, 5, 6, 6, 8, 9, 9],
961+
}
962+
)
963+
964+
result = df.pivot_table(columns=columns, margins=True, aggfunc=aggfunc)
965+
expected = pd.DataFrame(
966+
values, index=Index(["D", "E"]), columns=expected_columns
967+
)
968+
969+
tm.assert_frame_equal(result, expected)
970+
913971
def test_margins_dtype(self):
914972
# GH 17013
915973

0 commit comments

Comments
 (0)