Skip to content

Commit a6db928

Browse files
authored
BUG/API: tighter checks on DTI/TDI.equals (#36962)
1 parent f2ec69d commit a6db928

File tree

4 files changed

+60
-9
lines changed

4 files changed

+60
-9
lines changed

pandas/core/arrays/timedeltas.py

+4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pandas.core.dtypes.common import (
2424
DT64NS_DTYPE,
2525
TD64NS_DTYPE,
26+
is_categorical_dtype,
2627
is_dtype_equal,
2728
is_float_dtype,
2829
is_integer_dtype,
@@ -940,6 +941,9 @@ def sequence_to_td64ns(data, copy=False, unit=None, errors="raise"):
940941
data = data._data
941942
elif isinstance(data, IntegerArray):
942943
data = data.to_numpy("int64", na_value=tslibs.iNaT)
944+
elif is_categorical_dtype(data.dtype):
945+
data = data.categories.take(data.codes, fill_value=NaT)._values
946+
copy = False
943947

944948
# Convert whatever we have into timedelta64[ns] dtype
945949
if is_object_dtype(data.dtype) or is_string_dtype(data.dtype):

pandas/core/indexes/datetimelike.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Base and utility classes for tseries type pandas objects.
33
"""
44
from datetime import datetime, tzinfo
5-
from typing import Any, List, Optional, TypeVar, Union, cast
5+
from typing import TYPE_CHECKING, Any, List, Optional, TypeVar, Union, cast
66

77
import numpy as np
88

@@ -16,6 +16,7 @@
1616
from pandas.core.dtypes.common import (
1717
ensure_int64,
1818
is_bool_dtype,
19+
is_categorical_dtype,
1920
is_dtype_equal,
2021
is_integer,
2122
is_list_like,
@@ -41,6 +42,9 @@
4142
from pandas.core.ops import get_op_result_name
4243
from pandas.core.tools.timedeltas import to_timedelta
4344

45+
if TYPE_CHECKING:
46+
from pandas import CategoricalIndex
47+
4448
_index_doc_kwargs = dict(ibase._index_doc_kwargs)
4549

4650
_T = TypeVar("_T", bound="DatetimeIndexOpsMixin")
@@ -137,14 +141,31 @@ def equals(self, other: object) -> bool:
137141
elif other.dtype.kind in ["f", "i", "u", "c"]:
138142
return False
139143
elif not isinstance(other, type(self)):
140-
try:
141-
other = type(self)(other)
142-
except (ValueError, TypeError, OverflowError):
143-
# e.g.
144-
# ValueError -> cannot parse str entry, or OutOfBoundsDatetime
145-
# TypeError -> trying to convert IntervalIndex to DatetimeIndex
146-
# OverflowError -> Index([very_large_timedeltas])
147-
return False
144+
inferrable = [
145+
"timedelta",
146+
"timedelta64",
147+
"datetime",
148+
"datetime64",
149+
"date",
150+
"period",
151+
]
152+
153+
should_try = False
154+
if other.dtype == object:
155+
should_try = other.inferred_type in inferrable
156+
elif is_categorical_dtype(other.dtype):
157+
other = cast("CategoricalIndex", other)
158+
should_try = other.categories.inferred_type in inferrable
159+
160+
if should_try:
161+
try:
162+
other = type(self)(other)
163+
except (ValueError, TypeError, OverflowError):
164+
# e.g.
165+
# ValueError -> cannot parse str entry, or OutOfBoundsDatetime
166+
# TypeError -> trying to convert IntervalIndex to DatetimeIndex
167+
# OverflowError -> Index([very_large_timedeltas])
168+
return False
148169

149170
if not is_dtype_equal(self.dtype, other.dtype):
150171
# have different timezone

pandas/tests/indexes/datetimelike.py

+14
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,20 @@ def test_not_equals_numeric(self):
116116
assert not index.equals(pd.Index(index.asi8.astype("u8")))
117117
assert not index.equals(pd.Index(index.asi8).astype("f8"))
118118

119+
def test_equals(self):
120+
index = self.create_index()
121+
122+
assert index.equals(index.astype(object))
123+
assert index.equals(pd.CategoricalIndex(index))
124+
assert index.equals(pd.CategoricalIndex(index.astype(object)))
125+
126+
def test_not_equals_strings(self):
127+
index = self.create_index()
128+
129+
other = pd.Index([str(x) for x in index], dtype=object)
130+
assert not index.equals(other)
131+
assert not index.equals(pd.CategoricalIndex(other))
132+
119133
def test_where_cast_str(self):
120134
index = self.create_index()
121135

pandas/tests/indexes/timedeltas/test_constructors.py

+12
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,15 @@ def test_explicit_none_freq(self):
238238

239239
result = TimedeltaIndex(tdi._data, freq=None)
240240
assert result.freq is None
241+
242+
def test_from_categorical(self):
243+
tdi = timedelta_range(1, periods=5)
244+
245+
cat = pd.Categorical(tdi)
246+
247+
result = TimedeltaIndex(cat)
248+
tm.assert_index_equal(result, tdi)
249+
250+
ci = pd.CategoricalIndex(tdi)
251+
result = TimedeltaIndex(ci)
252+
tm.assert_index_equal(result, tdi)

0 commit comments

Comments
 (0)