Skip to content

Commit 2bf68eb

Browse files
authored
API: CategoricalDtype.__eq__ with categories=None stricter (#38516)
1 parent 12c58e1 commit 2bf68eb

File tree

8 files changed

+62
-14
lines changed

8 files changed

+62
-14
lines changed

doc/source/whatsnew/v1.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ See :ref:`install.dependencies` and :ref:`install.optional_dependencies` for mor
133133

134134
Other API changes
135135
^^^^^^^^^^^^^^^^^
136-
136+
- Partially initialized :class:`CategoricalDtype` (i.e. those with ``categories=None`` objects will no longer compare as equal to fully initialized dtype objects.
137137
-
138138
-
139139

pandas/core/arrays/categorical.py

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
is_scalar,
3737
is_timedelta64_dtype,
3838
needs_i8_conversion,
39+
pandas_dtype,
3940
)
4041
from pandas.core.dtypes.dtypes import CategoricalDtype
4142
from pandas.core.dtypes.generic import ABCIndex, ABCSeries
@@ -409,6 +410,7 @@ def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike:
409410
If copy is set to False and dtype is categorical, the original
410411
object is returned.
411412
"""
413+
dtype = pandas_dtype(dtype)
412414
if self.dtype is dtype:
413415
result = self.copy() if copy else self
414416

pandas/core/dtypes/common.py

+13
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,19 @@ def is_dtype_equal(source, target) -> bool:
639639
>>> is_dtype_equal(DatetimeTZDtype(tz="UTC"), "datetime64")
640640
False
641641
"""
642+
if isinstance(target, str):
643+
if not isinstance(source, str):
644+
# GH#38516 ensure we get the same behavior from
645+
# is_dtype_equal(CDT, "category") and CDT == "category"
646+
try:
647+
src = get_dtype(source)
648+
if isinstance(src, ExtensionDtype):
649+
return src == target
650+
except (TypeError, AttributeError):
651+
return False
652+
elif isinstance(source, str):
653+
return is_dtype_equal(target, source)
654+
642655
try:
643656
source = get_dtype(source)
644657
target = get_dtype(target)

pandas/core/dtypes/dtypes.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -354,12 +354,10 @@ def __eq__(self, other: Any) -> bool:
354354
elif not (hasattr(other, "ordered") and hasattr(other, "categories")):
355355
return False
356356
elif self.categories is None or other.categories is None:
357-
# We're forced into a suboptimal corner thanks to math and
358-
# backwards compatibility. We require that `CDT(...) == 'category'`
359-
# for all CDTs **including** `CDT(None, ...)`. Therefore, *all*
360-
# CDT(., .) = CDT(None, False) and *all*
361-
# CDT(., .) = CDT(None, True).
362-
return True
357+
# For non-fully-initialized dtypes, these are only equal to
358+
# - the string "category" (handled above)
359+
# - other CategoricalDtype with categories=None
360+
return self.categories is other.categories
363361
elif self.ordered or other.ordered:
364362
# At least one has ordered=True; equal if both have ordered=True
365363
# and the same values for categories in the same order.

pandas/tests/arrays/categorical/test_dtypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_astype(self, ordered):
127127
expected = np.array(cat)
128128
tm.assert_numpy_array_equal(result, expected)
129129

130-
msg = r"Cannot cast object dtype to <class 'float'>"
130+
msg = r"Cannot cast object dtype to float64"
131131
with pytest.raises(ValueError, match=msg):
132132
cat.astype(float)
133133

pandas/tests/dtypes/test_common.py

-1
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,6 @@ def test_is_complex_dtype():
641641
(pd.CategoricalIndex(["a", "b"]).dtype, CategoricalDtype(["a", "b"])),
642642
(pd.CategoricalIndex(["a", "b"]), CategoricalDtype(["a", "b"])),
643643
(CategoricalDtype(), CategoricalDtype()),
644-
(CategoricalDtype(["a", "b"]), CategoricalDtype()),
645644
(pd.DatetimeIndex([1, 2]), np.dtype("=M8[ns]")),
646645
(pd.DatetimeIndex([1, 2]).dtype, np.dtype("=M8[ns]")),
647646
("<M8[ns]", np.dtype("<M8[ns]")),

pandas/tests/dtypes/test_dtypes.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,20 @@ def test_hash_vs_equality(self, dtype):
9090
assert hash(dtype) == hash(dtype2)
9191

9292
def test_equality(self, dtype):
93+
assert dtype == "category"
9394
assert is_dtype_equal(dtype, "category")
95+
assert "category" == dtype
96+
assert is_dtype_equal("category", dtype)
97+
98+
assert dtype == CategoricalDtype()
9499
assert is_dtype_equal(dtype, CategoricalDtype())
100+
assert CategoricalDtype() == dtype
101+
assert is_dtype_equal(CategoricalDtype(), dtype)
102+
103+
assert dtype != "foo"
95104
assert not is_dtype_equal(dtype, "foo")
105+
assert "foo" != dtype
106+
assert not is_dtype_equal("foo", dtype)
96107

97108
def test_construction_from_string(self, dtype):
98109
result = CategoricalDtype.construct_from_string("category")
@@ -834,10 +845,27 @@ def test_categorical_equality(self, ordered1, ordered2):
834845
c1 = CategoricalDtype(list("abc"), ordered1)
835846
c2 = CategoricalDtype(None, ordered2)
836847
c3 = CategoricalDtype(None, ordered1)
837-
assert c1 == c2
838-
assert c2 == c1
848+
assert c1 != c2
849+
assert c2 != c1
839850
assert c2 == c3
840851

852+
def test_categorical_dtype_equality_requires_categories(self):
853+
# CategoricalDtype with categories=None is *not* equal to
854+
# any fully-initialized CategoricalDtype
855+
first = CategoricalDtype(["a", "b"])
856+
second = CategoricalDtype()
857+
third = CategoricalDtype(ordered=True)
858+
859+
assert second == second
860+
assert third == third
861+
862+
assert first != second
863+
assert second != first
864+
assert first != third
865+
assert third != first
866+
assert second == third
867+
assert third == second
868+
841869
@pytest.mark.parametrize("categories", [list("abc"), None])
842870
@pytest.mark.parametrize("other", ["category", "not a category"])
843871
def test_categorical_equality_strings(self, categories, ordered, other):

pandas/tests/reshape/merge/test_merge.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -1622,7 +1622,7 @@ def test_identical(self, left):
16221622
merged = pd.merge(left, left, on="X")
16231623
result = merged.dtypes.sort_index()
16241624
expected = Series(
1625-
[CategoricalDtype(), np.dtype("O"), np.dtype("O")],
1625+
[CategoricalDtype(categories=["foo", "bar"]), np.dtype("O"), np.dtype("O")],
16261626
index=["X", "Y_x", "Y_y"],
16271627
)
16281628
tm.assert_series_equal(result, expected)
@@ -1633,7 +1633,11 @@ def test_basic(self, left, right):
16331633
merged = pd.merge(left, right, on="X")
16341634
result = merged.dtypes.sort_index()
16351635
expected = Series(
1636-
[CategoricalDtype(), np.dtype("O"), np.dtype("int64")],
1636+
[
1637+
CategoricalDtype(categories=["foo", "bar"]),
1638+
np.dtype("O"),
1639+
np.dtype("int64"),
1640+
],
16371641
index=["X", "Y", "Z"],
16381642
)
16391643
tm.assert_series_equal(result, expected)
@@ -1713,7 +1717,11 @@ def test_other_columns(self, left, right):
17131717
merged = pd.merge(left, right, on="X")
17141718
result = merged.dtypes.sort_index()
17151719
expected = Series(
1716-
[CategoricalDtype(), np.dtype("O"), CategoricalDtype()],
1720+
[
1721+
CategoricalDtype(categories=["foo", "bar"]),
1722+
np.dtype("O"),
1723+
CategoricalDtype(categories=[1, 2]),
1724+
],
17171725
index=["X", "Y", "Z"],
17181726
)
17191727
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)