Skip to content

BUG: Prevent erroring out when comparing unordered categories with different permutations #51678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,7 @@ Categorical
- Bug in :meth:`DataFrame.groupby` and :meth:`Series.groupby` would reorder categories when used as a grouper (:issue:`48749`)
- Bug in :class:`Categorical` constructor when constructing from a :class:`Categorical` object and ``dtype="category"`` losing ordered-ness (:issue:`49309`)
- Bug in :meth:`.SeriesGroupBy.min`, :meth:`.SeriesGroupBy.max`, :meth:`.DataFrameGroupBy.min`, and :meth:`.DataFrameGroupBy.max` with unordered :class:`CategoricalDtype` with no groups failing to raise ``TypeError`` (:issue:`51034`)
- Bug in :meth:`Categorical._categories_match_up_to_permutation` would raise an error when comparing unordered categories with different permutations (:issue:`51543`)

Datetimelike
^^^^^^^^^^^^
Expand Down
34 changes: 24 additions & 10 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,11 @@ def func(self, other):
# Two Categoricals can only be compared if the categories are
# the same (maybe up to ordering, depending on ordered)

msg = "Categoricals can only be compared if 'categories' are the same."
if not self._categories_match_up_to_permutation(other):
msg = (
"Categoricals can only be compared if 'categories' and 'ordered' "
"are the same."
)
if not self._categories_match(other):
raise TypeError(msg)

if not self.ordered and not self.categories.equals(other.categories):
Expand Down Expand Up @@ -1961,12 +1964,11 @@ def _validate_listlike(self, value):

# require identical categories set
if isinstance(value, Categorical):
if not is_dtype_equal(self.dtype, value.dtype):
if not self._categories_match(value):
raise TypeError(
"Cannot set a Categorical with another, "
"without identical categories"
)
# is_dtype_equal implies categories_match_up_to_permutation
value = self._encode_with_my_categories(value)
return value._codes

Expand Down Expand Up @@ -2154,7 +2156,7 @@ def equals(self, other: object) -> bool:
"""
if not isinstance(other, Categorical):
return False
elif self._categories_match_up_to_permutation(other):
elif self._categories_match(other):
other = self._encode_with_my_categories(other)
return np.array_equal(self._codes, other._codes)
return False
Expand Down Expand Up @@ -2196,7 +2198,7 @@ def _encode_with_my_categories(self, other: Categorical) -> Categorical:
Notes
-----
This assumes we have already checked
self._categories_match_up_to_permutation(other).
self._categories_match(other).
"""
# Indexing on codes is more efficient if categories are the same,
# so we can apply some optimizations based on the degree of
Expand All @@ -2206,10 +2208,11 @@ def _encode_with_my_categories(self, other: Categorical) -> Categorical:
)
return self._from_backing_data(codes)

def _categories_match_up_to_permutation(self, other: Categorical) -> bool:
def _categories_match(self, other: Categorical) -> bool:
"""
Returns True if categoricals are the same dtype
same categories, and same ordered
Returns True if categoricals have the same dtype,
same ordered, and the same categories regardless
of permutation (when unordered)

Parameters
----------
Expand All @@ -2219,7 +2222,18 @@ def _categories_match_up_to_permutation(self, other: Categorical) -> bool:
-------
bool
"""
return hash(self.dtype) == hash(other.dtype)
try:
if not hasattr(other, "categories"):
other = other.dtype # type: ignore[assignment]
if set(self.categories) != set(other.categories):
return False
if self.ordered != other.ordered:
return False
if self.ordered and not (self.categories == other.categories).all():
return False
except AttributeError:
return False
return True

def describe(self) -> DataFrame:
"""
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/dtypes/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _maybe_unwrap(x):
raise TypeError("dtype of categories must be the same")

ordered = False
if all(first._categories_match_up_to_permutation(other) for other in to_union[1:]):
if all(first._categories_match(other) for other in to_union[1:]):
# identical categories - fastpath
categories = first.categories
ordered = first.ordered
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _is_dtype_compat(self, other) -> Categorical:
"""
if is_categorical_dtype(other):
other = extract_array(other)
if not other._categories_match_up_to_permutation(self):
if not other._categories_match(self):
raise TypeError(
"categories must match existing categories when appending"
)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ def _maybe_coerce_merge_keys(self) -> None:
if lk_is_cat and rk_is_cat:
lk = cast(Categorical, lk)
rk = cast(Categorical, rk)
if lk._categories_match_up_to_permutation(rk):
if lk._categories_match(rk):
continue

elif lk_is_cat or rk_is_cat:
Expand Down
38 changes: 17 additions & 21 deletions pandas/tests/arrays/categorical/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,33 @@


class TestCategoricalDtypes:
def test_categories_match_up_to_permutation(self):
def test_categories_match(self):
# test dtype comparisons between cats

c1 = Categorical(list("aabca"), categories=list("abc"), ordered=False)
c2 = Categorical(list("aabca"), categories=list("cab"), ordered=False)
c3 = Categorical(list("aabca"), categories=list("cab"), ordered=True)
assert c1._categories_match_up_to_permutation(c1)
assert c2._categories_match_up_to_permutation(c2)
assert c3._categories_match_up_to_permutation(c3)
assert c1._categories_match_up_to_permutation(c2)
assert not c1._categories_match_up_to_permutation(c3)
assert not c1._categories_match_up_to_permutation(Index(list("aabca")))
assert not c1._categories_match_up_to_permutation(c1.astype(object))
assert c1._categories_match_up_to_permutation(CategoricalIndex(c1))
assert c1._categories_match_up_to_permutation(
CategoricalIndex(c1, categories=list("cab"))
)
assert not c1._categories_match_up_to_permutation(
CategoricalIndex(c1, ordered=True)
)
assert c1._categories_match(c1)
assert c2._categories_match(c2)
assert c3._categories_match(c3)
assert c1._categories_match(c2)
assert not c1._categories_match(c3)
assert not c1._categories_match(Index(list("aabca")))
assert not c1._categories_match(c1.astype(object))
assert c1._categories_match(CategoricalIndex(c1))
assert c1._categories_match(CategoricalIndex(c1, categories=list("cab")))
assert not c1._categories_match(CategoricalIndex(c1, ordered=True))

# GH 16659
s1 = Series(c1)
s2 = Series(c2)
s3 = Series(c3)
assert c1._categories_match_up_to_permutation(s1)
assert c2._categories_match_up_to_permutation(s2)
assert c3._categories_match_up_to_permutation(s3)
assert c1._categories_match_up_to_permutation(s2)
assert not c1._categories_match_up_to_permutation(s3)
assert not c1._categories_match_up_to_permutation(s1.astype(object))
assert c1._categories_match(s1)
assert c2._categories_match(s2)
assert c3._categories_match(s3)
assert c1._categories_match(s2)
assert not c1._categories_match(s3)
assert not c1._categories_match(s1.astype(object))

def test_set_dtype_same(self):
c = Categorical(["a", "b", "c"])
Expand Down
15 changes: 12 additions & 3 deletions pandas/tests/arrays/categorical/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ def test_comparisons(self, factor):
tm.assert_numpy_array_equal(res, exp)

# Only categories with same categories can be compared
msg = "Categoricals can only be compared if 'categories' are the same"
msg = (
"Categoricals can only be compared if 'categories' and 'ordered' "
"are the same"
)
with pytest.raises(TypeError, match=msg):
cat > cat_rev

Expand Down Expand Up @@ -267,7 +270,10 @@ def test_comparisons(self, data, reverse, base):
tm.assert_numpy_array_equal(res_rev.values, exp_rev2)

# Only categories with same categories can be compared
msg = "Categoricals can only be compared if 'categories' are the same"
msg = (
"Categoricals can only be compared if 'categories' and 'ordered' "
"are the same"
)
with pytest.raises(TypeError, match=msg):
cat > cat_rev

Expand Down Expand Up @@ -333,7 +339,10 @@ def test_compare_different_lengths(self):
c1 = Categorical([], categories=["a", "b"])
c2 = Categorical([], categories=["a"])

msg = "Categoricals can only be compared if 'categories' are the same."
msg = (
"Categoricals can only be compared if 'categories' and 'ordered' "
"are the same."
)
with pytest.raises(TypeError, match=msg):
c1 == c2

Expand Down
11 changes: 10 additions & 1 deletion pandas/tests/indexes/categorical/test_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,23 @@ def test_equals_categorical(self):
with pytest.raises(ValueError, match="Lengths must match"):
ci1 == Index(["a", "b", "c"])

msg = "Categoricals can only be compared if 'categories' are the same"
msg = (
"Categoricals can only be compared if 'categories' and 'ordered' "
"are the same"
)
with pytest.raises(TypeError, match=msg):
ci1 == ci2
with pytest.raises(TypeError, match=msg):
ci1 == Categorical(ci1.values, ordered=False)
with pytest.raises(TypeError, match=msg):
ci1 == Categorical(ci1.values, categories=list("abc"))

ci1 = CategoricalIndex(["a", "b", 3], categories=["a", "b", 3])
ci2 = CategoricalIndex(["a", "b", 3], categories=["b", "a", 3])

assert ci1.equals(ci2)
assert ci1.astype(object).equals(ci2)

# tests
# make sure that we are testing for category inclusion properly
ci = CategoricalIndex(list("aabca"), categories=["c", "a", "b"])
Expand Down
6 changes: 3 additions & 3 deletions pandas/tests/reshape/merge/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1965,8 +1965,8 @@ def test_other_columns(self, left, right):
tm.assert_series_equal(result, expected)

# categories are preserved
assert left.X.values._categories_match_up_to_permutation(merged.X.values)
assert right.Z.values._categories_match_up_to_permutation(merged.Z.values)
assert left.X.values._categories_match(merged.X.values)
assert right.Z.values._categories_match(merged.Z.values)

@pytest.mark.parametrize(
"change",
Expand All @@ -1983,7 +1983,7 @@ def test_dtype_on_merged_different(self, change, join_type, left, right):
X = change(right.X.astype("object"))
right = right.assign(X=X)
assert is_categorical_dtype(left.X.values.dtype)
# assert not left.X.values._categories_match_up_to_permutation(right.X.values)
assert not left.X.values._categories_match(right.X.values)

merged = merge(left, right, on="X", how=join_type)

Expand Down