From dbd1e7cccd5fbe8df852a0b8d2258b9da3948ae3 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 7 Oct 2020 16:15:17 -0700 Subject: [PATCH 1/2] BUG/API: tighter checks on DTI/TDI.equals --- pandas/core/arrays/timedeltas.py | 4 +++ pandas/core/indexes/datetimelike.py | 33 ++++++++++++++----- pandas/tests/indexes/datetimelike.py | 14 ++++++++ .../indexes/timedeltas/test_constructors.py | 12 +++++++ 4 files changed, 55 insertions(+), 8 deletions(-) diff --git a/pandas/core/arrays/timedeltas.py b/pandas/core/arrays/timedeltas.py index c97c7da375fd4..ed5586c57a5b7 100644 --- a/pandas/core/arrays/timedeltas.py +++ b/pandas/core/arrays/timedeltas.py @@ -23,6 +23,7 @@ from pandas.core.dtypes.common import ( DT64NS_DTYPE, TD64NS_DTYPE, + is_categorical_dtype, is_dtype_equal, is_float_dtype, is_integer_dtype, @@ -953,6 +954,9 @@ def sequence_to_td64ns(data, copy=False, unit=None, errors="raise"): data = data._data elif isinstance(data, IntegerArray): data = data.to_numpy("int64", na_value=tslibs.iNaT) + elif is_categorical_dtype(data.dtype): + data = data.categories.take(data.codes, fill_value=NaT)._values + copy = False # Convert whatever we have into timedelta64[ns] dtype if is_object_dtype(data.dtype) or is_string_dtype(data.dtype): diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index d2162d987ccd6..aedd04bbfcb98 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -16,6 +16,7 @@ from pandas.core.dtypes.common import ( ensure_int64, is_bool_dtype, + is_categorical_dtype, is_dtype_equal, is_integer, is_list_like, @@ -137,14 +138,30 @@ def equals(self, other: object) -> bool: elif other.dtype.kind in ["f", "i", "u", "c"]: return False elif not isinstance(other, type(self)): - try: - other = type(self)(other) - except (ValueError, TypeError, OverflowError): - # e.g. - # ValueError -> cannot parse str entry, or OutOfBoundsDatetime - # TypeError -> trying to convert IntervalIndex to DatetimeIndex - # OverflowError -> Index([very_large_timedeltas]) - return False + inferrable = [ + "timedelta", + "timedelta64", + "datetime", + "datetime64", + "date", + "period", + ] + + should_try = False + if other.dtype == object: + should_try = other.inferred_type in inferrable + elif is_categorical_dtype(other.dtype): + should_try = other.categories.inferred_type in inferrable + + if should_try: + try: + other = type(self)(other) + except (ValueError, TypeError, OverflowError): + # e.g. + # ValueError -> cannot parse str entry, or OutOfBoundsDatetime + # TypeError -> trying to convert IntervalIndex to DatetimeIndex + # OverflowError -> Index([very_large_timedeltas]) + return False if not is_dtype_equal(self.dtype, other.dtype): # have different timezone diff --git a/pandas/tests/indexes/datetimelike.py b/pandas/tests/indexes/datetimelike.py index 71ae1d6bda9c7..6d0b228bb861b 100644 --- a/pandas/tests/indexes/datetimelike.py +++ b/pandas/tests/indexes/datetimelike.py @@ -115,3 +115,17 @@ def test_not_equals_numeric(self): assert not index.equals(pd.Index(index.asi8)) assert not index.equals(pd.Index(index.asi8.astype("u8"))) assert not index.equals(pd.Index(index.asi8).astype("f8")) + + def test_equals(self): + index = self.create_index() + + assert index.equals(index.astype(object)) + assert index.equals(pd.CategoricalIndex(index)) + assert index.equals(pd.CategoricalIndex(index.astype(object))) + + def test_not_equals_strings(self): + index = self.create_index() + + other = pd.Index([str(x) for x in index], dtype=object) + assert not index.equals(other) + assert not index.equals(pd.CategoricalIndex(other)) diff --git a/pandas/tests/indexes/timedeltas/test_constructors.py b/pandas/tests/indexes/timedeltas/test_constructors.py index 41e4e220c999c..09344bb5054f6 100644 --- a/pandas/tests/indexes/timedeltas/test_constructors.py +++ b/pandas/tests/indexes/timedeltas/test_constructors.py @@ -238,3 +238,15 @@ def test_explicit_none_freq(self): result = TimedeltaIndex(tdi._data, freq=None) assert result.freq is None + + def test_from_categorical(self): + tdi = timedelta_range(1, periods=5) + + cat = pd.Categorical(tdi) + + result = TimedeltaIndex(cat) + tm.assert_index_equal(result, tdi) + + ci = pd.CategoricalIndex(tdi) + result = TimedeltaIndex(ci) + tm.assert_index_equal(result, tdi) From 7792f7fdc5c467805ab91df2894d47bf01f5dcca Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 7 Oct 2020 18:56:39 -0700 Subject: [PATCH 2/2] mypy fixup --- pandas/core/indexes/datetimelike.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index cb467fb10efe5..5baa103a25d51 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -2,7 +2,7 @@ Base and utility classes for tseries type pandas objects. """ from datetime import datetime, tzinfo -from typing import Any, List, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, List, Optional, TypeVar, Union, cast import numpy as np @@ -42,6 +42,9 @@ from pandas.core.ops import get_op_result_name from pandas.core.tools.timedeltas import to_timedelta +if TYPE_CHECKING: + from pandas import CategoricalIndex + _index_doc_kwargs = dict(ibase._index_doc_kwargs) _T = TypeVar("_T", bound="DatetimeIndexOpsMixin") @@ -151,6 +154,7 @@ def equals(self, other: object) -> bool: if other.dtype == object: should_try = other.inferred_type in inferrable elif is_categorical_dtype(other.dtype): + other = cast("CategoricalIndex", other) should_try = other.categories.inferred_type in inferrable if should_try: