diff --git a/pandas/core/reshape/pivot.py b/pandas/core/reshape/pivot.py index 842a42f80e1b7..8fae01cb30d3d 100644 --- a/pandas/core/reshape/pivot.py +++ b/pandas/core/reshape/pivot.py @@ -329,8 +329,7 @@ def _all_key(key): piece = piece.copy() try: piece[all_key] = margin[key] - except TypeError: - + except ValueError: # we cannot reshape, so coerce the axis piece.set_axis( piece._get_axis(cat_axis)._to_safe_for_reshape(), diff --git a/pandas/tests/reshape/test_crosstab.py b/pandas/tests/reshape/test_crosstab.py index 1aadcfdc30f1b..5f6037276b31c 100644 --- a/pandas/tests/reshape/test_crosstab.py +++ b/pandas/tests/reshape/test_crosstab.py @@ -1,6 +1,8 @@ import numpy as np import pytest +from pandas.core.dtypes.common import is_categorical_dtype + from pandas import CategoricalIndex, DataFrame, Index, MultiIndex, Series, crosstab import pandas._testing as tm @@ -743,3 +745,33 @@ def test_margin_normalize_multiple_columns(self): ) expected.index.name = "C" tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("a_dtype", ["category", "int64"]) +@pytest.mark.parametrize("b_dtype", ["category", "int64"]) +def test_categoricals(a_dtype, b_dtype): + # https://github.com/pandas-dev/pandas/issues/37465 + g = np.random.RandomState(25982704) + a = Series(g.randint(0, 3, size=100)).astype(a_dtype) + b = Series(g.randint(0, 2, size=100)).astype(b_dtype) + result = crosstab(a, b, margins=True, dropna=False) + columns = Index([0, 1, "All"], dtype="object", name="col_0") + index = Index([0, 1, 2, "All"], dtype="object", name="row_0") + values = [[18, 16, 34], [18, 16, 34], [16, 16, 32], [52, 48, 100]] + expected = DataFrame(values, index, columns) + tm.assert_frame_equal(result, expected) + + # Verify when categorical does not have all values present + a.loc[a == 1] = 2 + a_is_cat = is_categorical_dtype(a.dtype) + assert not a_is_cat or a.value_counts().loc[1] == 0 + result = crosstab(a, b, margins=True, dropna=False) + values = [[18, 16, 34], [0, 0, np.nan], [34, 32, 66], [52, 48, 100]] + expected = DataFrame(values, index, columns) + if not a_is_cat: + expected = expected.loc[[0, 2, "All"]] + expected["All"] = expected["All"].astype("int64") + print(result) + print(expected) + print(expected.loc[[0, 2, "All"]]) + tm.assert_frame_equal(result, expected)