Skip to content

Commit 8abbf39

Browse files
committed
Timestamp comparisons for object arrays, closes pandas-dev#15183
1 parent b6a7cc9 commit 8abbf39

File tree

3 files changed

+50
-1
lines changed

3 files changed

+50
-1
lines changed

pandas/_libs/tslibs/timestamps.pyx

+4-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ from cpython.datetime cimport (datetime,
1717
PyDateTime_IMPORT
1818

1919
from util cimport (is_datetime64_object, is_timedelta64_object,
20-
is_integer_object, is_string_object,
20+
is_integer_object, is_string_object, is_array,
2121
INT64_MAX)
2222

2323
cimport ccalendar
@@ -108,6 +108,9 @@ cdef class _Timestamp(datetime):
108108
raise TypeError('Cannot compare type %r with type %r' %
109109
(type(self).__name__,
110110
type(other).__name__))
111+
elif is_array(other):
112+
# avoid recursion error GH#15183
113+
return PyObject_RichCompare(np.array([self]), other, op)
111114
return PyObject_RichCompare(other, self, reverse_ops[op])
112115
else:
113116
if op == Py_EQ:

pandas/tests/scalar/test_timedelta.py

+21
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,27 @@ def test_binary_ops_with_timedelta(self):
163163
pytest.raises(TypeError, lambda: td * td)
164164

165165

166+
class TestTimedeltaComparison(object):
167+
def test_comparison_object_array(self):
168+
# analogous to GH#15183
169+
td = Timedelta('2 days')
170+
other = Timedelta('3 hours')
171+
172+
arr = np.array([other, td], dtype=object)
173+
res = arr == td
174+
expected = np.array([False, True], dtype=bool)
175+
assert (res == expected).all()
176+
177+
# 2D case
178+
arr = np.array([[other, td],
179+
[td, other]],
180+
dtype=object)
181+
res = arr != td
182+
expected = np.array([[True, False], [False, True]], dtype=bool)
183+
assert res.shape == expected.shape
184+
assert (res == expected).all()
185+
186+
166187
class TestTimedeltas(object):
167188
_multiprocess_can_split_ = True
168189

pandas/tests/scalar/test_timestamp.py

+25
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,31 @@ def test_timestamp(self):
969969

970970

971971
class TestTimestampComparison(object):
972+
def test_comparison_object_array(self):
973+
# GH#15183
974+
ts = Timestamp('2011-01-03 00:00:00-0500', tz='US/Eastern')
975+
other = Timestamp('2011-01-01 00:00:00-0500', tz='US/Eastern')
976+
naive = Timestamp('2011-01-01 00:00:00')
977+
978+
arr = np.array([other, ts], dtype=object)
979+
res = arr == ts
980+
expected = np.array([False, True], dtype=bool)
981+
assert (res == expected).all()
982+
983+
# 2D case
984+
arr = np.array([[other, ts],
985+
[ts, other]],
986+
dtype=object)
987+
res = arr != ts
988+
expected = np.array([[True, False], [False, True]], dtype=bool)
989+
assert res.shape == expected.shape
990+
assert (res == expected).all()
991+
992+
# tzaware mismatch
993+
arr = np.array([naive], dtype=object)
994+
with pytest.raises(TypeError):
995+
arr < ts
996+
972997
def test_comparison(self):
973998
# 5-18-2012 00:00:00.000
974999
stamp = long(1337299200000000000)

0 commit comments

Comments
 (0)