diff --git a/pandas/core/arrays/timedeltas.py b/pandas/core/arrays/timedeltas.py index f85e3e716bbf9..0853664c766bb 100644 --- a/pandas/core/arrays/timedeltas.py +++ b/pandas/core/arrays/timedeltas.py @@ -23,6 +23,7 @@ from pandas.core.dtypes.common import ( DT64NS_DTYPE, TD64NS_DTYPE, + is_categorical_dtype, is_dtype_equal, is_float_dtype, is_integer_dtype, @@ -940,6 +941,9 @@ def sequence_to_td64ns(data, copy=False, unit=None, errors="raise"): data = data._data elif isinstance(data, IntegerArray): data = data.to_numpy("int64", na_value=tslibs.iNaT) + elif is_categorical_dtype(data.dtype): + data = data.categories.take(data.codes, fill_value=NaT)._values + copy = False # Convert whatever we have into timedelta64[ns] dtype if is_object_dtype(data.dtype) or is_string_dtype(data.dtype): diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 5821ff0aca3c2..5baa103a25d51 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -2,7 +2,7 @@ Base and utility classes for tseries type pandas objects. """ from datetime import datetime, tzinfo -from typing import Any, List, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, List, Optional, TypeVar, Union, cast import numpy as np @@ -16,6 +16,7 @@ from pandas.core.dtypes.common import ( ensure_int64, is_bool_dtype, + is_categorical_dtype, is_dtype_equal, is_integer, is_list_like, @@ -41,6 +42,9 @@ from pandas.core.ops import get_op_result_name from pandas.core.tools.timedeltas import to_timedelta +if TYPE_CHECKING: + from pandas import CategoricalIndex + _index_doc_kwargs = dict(ibase._index_doc_kwargs) _T = TypeVar("_T", bound="DatetimeIndexOpsMixin") @@ -137,14 +141,31 @@ def equals(self, other: object) -> bool: elif other.dtype.kind in ["f", "i", "u", "c"]: return False elif not isinstance(other, type(self)): - try: - other = type(self)(other) - except (ValueError, TypeError, OverflowError): - # e.g. - # ValueError -> cannot parse str entry, or OutOfBoundsDatetime - # TypeError -> trying to convert IntervalIndex to DatetimeIndex - # OverflowError -> Index([very_large_timedeltas]) - return False + inferrable = [ + "timedelta", + "timedelta64", + "datetime", + "datetime64", + "date", + "period", + ] + + should_try = False + if other.dtype == object: + should_try = other.inferred_type in inferrable + elif is_categorical_dtype(other.dtype): + other = cast("CategoricalIndex", other) + should_try = other.categories.inferred_type in inferrable + + if should_try: + try: + other = type(self)(other) + except (ValueError, TypeError, OverflowError): + # e.g. + # ValueError -> cannot parse str entry, or OutOfBoundsDatetime + # TypeError -> trying to convert IntervalIndex to DatetimeIndex + # OverflowError -> Index([very_large_timedeltas]) + return False if not is_dtype_equal(self.dtype, other.dtype): # have different timezone diff --git a/pandas/tests/indexes/datetimelike.py b/pandas/tests/indexes/datetimelike.py index df857cce05bbb..be8ca61f1a730 100644 --- a/pandas/tests/indexes/datetimelike.py +++ b/pandas/tests/indexes/datetimelike.py @@ -116,6 +116,20 @@ def test_not_equals_numeric(self): assert not index.equals(pd.Index(index.asi8.astype("u8"))) assert not index.equals(pd.Index(index.asi8).astype("f8")) + def test_equals(self): + index = self.create_index() + + assert index.equals(index.astype(object)) + assert index.equals(pd.CategoricalIndex(index)) + assert index.equals(pd.CategoricalIndex(index.astype(object))) + + def test_not_equals_strings(self): + index = self.create_index() + + other = pd.Index([str(x) for x in index], dtype=object) + assert not index.equals(other) + assert not index.equals(pd.CategoricalIndex(other)) + def test_where_cast_str(self): index = self.create_index() diff --git a/pandas/tests/indexes/timedeltas/test_constructors.py b/pandas/tests/indexes/timedeltas/test_constructors.py index 41e4e220c999c..09344bb5054f6 100644 --- a/pandas/tests/indexes/timedeltas/test_constructors.py +++ b/pandas/tests/indexes/timedeltas/test_constructors.py @@ -238,3 +238,15 @@ def test_explicit_none_freq(self): result = TimedeltaIndex(tdi._data, freq=None) assert result.freq is None + + def test_from_categorical(self): + tdi = timedelta_range(1, periods=5) + + cat = pd.Categorical(tdi) + + result = TimedeltaIndex(cat) + tm.assert_index_equal(result, tdi) + + ci = pd.CategoricalIndex(tdi) + result = TimedeltaIndex(ci) + tm.assert_index_equal(result, tdi)