Skip to content

Commit ed4842b

Browse files
committed
BUG: Correct crosstabl for categorical inputs
Change catch types to reflect error changes closes #37465
1 parent 927c83c commit ed4842b

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

pandas/core/reshape/pivot.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,7 @@ def _all_key(key):
329329
piece = piece.copy()
330330
try:
331331
piece[all_key] = margin[key]
332-
except TypeError:
333-
332+
except ValueError:
334333
# we cannot reshape, so coerce the axis
335334
piece.set_axis(
336335
piece._get_axis(cat_axis)._to_safe_for_reshape(),

pandas/tests/reshape/test_crosstab.py

+31
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from pandas import CategoricalIndex, DataFrame, Index, MultiIndex, Series, crosstab
55
import pandas._testing as tm
6+
from pandas.core.dtypes.common import is_categorical_dtype
67

78

89
class TestCrosstab:
@@ -743,3 +744,33 @@ def test_margin_normalize_multiple_columns(self):
743744
)
744745
expected.index.name = "C"
745746
tm.assert_frame_equal(result, expected)
747+
748+
749+
@pytest.mark.parametrize("a_dtype", ["category", "int64"])
750+
@pytest.mark.parametrize("b_dtype", ["category", "int64"])
751+
def test_categoricals(a_dtype, b_dtype):
752+
# https://github.com/pandas-dev/pandas/issues/37465
753+
g = np.random.RandomState(25982704)
754+
a = Series(g.randint(0, 3, size=100)).astype(a_dtype)
755+
b = Series(g.randint(0, 2, size=100)).astype(b_dtype)
756+
result = crosstab(a, b, margins=True, dropna=False)
757+
columns = Index([0, 1, "All"], dtype="object", name="col_0")
758+
index = Index([0, 1, 2, "All"], dtype="object", name="row_0")
759+
values = [[18, 16, 34], [18, 16, 34], [16, 16, 32], [52, 48, 100]]
760+
expected = DataFrame(values, index, columns)
761+
tm.assert_frame_equal(result, expected)
762+
763+
# Verify when categorical does not have all values present
764+
a.loc[a == 1] = 2
765+
a_is_cat = is_categorical_dtype(a.dtype)
766+
assert not a_is_cat or a.value_counts().loc[1] == 0
767+
result = crosstab(a, b, margins=True, dropna=False)
768+
values = [[18, 16, 34], [0, 0, np.nan], [34, 32, 66], [52, 48, 100]]
769+
expected = DataFrame(values, index, columns)
770+
if not a_is_cat:
771+
expected = expected.loc[[0, 2, "All"]]
772+
expected["All"] = expected["All"].astype("int64")
773+
print(result)
774+
print(expected)
775+
print(expected.loc[[0, 2, "All"]])
776+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)