Skip to content

Commit 2974e22

Browse files
authored
BUG: Correct crosstabl for categorical inputs (#37468)
1 parent cd0f936 commit 2974e22

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

pandas/core/reshape/pivot.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,7 @@ def _all_key(key):
329329
piece = piece.copy()
330330
try:
331331
piece[all_key] = margin[key]
332-
except TypeError:
333-
332+
except ValueError:
334333
# we cannot reshape, so coerce the axis
335334
piece.set_axis(
336335
piece._get_axis(cat_axis)._to_safe_for_reshape(),

pandas/tests/reshape/test_crosstab.py

+32
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
import pytest
33

4+
from pandas.core.dtypes.common import is_categorical_dtype
5+
46
from pandas import CategoricalIndex, DataFrame, Index, MultiIndex, Series, crosstab
57
import pandas._testing as tm
68

@@ -743,3 +745,33 @@ def test_margin_normalize_multiple_columns(self):
743745
)
744746
expected.index.name = "C"
745747
tm.assert_frame_equal(result, expected)
748+
749+
750+
@pytest.mark.parametrize("a_dtype", ["category", "int64"])
751+
@pytest.mark.parametrize("b_dtype", ["category", "int64"])
752+
def test_categoricals(a_dtype, b_dtype):
753+
# https://github.com/pandas-dev/pandas/issues/37465
754+
g = np.random.RandomState(25982704)
755+
a = Series(g.randint(0, 3, size=100)).astype(a_dtype)
756+
b = Series(g.randint(0, 2, size=100)).astype(b_dtype)
757+
result = crosstab(a, b, margins=True, dropna=False)
758+
columns = Index([0, 1, "All"], dtype="object", name="col_0")
759+
index = Index([0, 1, 2, "All"], dtype="object", name="row_0")
760+
values = [[18, 16, 34], [18, 16, 34], [16, 16, 32], [52, 48, 100]]
761+
expected = DataFrame(values, index, columns)
762+
tm.assert_frame_equal(result, expected)
763+
764+
# Verify when categorical does not have all values present
765+
a.loc[a == 1] = 2
766+
a_is_cat = is_categorical_dtype(a.dtype)
767+
assert not a_is_cat or a.value_counts().loc[1] == 0
768+
result = crosstab(a, b, margins=True, dropna=False)
769+
values = [[18, 16, 34], [0, 0, np.nan], [34, 32, 66], [52, 48, 100]]
770+
expected = DataFrame(values, index, columns)
771+
if not a_is_cat:
772+
expected = expected.loc[[0, 2, "All"]]
773+
expected["All"] = expected["All"].astype("int64")
774+
print(result)
775+
print(expected)
776+
print(expected.loc[[0, 2, "All"]])
777+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)