Skip to content

Commit fe42954

Browse files
authored
BUG: Timedelta == ndarray[td64] (#33441)
1 parent 991f784 commit fe42954

File tree

5 files changed

+48
-31
lines changed

5 files changed

+48
-31
lines changed

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ Timedelta
397397
- Bug in dividing ``np.nan`` or ``None`` by :class:`Timedelta`` incorrectly returning ``NaT`` (:issue:`31869`)
398398
- Timedeltas now understand ``µs`` as identifier for microsecond (:issue:`32899`)
399399
- :class:`Timedelta` string representation now includes nanoseconds, when nanoseconds are non-zero (:issue:`9309`)
400+
- Bug in comparing a :class:`Timedelta`` object against a ``np.ndarray`` with ``timedelta64`` dtype incorrectly viewing all entries as unequal (:issue:`33441`)
400401

401402
Timezones
402403
^^^^^^^^^

pandas/_libs/tslibs/timedeltas.pyx

+24-28
Original file line numberDiff line numberDiff line change
@@ -778,36 +778,32 @@ cdef class _Timedelta(timedelta):
778778

779779
if isinstance(other, _Timedelta):
780780
ots = other
781-
elif PyDelta_Check(other) or isinstance(other, Tick):
781+
elif (is_timedelta64_object(other) or PyDelta_Check(other)
782+
or isinstance(other, Tick)):
782783
ots = Timedelta(other)
783-
else:
784-
ndim = getattr(other, "ndim", -1)
784+
# TODO: watch out for overflows
785785

786-
if ndim != -1:
787-
if ndim == 0:
788-
if is_timedelta64_object(other):
789-
other = Timedelta(other)
790-
else:
791-
if op == Py_EQ:
792-
return False
793-
elif op == Py_NE:
794-
return True
795-
# only allow ==, != ops
796-
raise TypeError(f'Cannot compare type '
797-
f'{type(self).__name__} with '
798-
f'type {type(other).__name__}')
799-
if util.is_array(other):
800-
return PyObject_RichCompare(np.array([self]), other, op)
801-
return PyObject_RichCompare(other, self, reverse_ops[op])
802-
else:
803-
if other is NaT:
804-
return PyObject_RichCompare(other, self, reverse_ops[op])
805-
elif op == Py_EQ:
806-
return False
807-
elif op == Py_NE:
808-
return True
809-
raise TypeError(f'Cannot compare type {type(self).__name__} with '
810-
f'type {type(other).__name__}')
786+
elif other is NaT:
787+
return op == Py_NE
788+
789+
elif util.is_array(other):
790+
# TODO: watch out for zero-dim
791+
if other.dtype.kind == "m":
792+
return PyObject_RichCompare(self.asm8, other, op)
793+
elif other.dtype.kind == "O":
794+
# operate element-wise
795+
return np.array(
796+
[PyObject_RichCompare(self, x, op) for x in other],
797+
dtype=bool,
798+
)
799+
if op == Py_EQ:
800+
return np.zeros(other.shape, dtype=bool)
801+
elif op == Py_NE:
802+
return np.ones(other.shape, dtype=bool)
803+
return NotImplemented # let other raise TypeError
804+
805+
else:
806+
return NotImplemented
811807

812808
return cmp_scalar(self.value, ots.value, op)
813809

pandas/core/internals/managers.py

+1
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,7 @@ def to_dict(self, copy: bool = True):
880880
for b in self.blocks:
881881
bd.setdefault(str(b.dtype), []).append(b)
882882

883+
# TODO(EA2D): the combine will be unnecessary with 2D EAs
883884
return {dtype: self._combine(blocks, copy=copy) for dtype, blocks in bd.items()}
884885

885886
def fast_xs(self, loc: int) -> ArrayLike:

pandas/tests/arithmetic/test_datetime64.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,7 @@ def test_dti_cmp_object_dtype(self):
734734
result = dti == other
735735
expected = np.array([True] * 5 + [False] * 5)
736736
tm.assert_numpy_array_equal(result, expected)
737-
msg = "Cannot compare type"
737+
msg = ">=' not supported between instances of 'Timestamp' and 'Timedelta'"
738738
with pytest.raises(TypeError, match=msg):
739739
dti >= other
740740

pandas/tests/scalar/timedelta/test_arithmetic.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,25 @@ def test_compare_timedelta_ndarray(self):
904904
expected = np.array([False, False])
905905
tm.assert_numpy_array_equal(result, expected)
906906

907+
def test_compare_td64_ndarray(self):
908+
# GG#33441
909+
arr = np.arange(5).astype("timedelta64[ns]")
910+
td = pd.Timedelta(arr[1])
911+
912+
expected = np.array([False, True, False, False, False], dtype=bool)
913+
914+
result = td == arr
915+
tm.assert_numpy_array_equal(result, expected)
916+
917+
result = arr == td
918+
tm.assert_numpy_array_equal(result, expected)
919+
920+
result = td != arr
921+
tm.assert_numpy_array_equal(result, ~expected)
922+
923+
result = arr != td
924+
tm.assert_numpy_array_equal(result, ~expected)
925+
907926
@pytest.mark.skip(reason="GH#20829 is reverted until after 0.24.0")
908927
def test_compare_custom_object(self):
909928
"""
@@ -943,7 +962,7 @@ def __gt__(self, other):
943962
def test_compare_unknown_type(self, val):
944963
# GH#20829
945964
t = Timedelta("1s")
946-
msg = "Cannot compare type Timedelta with type (int|str)"
965+
msg = "not supported between instances of 'Timedelta' and '(int|str)'"
947966
with pytest.raises(TypeError, match=msg):
948967
t >= val
949968
with pytest.raises(TypeError, match=msg):
@@ -984,7 +1003,7 @@ def test_ops_error_str():
9841003
with pytest.raises(TypeError, match=msg):
9851004
left + right
9861005

987-
msg = "Cannot compare type"
1006+
msg = "not supported between instances of"
9881007
with pytest.raises(TypeError, match=msg):
9891008
left > right
9901009

0 commit comments

Comments
 (0)