Skip to content

Commit 8d5c51b

Browse files
jbrockmendeljreback
authored andcommitted
[Bug] Fix various DatetimeIndex comparison bugs (#22074)
1 parent 57c7daa commit 8d5c51b

File tree

7 files changed

+208
-24
lines changed

7 files changed

+208
-24
lines changed

doc/source/whatsnew/v0.24.0.txt

+5
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,9 @@ Datetimelike
533533
- Fixed bug where two :class:`DateOffset` objects with different ``normalize`` attributes could evaluate as equal (:issue:`21404`)
534534
- Fixed bug where :meth:`Timestamp.resolution` incorrectly returned 1-microsecond ``timedelta`` instead of 1-nanosecond :class:`Timedelta` (:issue:`21336`,:issue:`21365`)
535535
- Bug in :func:`to_datetime` that did not consistently return an :class:`Index` when ``box=True`` was specified (:issue:`21864`)
536+
- Bug in :class:`DatetimeIndex` comparisons where string comparisons incorrectly raises ``TypeError`` (:issue:`22074`)
537+
- Bug in :class:`DatetimeIndex` comparisons when comparing against ``timedelta64[ns]`` dtyped arrays; in some cases ``TypeError`` was incorrectly raised, in others it incorrectly failed to raise (:issue:`22074`)
538+
- Bug in :class:`DatetimeIndex` comparisons when comparing against object-dtyped arrays (:issue:`22074`)
536539

537540
Timedelta
538541
^^^^^^^^^
@@ -555,6 +558,7 @@ Timezones
555558
- Bug in :class:`Index` with ``datetime64[ns, tz]`` dtype that did not localize integer data correctly (:issue:`20964`)
556559
- Bug in :class:`DatetimeIndex` where constructing with an integer and tz would not localize correctly (:issue:`12619`)
557560
- Fixed bug where :meth:`DataFrame.describe` and :meth:`Series.describe` on tz-aware datetimes did not show `first` and `last` result (:issue:`21328`)
561+
- Bug in :class:`DatetimeIndex` comparisons failing to raise ``TypeError`` when comparing timezone-aware ``DatetimeIndex`` against ``np.datetime64`` (:issue:`22074`)
558562

559563
Offsets
560564
^^^^^^^
@@ -572,6 +576,7 @@ Numeric
572576
- Bug in :meth:`DataFrame.agg`, :meth:`DataFrame.transform` and :meth:`DataFrame.apply` where,
573577
when supplied with a list of functions and ``axis=1`` (e.g. ``df.apply(['sum', 'mean'], axis=1)``),
574578
a ``TypeError`` was wrongly raised. For all three methods such calculation are now done correctly. (:issue:`16679`).
579+
- Bug in :class:`Series` comparison against datetime-like scalars and arrays (:issue:`22074`)
575580
-
576581

577582
Strings

pandas/core/arrays/datetimes.py

+27-13
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
from pytz import utc
77

8-
from pandas._libs import tslib
8+
from pandas._libs import lib, tslib
99
from pandas._libs.tslib import Timestamp, NaT, iNaT
1010
from pandas._libs.tslibs import (
1111
normalize_date,
@@ -18,7 +18,7 @@
1818

1919
from pandas.core.dtypes.common import (
2020
_NS_DTYPE,
21-
is_datetimelike,
21+
is_object_dtype,
2222
is_datetime64tz_dtype,
2323
is_datetime64_dtype,
2424
is_timedelta64_dtype,
@@ -29,6 +29,7 @@
2929

3030
import pandas.core.common as com
3131
from pandas.core.algorithms import checked_add_with_arr
32+
from pandas.core import ops
3233

3334
from pandas.tseries.frequencies import to_offset
3435
from pandas.tseries.offsets import Tick, Day, generate_range
@@ -99,31 +100,40 @@ def wrapper(self, other):
99100
meth = getattr(dtl.DatetimeLikeArrayMixin, opname)
100101

101102
if isinstance(other, (datetime, np.datetime64, compat.string_types)):
102-
if isinstance(other, datetime):
103+
if isinstance(other, (datetime, np.datetime64)):
103104
# GH#18435 strings get a pass from tzawareness compat
104105
self._assert_tzawareness_compat(other)
105106

106-
other = _to_m8(other, tz=self.tz)
107+
try:
108+
other = _to_m8(other, tz=self.tz)
109+
except ValueError:
110+
# string that cannot be parsed to Timestamp
111+
return ops.invalid_comparison(self, other, op)
112+
107113
result = meth(self, other)
108114
if isna(other):
109115
result.fill(nat_result)
116+
elif lib.is_scalar(other):
117+
return ops.invalid_comparison(self, other, op)
110118
else:
111119
if isinstance(other, list):
120+
# FIXME: This can break for object-dtype with mixed types
112121
other = type(self)(other)
113122
elif not isinstance(other, (np.ndarray, ABCIndexClass, ABCSeries)):
114123
# Following Timestamp convention, __eq__ is all-False
115124
# and __ne__ is all True, others raise TypeError.
116-
if opname == '__eq__':
117-
return np.zeros(shape=self.shape, dtype=bool)
118-
elif opname == '__ne__':
119-
return np.ones(shape=self.shape, dtype=bool)
120-
raise TypeError('%s type object %s' %
121-
(type(other), str(other)))
122-
123-
if is_datetimelike(other):
125+
return ops.invalid_comparison(self, other, op)
126+
127+
if is_object_dtype(other):
128+
result = op(self.astype('O'), np.array(other))
129+
elif not (is_datetime64_dtype(other) or
130+
is_datetime64tz_dtype(other)):
131+
# e.g. is_timedelta64_dtype(other)
132+
return ops.invalid_comparison(self, other, op)
133+
else:
124134
self._assert_tzawareness_compat(other)
135+
result = meth(self, np.asarray(other))
125136

126-
result = meth(self, np.asarray(other))
127137
result = com.values_from_object(result)
128138

129139
# Make sure to pass an array to result[...]; indexing with
@@ -152,6 +162,10 @@ class DatetimeArrayMixin(dtl.DatetimeLikeArrayMixin):
152162
'is_year_end', 'is_leap_year']
153163
_object_ops = ['weekday_name', 'freq', 'tz']
154164

165+
# dummy attribute so that datetime.__eq__(DatetimeArray) defers
166+
# by returning NotImplemented
167+
timetuple = None
168+
155169
# -----------------------------------------------------------------
156170
# Constructors
157171

pandas/core/ops.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,35 @@ def mask_cmp_op(x, y, op, allowed_types):
788788
return result
789789

790790

791+
def invalid_comparison(left, right, op):
792+
"""
793+
If a comparison has mismatched types and is not necessarily meaningful,
794+
follow python3 conventions by:
795+
796+
- returning all-False for equality
797+
- returning all-True for inequality
798+
- raising TypeError otherwise
799+
800+
Parameters
801+
----------
802+
left : array-like
803+
right : scalar, array-like
804+
op : operator.{eq, ne, lt, le, gt}
805+
806+
Raises
807+
------
808+
TypeError : on inequality comparisons
809+
"""
810+
if op is operator.eq:
811+
res_values = np.zeros(left.shape, dtype=bool)
812+
elif op is operator.ne:
813+
res_values = np.ones(left.shape, dtype=bool)
814+
else:
815+
raise TypeError("Invalid comparison between dtype={dtype} and {typ}"
816+
.format(dtype=left.dtype, typ=type(right).__name__))
817+
return res_values
818+
819+
791820
# -----------------------------------------------------------------------------
792821
# Functions that add arithmetic methods to objects, given arithmetic factory
793822
# methods
@@ -1259,7 +1288,7 @@ def na_op(x, y):
12591288
result = _comp_method_OBJECT_ARRAY(op, x, y)
12601289

12611290
elif is_datetimelike_v_numeric(x, y):
1262-
raise TypeError("invalid type comparison")
1291+
return invalid_comparison(x, y, op)
12631292

12641293
else:
12651294

pandas/tests/frame/test_operators.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,20 @@ def test_comparison_invalid(self):
156156
def check(df, df2):
157157

158158
for (x, y) in [(df, df2), (df2, df)]:
159-
pytest.raises(TypeError, lambda: x == y)
160-
pytest.raises(TypeError, lambda: x != y)
159+
# we expect the result to match Series comparisons for
160+
# == and !=, inequalities should raise
161+
result = x == y
162+
expected = DataFrame({col: x[col] == y[col]
163+
for col in x.columns},
164+
index=x.index, columns=x.columns)
165+
assert_frame_equal(result, expected)
166+
167+
result = x != y
168+
expected = DataFrame({col: x[col] != y[col]
169+
for col in x.columns},
170+
index=x.index, columns=x.columns)
171+
assert_frame_equal(result, expected)
172+
161173
pytest.raises(TypeError, lambda: x >= y)
162174
pytest.raises(TypeError, lambda: x > y)
163175
pytest.raises(TypeError, lambda: x < y)

pandas/tests/frame/test_query_eval.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -463,9 +463,13 @@ def test_date_query_with_non_date(self):
463463
df = DataFrame({'dates': date_range('1/1/2012', periods=n),
464464
'nondate': np.arange(n)})
465465

466-
ops = '==', '!=', '<', '>', '<=', '>='
466+
result = df.query('dates == nondate', parser=parser, engine=engine)
467+
assert len(result) == 0
467468

468-
for op in ops:
469+
result = df.query('dates != nondate', parser=parser, engine=engine)
470+
assert_frame_equal(result, df)
471+
472+
for op in ['<', '>', '<=', '>=']:
469473
with pytest.raises(TypeError):
470474
df.query('dates %s nondate' % op, parser=parser, engine=engine)
471475

pandas/tests/indexes/datetimes/test_arithmetic.py

+117-4
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,20 @@ def test_comparison_tzawareness_compat(self, op):
275275
with pytest.raises(TypeError):
276276
op(ts, dz)
277277

278+
@pytest.mark.parametrize('op', [operator.eq, operator.ne,
279+
operator.gt, operator.ge,
280+
operator.lt, operator.le])
281+
@pytest.mark.parametrize('other', [datetime(2016, 1, 1),
282+
Timestamp('2016-01-01'),
283+
np.datetime64('2016-01-01')])
284+
def test_scalar_comparison_tzawareness(self, op, other, tz_aware_fixture):
285+
tz = tz_aware_fixture
286+
dti = pd.date_range('2016-01-01', periods=2, tz=tz)
287+
with pytest.raises(TypeError):
288+
op(dti, other)
289+
with pytest.raises(TypeError):
290+
op(other, dti)
291+
278292
@pytest.mark.parametrize('op', [operator.eq, operator.ne,
279293
operator.gt, operator.ge,
280294
operator.lt, operator.le])
@@ -290,12 +304,60 @@ def test_nat_comparison_tzawareness(self, op):
290304
result = op(dti.tz_localize('US/Pacific'), pd.NaT)
291305
tm.assert_numpy_array_equal(result, expected)
292306

293-
def test_dti_cmp_int_raises(self):
294-
rng = date_range('1/1/2000', periods=10)
307+
def test_dti_cmp_str(self, tz_naive_fixture):
308+
# GH#22074
309+
# regardless of tz, we expect these comparisons are valid
310+
tz = tz_naive_fixture
311+
rng = date_range('1/1/2000', periods=10, tz=tz)
312+
other = '1/1/2000'
313+
314+
result = rng == other
315+
expected = np.array([True] + [False] * 9)
316+
tm.assert_numpy_array_equal(result, expected)
317+
318+
result = rng != other
319+
expected = np.array([False] + [True] * 9)
320+
tm.assert_numpy_array_equal(result, expected)
321+
322+
result = rng < other
323+
expected = np.array([False] * 10)
324+
tm.assert_numpy_array_equal(result, expected)
325+
326+
result = rng <= other
327+
expected = np.array([True] + [False] * 9)
328+
tm.assert_numpy_array_equal(result, expected)
329+
330+
result = rng > other
331+
expected = np.array([False] + [True] * 9)
332+
tm.assert_numpy_array_equal(result, expected)
333+
334+
result = rng >= other
335+
expected = np.array([True] * 10)
336+
tm.assert_numpy_array_equal(result, expected)
337+
338+
@pytest.mark.parametrize('other', ['foo', 99, 4.0,
339+
object(), timedelta(days=2)])
340+
def test_dti_cmp_scalar_invalid(self, other, tz_naive_fixture):
341+
# GH#22074
342+
tz = tz_naive_fixture
343+
rng = date_range('1/1/2000', periods=10, tz=tz)
344+
345+
result = rng == other
346+
expected = np.array([False] * 10)
347+
tm.assert_numpy_array_equal(result, expected)
348+
349+
result = rng != other
350+
expected = np.array([True] * 10)
351+
tm.assert_numpy_array_equal(result, expected)
295352

296-
# raise TypeError for now
297353
with pytest.raises(TypeError):
298-
rng < rng[3].value
354+
rng < other
355+
with pytest.raises(TypeError):
356+
rng <= other
357+
with pytest.raises(TypeError):
358+
rng > other
359+
with pytest.raises(TypeError):
360+
rng >= other
299361

300362
def test_dti_cmp_list(self):
301363
rng = date_range('1/1/2000', periods=10)
@@ -304,6 +366,57 @@ def test_dti_cmp_list(self):
304366
expected = rng == rng
305367
tm.assert_numpy_array_equal(result, expected)
306368

369+
@pytest.mark.parametrize('other', [
370+
pd.timedelta_range('1D', periods=10),
371+
pd.timedelta_range('1D', periods=10).to_series(),
372+
pd.timedelta_range('1D', periods=10).asi8.view('m8[ns]')
373+
], ids=lambda x: type(x).__name__)
374+
def test_dti_cmp_tdi_tzawareness(self, other):
375+
# GH#22074
376+
# reversion test that we _don't_ call _assert_tzawareness_compat
377+
# when comparing against TimedeltaIndex
378+
dti = date_range('2000-01-01', periods=10, tz='Asia/Tokyo')
379+
380+
result = dti == other
381+
expected = np.array([False] * 10)
382+
tm.assert_numpy_array_equal(result, expected)
383+
384+
result = dti != other
385+
expected = np.array([True] * 10)
386+
tm.assert_numpy_array_equal(result, expected)
387+
388+
with pytest.raises(TypeError):
389+
dti < other
390+
with pytest.raises(TypeError):
391+
dti <= other
392+
with pytest.raises(TypeError):
393+
dti > other
394+
with pytest.raises(TypeError):
395+
dti >= other
396+
397+
def test_dti_cmp_object_dtype(self):
398+
# GH#22074
399+
dti = date_range('2000-01-01', periods=10, tz='Asia/Tokyo')
400+
401+
other = dti.astype('O')
402+
403+
result = dti == other
404+
expected = np.array([True] * 10)
405+
tm.assert_numpy_array_equal(result, expected)
406+
407+
other = dti.tz_localize(None)
408+
with pytest.raises(TypeError):
409+
# tzawareness failure
410+
dti != other
411+
412+
other = np.array(list(dti[:5]) + [Timedelta(days=1)] * 5)
413+
result = dti == other
414+
expected = np.array([True] * 5 + [False] * 5)
415+
tm.assert_numpy_array_equal(result, expected)
416+
417+
with pytest.raises(TypeError):
418+
dti >= other
419+
307420

308421
class TestDatetimeIndexArithmetic(object):
309422

pandas/tests/series/test_operators.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,15 @@ def test_comparison_invalid(self):
243243
s2 = Series(date_range('20010101', periods=5))
244244

245245
for (x, y) in [(s, s2), (s2, s)]:
246-
pytest.raises(TypeError, lambda: x == y)
247-
pytest.raises(TypeError, lambda: x != y)
246+
247+
result = x == y
248+
expected = Series([False] * 5)
249+
assert_series_equal(result, expected)
250+
251+
result = x != y
252+
expected = Series([True] * 5)
253+
assert_series_equal(result, expected)
254+
248255
pytest.raises(TypeError, lambda: x >= y)
249256
pytest.raises(TypeError, lambda: x > y)
250257
pytest.raises(TypeError, lambda: x < y)

0 commit comments

Comments
 (0)