Skip to content

Commit 6919013

Browse files
jbrockmendeljreback
authored andcommitted
BUG/TST: Fix TimedeltaIndex comparisons with invalid types (#24056)
1 parent 473f21a commit 6919013

File tree

8 files changed

+149
-57
lines changed

8 files changed

+149
-57
lines changed

doc/source/whatsnew/v0.24.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,7 @@ Timedelta
13121312
- Bug in :class:`TimedeltaIndex` where adding ``np.timedelta64('NaT')`` incorrectly returned an all-`NaT` :class:`DatetimeIndex` instead of an all-`NaT` :class:`TimedeltaIndex` (:issue:`23215`)
13131313
- Bug in :class:`Timedelta` and :func:`to_timedelta()` have inconsistencies in supported unit string (:issue:`21762`)
13141314
- Bug in :class:`TimedeltaIndex` division where dividing by another :class:`TimedeltaIndex` raised ``TypeError`` instead of returning a :class:`Float64Index` (:issue:`23829`, :issue:`22631`)
1315+
- Bug in :class:`TimedeltaIndex` comparison operations where comparing against non-``Timedelta``-like objects would raise ``TypeError`` instead of returning all-``False`` for ``__eq__`` and all-``True`` for ``__ne__`` (:issue:`24056`)
13151316

13161317
Timezones
13171318
^^^^^^^^^

pandas/core/arrays/datetimes.py

+9
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ class DatetimeArrayMixin(dtl.DatetimeLikeArrayMixin,
174174
# by returning NotImplemented
175175
timetuple = None
176176

177+
# Needed so that Timestamp.__richcmp__(DateTimeArray) operates pointwise
178+
ndim = 1
179+
177180
# ensure that operations with numpy arrays defer to our implementation
178181
__array_priority__ = 1000
179182

@@ -217,6 +220,12 @@ def __new__(cls, values, freq=None, tz=None, dtype=None):
217220
# if dtype has an embedded tz, capture it
218221
tz = dtl.validate_tz_from_dtype(dtype, tz)
219222

223+
if not hasattr(values, "dtype"):
224+
if np.ndim(values) == 0:
225+
# i.e. iterator
226+
values = list(values)
227+
values = np.array(values)
228+
220229
if is_object_dtype(values):
221230
# kludge; dispatch until the DatetimeArray constructor is complete
222231
from pandas import DatetimeIndex

pandas/core/arrays/timedeltas.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ABCDataFrame, ABCIndexClass, ABCSeries, ABCTimedeltaIndex)
2323
from pandas.core.dtypes.missing import isna
2424

25+
from pandas.core import ops
2526
from pandas.core.algorithms import checked_add_with_arr, unique1d
2627
import pandas.core.common as com
2728

@@ -70,25 +71,29 @@ def _td_array_cmp(cls, op):
7071
opname = '__{name}__'.format(name=op.__name__)
7172
nat_result = True if opname == '__ne__' else False
7273

74+
meth = getattr(dtl.DatetimeLikeArrayMixin, opname)
75+
7376
def wrapper(self, other):
74-
msg = "cannot compare a {cls} with type {typ}"
75-
meth = getattr(dtl.DatetimeLikeArrayMixin, opname)
7677
if _is_convertible_to_td(other) or other is NaT:
7778
try:
7879
other = _to_m8(other)
7980
except ValueError:
8081
# failed to parse as timedelta
81-
raise TypeError(msg.format(cls=type(self).__name__,
82-
typ=type(other).__name__))
82+
return ops.invalid_comparison(self, other, op)
83+
8384
result = meth(self, other)
8485
if isna(other):
8586
result.fill(nat_result)
8687

8788
elif not is_list_like(other):
88-
raise TypeError(msg.format(cls=type(self).__name__,
89-
typ=type(other).__name__))
89+
return ops.invalid_comparison(self, other, op)
90+
9091
else:
91-
other = type(self)(other)._data
92+
try:
93+
other = type(self)(other)._data
94+
except (ValueError, TypeError):
95+
return ops.invalid_comparison(self, other, op)
96+
9297
result = meth(self, other)
9398
result = com.values_from_object(result)
9499

@@ -108,6 +113,9 @@ class TimedeltaArrayMixin(dtl.DatetimeLikeArrayMixin, dtl.TimelikeOps):
108113
_typ = "timedeltaarray"
109114
__array_priority__ = 1000
110115

116+
# Needed so that NaT.__richcmp__(DateTimeArray) operates pointwise
117+
ndim = 1
118+
111119
@property
112120
def _box_func(self):
113121
return lambda x: Timedelta(x, unit='ns')

pandas/core/generic.py

+4
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ class NDFrame(PandasObject, SelectionMixin):
111111
_metadata = []
112112
_is_copy = None
113113

114+
# dummy attribute so that datetime.__eq__(Series/DataFrame) defers
115+
# by returning NotImplemented
116+
timetuple = None
117+
114118
# ----------------------------------------------------------------------
115119
# Constructors
116120

pandas/core/ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1115,7 +1115,7 @@ def dispatch_to_series(left, right, func, str_rep=None, axis=None):
11151115
import pandas.core.computation.expressions as expressions
11161116

11171117
right = lib.item_from_zerodim(right)
1118-
if lib.is_scalar(right):
1118+
if lib.is_scalar(right) or np.ndim(right) == 0:
11191119

11201120
def column_op(a, b):
11211121
return {i: func(a.iloc[:, i], b)

pandas/tests/arithmetic/test_datetime64.py

+99-41
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,17 @@
2525
DatetimeIndex, TimedeltaIndex)
2626

2727

28+
def assert_all(obj):
29+
"""
30+
Test helper to call call obj.all() the appropriate number of times on
31+
a Series or DataFrame.
32+
"""
33+
if isinstance(obj, pd.DataFrame):
34+
assert obj.all().all()
35+
else:
36+
assert obj.all()
37+
38+
2839
# ------------------------------------------------------------------
2940
# Comparisons
3041

@@ -86,11 +97,16 @@ def test_comparison_invalid(self, box_with_array):
8697
[Period('2011-01', freq='M'), NaT, Period('2011-03', freq='M')]
8798
])
8899
@pytest.mark.parametrize('dtype', [None, object])
89-
def test_nat_comparisons_scalar(self, dtype, data, box):
90-
xbox = box if box is not pd.Index else np.ndarray
100+
def test_nat_comparisons_scalar(self, dtype, data, box_with_array):
101+
if box_with_array is tm.to_array and dtype is object:
102+
# dont bother testing ndarray comparison methods as this fails
103+
# on older numpys (since they check object identity)
104+
return
105+
106+
xbox = box_with_array if box_with_array is not pd.Index else np.ndarray
91107

92108
left = Series(data, dtype=dtype)
93-
left = tm.box_expected(left, box)
109+
left = tm.box_expected(left, box_with_array)
94110

95111
expected = [False, False, False]
96112
expected = tm.box_expected(expected, xbox)
@@ -290,23 +306,24 @@ def test_dti_cmp_datetimelike(self, other, tz_naive_fixture):
290306
expected = np.array([True, False])
291307
tm.assert_numpy_array_equal(result, expected)
292308

293-
def dti_cmp_non_datetime(self, tz_naive_fixture):
309+
def dt64arr_cmp_non_datetime(self, tz_naive_fixture, box_with_array):
294310
# GH#19301 by convention datetime.date is not considered comparable
295311
# to Timestamp or DatetimeIndex. This may change in the future.
296312
tz = tz_naive_fixture
297313
dti = pd.date_range('2016-01-01', periods=2, tz=tz)
314+
dtarr = tm.box_expected(dti, box_with_array)
298315

299316
other = datetime(2016, 1, 1).date()
300-
assert not (dti == other).any()
301-
assert (dti != other).all()
317+
assert not (dtarr == other).any()
318+
assert (dtarr != other).all()
302319
with pytest.raises(TypeError):
303-
dti < other
320+
dtarr < other
304321
with pytest.raises(TypeError):
305-
dti <= other
322+
dtarr <= other
306323
with pytest.raises(TypeError):
307-
dti > other
324+
dtarr > other
308325
with pytest.raises(TypeError):
309-
dti >= other
326+
dtarr >= other
310327

311328
@pytest.mark.parametrize('other', [None, np.nan, pd.NaT])
312329
def test_dti_eq_null_scalar(self, other, tz_naive_fixture):
@@ -323,49 +340,67 @@ def test_dti_ne_null_scalar(self, other, tz_naive_fixture):
323340
assert (dti != other).all()
324341

325342
@pytest.mark.parametrize('other', [None, np.nan])
326-
def test_dti_cmp_null_scalar_inequality(self, tz_naive_fixture, other):
343+
def test_dti_cmp_null_scalar_inequality(self, tz_naive_fixture, other,
344+
box_with_array):
327345
# GH#19301
328346
tz = tz_naive_fixture
329347
dti = pd.date_range('2016-01-01', periods=2, tz=tz)
348+
# FIXME: ValueError with transpose
349+
dtarr = tm.box_expected(dti, box_with_array, transpose=False)
330350

331351
with pytest.raises(TypeError):
332-
dti < other
352+
dtarr < other
333353
with pytest.raises(TypeError):
334-
dti <= other
354+
dtarr <= other
335355
with pytest.raises(TypeError):
336-
dti > other
356+
dtarr > other
337357
with pytest.raises(TypeError):
338-
dti >= other
358+
dtarr >= other
339359

340360
@pytest.mark.parametrize('dtype', [None, object])
341-
def test_dti_cmp_nat(self, dtype):
361+
def test_dti_cmp_nat(self, dtype, box_with_array):
362+
if box_with_array is tm.to_array and dtype is object:
363+
# dont bother testing ndarray comparison methods as this fails
364+
# on older numpys (since they check object identity)
365+
return
366+
367+
xbox = box_with_array if box_with_array is not pd.Index else np.ndarray
368+
342369
left = pd.DatetimeIndex([pd.Timestamp('2011-01-01'), pd.NaT,
343370
pd.Timestamp('2011-01-03')])
344371
right = pd.DatetimeIndex([pd.NaT, pd.NaT, pd.Timestamp('2011-01-03')])
345372

373+
left = tm.box_expected(left, box_with_array)
374+
right = tm.box_expected(right, box_with_array)
375+
346376
lhs, rhs = left, right
347377
if dtype is object:
348378
lhs, rhs = left.astype(object), right.astype(object)
349379

350380
result = rhs == lhs
351381
expected = np.array([False, False, True])
352-
tm.assert_numpy_array_equal(result, expected)
382+
expected = tm.box_expected(expected, xbox)
383+
tm.assert_equal(result, expected)
353384

354385
result = lhs != rhs
355386
expected = np.array([True, True, False])
356-
tm.assert_numpy_array_equal(result, expected)
387+
expected = tm.box_expected(expected, xbox)
388+
tm.assert_equal(result, expected)
357389

358390
expected = np.array([False, False, False])
359-
tm.assert_numpy_array_equal(lhs == pd.NaT, expected)
360-
tm.assert_numpy_array_equal(pd.NaT == rhs, expected)
391+
expected = tm.box_expected(expected, xbox)
392+
tm.assert_equal(lhs == pd.NaT, expected)
393+
tm.assert_equal(pd.NaT == rhs, expected)
361394

362395
expected = np.array([True, True, True])
363-
tm.assert_numpy_array_equal(lhs != pd.NaT, expected)
364-
tm.assert_numpy_array_equal(pd.NaT != lhs, expected)
396+
expected = tm.box_expected(expected, xbox)
397+
tm.assert_equal(lhs != pd.NaT, expected)
398+
tm.assert_equal(pd.NaT != lhs, expected)
365399

366400
expected = np.array([False, False, False])
367-
tm.assert_numpy_array_equal(lhs < pd.NaT, expected)
368-
tm.assert_numpy_array_equal(pd.NaT > lhs, expected)
401+
expected = tm.box_expected(expected, xbox)
402+
tm.assert_equal(lhs < pd.NaT, expected)
403+
tm.assert_equal(pd.NaT > lhs, expected)
369404

370405
def test_dti_cmp_nat_behaves_like_float_cmp_nan(self):
371406
fidx1 = pd.Index([1.0, np.nan, 3.0, np.nan, 5.0, 7.0])
@@ -459,36 +494,47 @@ def test_dti_cmp_nat_behaves_like_float_cmp_nan(self):
459494
@pytest.mark.parametrize('op', [operator.eq, operator.ne,
460495
operator.gt, operator.ge,
461496
operator.lt, operator.le])
462-
def test_comparison_tzawareness_compat(self, op):
497+
def test_comparison_tzawareness_compat(self, op, box_with_array):
463498
# GH#18162
464499
dr = pd.date_range('2016-01-01', periods=6)
465500
dz = dr.tz_localize('US/Pacific')
466501

502+
# FIXME: ValueError with transpose
503+
dr = tm.box_expected(dr, box_with_array, transpose=False)
504+
dz = tm.box_expected(dz, box_with_array, transpose=False)
505+
467506
with pytest.raises(TypeError):
468507
op(dr, dz)
469-
with pytest.raises(TypeError):
470-
op(dr, list(dz))
508+
if box_with_array is not pd.DataFrame:
509+
# DataFrame op is invalid until transpose bug is fixed
510+
with pytest.raises(TypeError):
511+
op(dr, list(dz))
471512
with pytest.raises(TypeError):
472513
op(dz, dr)
473-
with pytest.raises(TypeError):
474-
op(dz, list(dr))
514+
if box_with_array is not pd.DataFrame:
515+
# DataFrame op is invalid until transpose bug is fixed
516+
with pytest.raises(TypeError):
517+
op(dz, list(dr))
475518

476519
# Check that there isn't a problem aware-aware and naive-naive do not
477520
# raise
478-
assert (dr == dr).all()
479-
assert (dr == list(dr)).all()
480-
assert (dz == dz).all()
481-
assert (dz == list(dz)).all()
521+
assert_all(dr == dr)
522+
assert_all(dz == dz)
523+
if box_with_array is not pd.DataFrame:
524+
# DataFrame doesn't align the lists correctly unless we transpose,
525+
# which we cannot do at the moment
526+
assert (dr == list(dr)).all()
527+
assert (dz == list(dz)).all()
482528

483529
# Check comparisons against scalar Timestamps
484530
ts = pd.Timestamp('2000-03-14 01:59')
485531
ts_tz = pd.Timestamp('2000-03-14 01:59', tz='Europe/Amsterdam')
486532

487-
assert (dr > ts).all()
533+
assert_all(dr > ts)
488534
with pytest.raises(TypeError):
489535
op(dr, ts_tz)
490536

491-
assert (dz > ts_tz).all()
537+
assert_all(dz > ts_tz)
492538
with pytest.raises(TypeError):
493539
op(dz, ts)
494540

@@ -502,13 +548,18 @@ def test_comparison_tzawareness_compat(self, op):
502548
@pytest.mark.parametrize('other', [datetime(2016, 1, 1),
503549
Timestamp('2016-01-01'),
504550
np.datetime64('2016-01-01')])
505-
def test_scalar_comparison_tzawareness(self, op, other, tz_aware_fixture):
551+
def test_scalar_comparison_tzawareness(self, op, other, tz_aware_fixture,
552+
box_with_array):
506553
tz = tz_aware_fixture
507554
dti = pd.date_range('2016-01-01', periods=2, tz=tz)
555+
556+
# FIXME: ValueError with transpose
557+
dtarr = tm.box_expected(dti, box_with_array, transpose=False)
558+
508559
with pytest.raises(TypeError):
509-
op(dti, other)
560+
op(dtarr, other)
510561
with pytest.raises(TypeError):
511-
op(other, dti)
562+
op(other, dtarr)
512563

513564
@pytest.mark.parametrize('op', [operator.eq, operator.ne,
514565
operator.gt, operator.ge,
@@ -558,18 +609,25 @@ def test_dti_cmp_str(self, tz_naive_fixture):
558609

559610
@pytest.mark.parametrize('other', ['foo', 99, 4.0,
560611
object(), timedelta(days=2)])
561-
def test_dti_cmp_scalar_invalid(self, other, tz_naive_fixture):
612+
def test_dt64arr_cmp_scalar_invalid(self, other, tz_naive_fixture,
613+
box_with_array):
562614
# GH#22074
563615
tz = tz_naive_fixture
616+
xbox = box_with_array if box_with_array is not pd.Index else np.ndarray
617+
564618
rng = date_range('1/1/2000', periods=10, tz=tz)
619+
# FIXME: ValueError with transpose
620+
rng = tm.box_expected(rng, box_with_array, transpose=False)
565621

566622
result = rng == other
567623
expected = np.array([False] * 10)
568-
tm.assert_numpy_array_equal(result, expected)
624+
expected = tm.box_expected(expected, xbox, transpose=False)
625+
tm.assert_equal(result, expected)
569626

570627
result = rng != other
571628
expected = np.array([True] * 10)
572-
tm.assert_numpy_array_equal(result, expected)
629+
expected = tm.box_expected(expected, xbox, transpose=False)
630+
tm.assert_equal(result, expected)
573631

574632
with pytest.raises(TypeError):
575633
rng < other

0 commit comments

Comments
 (0)