Skip to content

Commit 28cba66

Browse files
committed
Test for is_dtype_equal matching dtype.__eq__
1 parent 3e4a3ed commit 28cba66

File tree

6 files changed

+34
-3
lines changed

6 files changed

+34
-3
lines changed

pandas/core/arrays/categorical.py

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
is_scalar,
3737
is_timedelta64_dtype,
3838
needs_i8_conversion,
39+
pandas_dtype,
3940
)
4041
from pandas.core.dtypes.dtypes import CategoricalDtype
4142
from pandas.core.dtypes.generic import ABCIndex, ABCSeries
@@ -403,6 +404,7 @@ def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike:
403404
If copy is set to False and dtype is categorical, the original
404405
object is returned.
405406
"""
407+
dtype = pandas_dtype(dtype)
406408
if self.dtype is dtype:
407409
result = self.copy() if copy else self
408410

pandas/core/dtypes/common.py

+13
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,19 @@ def is_dtype_equal(source, target) -> bool:
639639
>>> is_dtype_equal(DatetimeTZDtype(tz="UTC"), "datetime64")
640640
False
641641
"""
642+
if isinstance(target, str):
643+
if not isinstance(source, str):
644+
# GH#38516 ensure we get the same behavior from
645+
# is_dtype_equal(CDT, "category") and CDT == "category"
646+
try:
647+
src = get_dtype(source)
648+
if isinstance(src, ExtensionDtype):
649+
return src == target
650+
except (TypeError, AttributeError):
651+
return False
652+
elif isinstance(source, str):
653+
return is_dtype_equal(target, source)
654+
642655
try:
643656
source = get_dtype(source)
644657
target = get_dtype(target)

pandas/core/dtypes/dtypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def __eq__(self, other: Any) -> bool:
355355
return False
356356
elif self.categories is None or other.categories is None:
357357
# For non-fully-initialized dtypes, these are only equal to
358-
# - the string "categorical" (handled above)
358+
# - the string "category" (handled above)
359359
# - other CategoricalDtype with categories=None
360360
return self.categories is other.categories
361361
elif self.ordered or other.ordered:

pandas/core/indexes/category.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,13 @@ def __new__(
201201

202202
if not isinstance(data, Categorical):
203203
data = Categorical(data, dtype=dtype)
204-
elif isinstance(dtype, CategoricalDtype) and dtype != data.dtype:
204+
elif (
205+
isinstance(dtype, CategoricalDtype)
206+
and dtype != data.dtype
207+
and dtype.categories is not None
208+
):
205209
# we want to silently ignore dtype='category'
210+
# TODO: what if dtype.ordered is not None but dtype.categories is?
206211
data = data._set_dtype(dtype)
207212

208213
data = data.copy() if copy else data

pandas/tests/arrays/categorical/test_dtypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_astype(self, ordered):
127127
expected = np.array(cat)
128128
tm.assert_numpy_array_equal(result, expected)
129129

130-
msg = r"Cannot cast object dtype to <class 'float'>"
130+
msg = r"Cannot cast object dtype to float64"
131131
with pytest.raises(ValueError, match=msg):
132132
cat.astype(float)
133133

pandas/tests/dtypes/test_dtypes.py

+11
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,20 @@ def test_hash_vs_equality(self, dtype):
9090
assert hash(dtype) == hash(dtype2)
9191

9292
def test_equality(self, dtype):
93+
assert dtype == "category"
9394
assert is_dtype_equal(dtype, "category")
95+
assert "category" == dtype
96+
assert is_dtype_equal("category", dtype)
97+
98+
assert dtype == CategoricalDtype()
9499
assert is_dtype_equal(dtype, CategoricalDtype())
100+
assert CategoricalDtype() == dtype
101+
assert is_dtype_equal(CategoricalDtype(), dtype)
102+
103+
assert dtype != "foo"
95104
assert not is_dtype_equal(dtype, "foo")
105+
assert "foo" != dtype
106+
assert not is_dtype_equal("foo", dtype)
96107

97108
def test_construction_from_string(self, dtype):
98109
result = CategoricalDtype.construct_from_string("category")

0 commit comments

Comments
 (0)