From df64bcd57e014ea0bda5be18278ad887c8763f2e Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 15 Dec 2020 18:33:23 -0800 Subject: [PATCH 1/4] API: CategoricalDtype.__eq__ with categories=None stricter --- pandas/core/dtypes/dtypes.py | 10 ++++------ pandas/tests/dtypes/cast/test_infer_dtype.py | 5 +++-- pandas/tests/dtypes/test_common.py | 1 - pandas/tests/dtypes/test_dtypes.py | 21 ++++++++++++++++++-- pandas/tests/reshape/merge/test_merge.py | 14 ++++++++++--- 5 files changed, 37 insertions(+), 14 deletions(-) diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 0de8a07abbec3..1a4a836e96daa 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -354,12 +354,10 @@ def __eq__(self, other: Any) -> bool: elif not (hasattr(other, "ordered") and hasattr(other, "categories")): return False elif self.categories is None or other.categories is None: - # We're forced into a suboptimal corner thanks to math and - # backwards compatibility. We require that `CDT(...) == 'category'` - # for all CDTs **including** `CDT(None, ...)`. Therefore, *all* - # CDT(., .) = CDT(None, False) and *all* - # CDT(., .) = CDT(None, True). - return True + # For non-fully-initialized dtypes, these are only equal to + # - the string "categorical" (handled above) + # - other CategoricalDtype with categories=None + return self.categories is other.categories elif self.ordered or other.ordered: # At least one has ordered=True; equal if both have ordered=True # and the same values for categories in the same order. diff --git a/pandas/tests/dtypes/cast/test_infer_dtype.py b/pandas/tests/dtypes/cast/test_infer_dtype.py index 65da8985843f9..c21dd90f7c72b 100644 --- a/pandas/tests/dtypes/cast/test_infer_dtype.py +++ b/pandas/tests/dtypes/cast/test_infer_dtype.py @@ -8,6 +8,7 @@ from pandas import ( Categorical, + CategoricalDtype, Interval, Period, Series, @@ -149,8 +150,8 @@ def test_infer_dtype_from_scalar_errors(): (np.array([[1.0, 2.0]]), np.float_, False), (Categorical(list("aabc")), np.object_, False), (Categorical([1, 2, 3]), np.int64, False), - (Categorical(list("aabc")), "category", True), - (Categorical([1, 2, 3]), "category", True), + (Categorical(list("aabc")), CategoricalDtype(categories=["a", "b", "c"]), True), + (Categorical([1, 2, 3]), CategoricalDtype(categories=[1, 2, 3]), True), (Timestamp("20160101"), np.object_, False), (np.datetime64("2016-01-01"), np.dtype("=M8[D]"), False), (date_range("20160101", periods=3), np.dtype("=M8[ns]"), False), diff --git a/pandas/tests/dtypes/test_common.py b/pandas/tests/dtypes/test_common.py index 0d0601aa542b4..9e75ba0864e76 100644 --- a/pandas/tests/dtypes/test_common.py +++ b/pandas/tests/dtypes/test_common.py @@ -641,7 +641,6 @@ def test_is_complex_dtype(): (pd.CategoricalIndex(["a", "b"]).dtype, CategoricalDtype(["a", "b"])), (pd.CategoricalIndex(["a", "b"]), CategoricalDtype(["a", "b"])), (CategoricalDtype(), CategoricalDtype()), - (CategoricalDtype(["a", "b"]), CategoricalDtype()), (pd.DatetimeIndex([1, 2]), np.dtype("=M8[ns]")), (pd.DatetimeIndex([1, 2]).dtype, np.dtype("=M8[ns]")), (" Date: Wed, 16 Dec 2020 17:24:55 -0800 Subject: [PATCH 2/4] Test for is_dtype_equal matching dtype.__eq__ --- pandas/core/arrays/categorical.py | 2 ++ pandas/core/dtypes/common.py | 13 +++++++++++++ pandas/core/dtypes/dtypes.py | 2 +- pandas/core/indexes/category.py | 7 ++++++- pandas/tests/arrays/categorical/test_dtypes.py | 2 +- pandas/tests/dtypes/test_dtypes.py | 11 +++++++++++ 6 files changed, 34 insertions(+), 3 deletions(-) diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 27110fe1f8439..5a418ec908d12 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -36,6 +36,7 @@ is_scalar, is_timedelta64_dtype, needs_i8_conversion, + pandas_dtype, ) from pandas.core.dtypes.dtypes import CategoricalDtype from pandas.core.dtypes.generic import ABCIndex, ABCSeries @@ -403,6 +404,7 @@ def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike: If copy is set to False and dtype is categorical, the original object is returned. """ + dtype = pandas_dtype(dtype) if self.dtype is dtype: result = self.copy() if copy else self diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index 081339583e3fd..5869b2cf22516 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -639,6 +639,19 @@ def is_dtype_equal(source, target) -> bool: >>> is_dtype_equal(DatetimeTZDtype(tz="UTC"), "datetime64") False """ + if isinstance(target, str): + if not isinstance(source, str): + # GH#38516 ensure we get the same behavior from + # is_dtype_equal(CDT, "category") and CDT == "category" + try: + src = get_dtype(source) + if isinstance(src, ExtensionDtype): + return src == target + except (TypeError, AttributeError): + return False + elif isinstance(source, str): + return is_dtype_equal(target, source) + try: source = get_dtype(source) target = get_dtype(target) diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 1a4a836e96daa..75f3b511bc57d 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -355,7 +355,7 @@ def __eq__(self, other: Any) -> bool: return False elif self.categories is None or other.categories is None: # For non-fully-initialized dtypes, these are only equal to - # - the string "categorical" (handled above) + # - the string "category" (handled above) # - other CategoricalDtype with categories=None return self.categories is other.categories elif self.ordered or other.ordered: diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index e2a7752cf3f0d..35fd8af9cd36e 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -201,8 +201,13 @@ def __new__( if not isinstance(data, Categorical): data = Categorical(data, dtype=dtype) - elif isinstance(dtype, CategoricalDtype) and dtype != data.dtype: + elif ( + isinstance(dtype, CategoricalDtype) + and dtype != data.dtype + and dtype.categories is not None + ): # we want to silently ignore dtype='category' + # TODO: what if dtype.ordered is not None but dtype.categories is? data = data._set_dtype(dtype) data = data.copy() if copy else data diff --git a/pandas/tests/arrays/categorical/test_dtypes.py b/pandas/tests/arrays/categorical/test_dtypes.py index 12654388de904..a2192b2810596 100644 --- a/pandas/tests/arrays/categorical/test_dtypes.py +++ b/pandas/tests/arrays/categorical/test_dtypes.py @@ -127,7 +127,7 @@ def test_astype(self, ordered): expected = np.array(cat) tm.assert_numpy_array_equal(result, expected) - msg = r"Cannot cast object dtype to " + msg = r"Cannot cast object dtype to float64" with pytest.raises(ValueError, match=msg): cat.astype(float) diff --git a/pandas/tests/dtypes/test_dtypes.py b/pandas/tests/dtypes/test_dtypes.py index 1afe389e86668..8ba8562affb67 100644 --- a/pandas/tests/dtypes/test_dtypes.py +++ b/pandas/tests/dtypes/test_dtypes.py @@ -90,9 +90,20 @@ def test_hash_vs_equality(self, dtype): assert hash(dtype) == hash(dtype2) def test_equality(self, dtype): + assert dtype == "category" assert is_dtype_equal(dtype, "category") + assert "category" == dtype + assert is_dtype_equal("category", dtype) + + assert dtype == CategoricalDtype() assert is_dtype_equal(dtype, CategoricalDtype()) + assert CategoricalDtype() == dtype + assert is_dtype_equal(CategoricalDtype(), dtype) + + assert dtype != "foo" assert not is_dtype_equal(dtype, "foo") + assert "foo" != dtype + assert not is_dtype_equal("foo", dtype) def test_construction_from_string(self, dtype): result = CategoricalDtype.construct_from_string("category") From c78aa7794627d42afe3fec99fbaf9eed68b9f497 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 16 Dec 2020 14:27:05 -0800 Subject: [PATCH 3/4] REF: avoid special-casing Categorical astype --- pandas/core/arrays/integer.py | 7 ++++++- pandas/core/indexes/category.py | 5 ----- pandas/core/indexes/extension.py | 11 +++++++---- pandas/core/internals/blocks.py | 10 +--------- 4 files changed, 14 insertions(+), 19 deletions(-) diff --git a/pandas/core/arrays/integer.py b/pandas/core/arrays/integer.py index 98e29f2062983..0d62051c116d1 100644 --- a/pandas/core/arrays/integer.py +++ b/pandas/core/arrays/integer.py @@ -9,7 +9,7 @@ from pandas.compat.numpy import function as nv from pandas.util._decorators import cache_readonly -from pandas.core.dtypes.base import register_extension_dtype +from pandas.core.dtypes.base import ExtensionDtype, register_extension_dtype from pandas.core.dtypes.common import ( is_bool_dtype, is_datetime64_dtype, @@ -409,6 +409,11 @@ def astype(self, dtype, copy: bool = True) -> ArrayLike: elif isinstance(dtype, StringDtype): return dtype.construct_array_type()._from_sequence(self, copy=False) + elif isinstance(dtype, ExtensionDtype): + # e.g. Categorical + cls = dtype.construct_array_type() + return cls._from_sequence(self, dtype=dtype, copy=copy) + # coerce if is_float_dtype(dtype): # In astype, we consider dtype=float to also mean na_value=np.nan diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index 35fd8af9cd36e..b068d34599efc 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -375,11 +375,6 @@ def __contains__(self, key: Any) -> bool: return contains(self, key, container=self._engine) - @doc(Index.astype) - def astype(self, dtype, copy=True): - res_data = self._data.astype(dtype, copy=copy) - return Index(res_data, name=self.name) - @doc(Index.fillna) def fillna(self, value, downcast=None): value = self._require_scalar(value) diff --git a/pandas/core/indexes/extension.py b/pandas/core/indexes/extension.py index 73f96b2f6ad41..661adde44089c 100644 --- a/pandas/core/indexes/extension.py +++ b/pandas/core/indexes/extension.py @@ -11,7 +11,7 @@ from pandas.errors import AbstractMethodError from pandas.util._decorators import cache_readonly, doc -from pandas.core.dtypes.common import is_dtype_equal, is_object_dtype +from pandas.core.dtypes.common import is_dtype_equal, is_object_dtype, pandas_dtype from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries from pandas.core.arrays import ExtensionArray @@ -294,9 +294,12 @@ def map(self, mapper, na_action=None): @doc(Index.astype) def astype(self, dtype, copy=True): - if is_dtype_equal(self.dtype, dtype) and copy is False: - # Ensure that self.astype(self.dtype) is self - return self + dtype = pandas_dtype(dtype) + if is_dtype_equal(self.dtype, dtype): + if not copy: + # Ensure that self.astype(self.dtype) is self + return self + return self.copy() new_values = self._data.astype(dtype, copy=copy) diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 2630c07814bb2..9844a8ad8f55a 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -642,15 +642,7 @@ def astype(self, dtype, copy: bool = False, errors: str = "raise"): def _astype(self, dtype: DtypeObj, copy: bool) -> ArrayLike: values = self.values - if is_categorical_dtype(dtype): - - if is_categorical_dtype(values.dtype): - # GH#10696/GH#18593: update an existing categorical efficiently - return values.astype(dtype, copy=copy) - - return Categorical(values, dtype=dtype) - - elif is_datetime64tz_dtype(dtype) and is_datetime64_dtype(values.dtype): + if is_datetime64tz_dtype(dtype) and is_datetime64_dtype(values.dtype): # if we are passed a datetime64[ns, tz] if copy: # this should be the only copy From a346d356a13bfacdfe3d9536cbc48d880222ab7a Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 22 Dec 2020 09:47:07 -0800 Subject: [PATCH 4/4] remove duplicate check --- pandas/core/arrays/integer.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pandas/core/arrays/integer.py b/pandas/core/arrays/integer.py index de49a3ee024eb..d01f84b224a89 100644 --- a/pandas/core/arrays/integer.py +++ b/pandas/core/arrays/integer.py @@ -395,11 +395,6 @@ def astype(self, dtype, copy: bool = True) -> ArrayLike: if isinstance(dtype, ExtensionDtype): return super().astype(dtype, copy=copy) - elif isinstance(dtype, ExtensionDtype): - # e.g. Categorical - cls = dtype.construct_array_type() - return cls._from_sequence(self, dtype=dtype, copy=copy) - # coerce if is_float_dtype(dtype): # In astype, we consider dtype=float to also mean na_value=np.nan