Skip to content

Commit 7123154

Browse files
committed
BUG: pivot_table raising for nullable dtype and margins (pandas-dev#48714)
1 parent a746068 commit 7123154

File tree

3 files changed

+23
-0
lines changed

3 files changed

+23
-0
lines changed

doc/source/whatsnew/v1.6.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ Groupby/resample/rolling
239239

240240
Reshaping
241241
^^^^^^^^^
242+
- Bug in :meth:`DataFrame.pivot_table` raising ``TypeError`` for nullable dtype and ``margins=True`` (:issue:`48681`)
242243
- Bug in :meth:`DataFrame.pivot` not respecting ``None`` as column name (:issue:`48293`)
243244
- Bug in :func:`join` when ``left_on`` or ``right_on`` is or includes a :class:`CategoricalIndex` incorrectly raising ``AttributeError`` (:issue:`48464`)
244245
-

pandas/core/reshape/pivot.py

+5
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from pandas.core.dtypes.cast import maybe_downcast_to_dtype
2727
from pandas.core.dtypes.common import (
28+
is_extension_array_dtype,
2829
is_integer_dtype,
2930
is_list_like,
3031
is_nested_list_like,
@@ -324,6 +325,10 @@ def _add_margins(
324325
row_names = result.index.names
325326
# check the result column and leave floats
326327
for dtype in set(result.dtypes):
328+
if is_extension_array_dtype(dtype):
329+
# Can hold NA already
330+
continue
331+
327332
cols = result.select_dtypes([dtype]).columns
328333
margin_dummy[cols] = margin_dummy[cols].apply(
329334
maybe_downcast_to_dtype, args=(dtype,)

pandas/tests/reshape/test_pivot.py

+17
Original file line numberDiff line numberDiff line change
@@ -2183,6 +2183,23 @@ def test_pivot_table_sort_false(self):
21832183
)
21842184
tm.assert_frame_equal(result, expected)
21852185

2186+
def test_pivot_table_nullable_margins(self):
2187+
# GH#48681
2188+
df = DataFrame(
2189+
{"a": "A", "b": [1, 2], "sales": Series([10, 11], dtype="Int64")}
2190+
)
2191+
2192+
result = df.pivot_table(index="b", columns="a", margins=True, aggfunc="sum")
2193+
expected = DataFrame(
2194+
[[10, 10], [11, 11], [21, 21]],
2195+
index=Index([1, 2, "All"], name="b"),
2196+
columns=MultiIndex.from_tuples(
2197+
[("sales", "A"), ("sales", "All")], names=[None, "a"]
2198+
),
2199+
dtype="Int64",
2200+
)
2201+
tm.assert_frame_equal(result, expected)
2202+
21862203
def test_pivot_table_sort_false_with_multiple_values(self):
21872204
df = DataFrame(
21882205
{

0 commit comments

Comments
 (0)