Skip to content

Commit cd6510d

Browse files
jbrockmendeljreback
authored andcommitted
Fix DTI comparison with None, datetime.date (#19301)
1 parent 601b8c9 commit cd6510d

File tree

3 files changed

+156
-74
lines changed

3 files changed

+156
-74
lines changed

doc/source/whatsnew/v0.23.0.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,8 @@ Datetimelike
453453
- Bug in subtracting :class:`Series` from ``NaT`` incorrectly returning ``NaT`` (:issue:`19158`)
454454
- Bug in :func:`Series.truncate` which raises ``TypeError`` with a monotonic ``PeriodIndex`` (:issue:`17717`)
455455
- Bug in :func:`~DataFrame.pct_change` using ``periods`` and ``freq`` returned different length outputs (:issue:`7292`)
456+
- Bug in comparison of :class:`DatetimeIndex` against ``None`` or ``datetime.date`` objects raising ``TypeError`` for ``==`` and ``!=`` comparisons instead of all-``False`` and all-``True``, respectively (:issue:`19301`)
457+
-
456458

457459
Timezones
458460
^^^^^^^^^
@@ -484,8 +486,6 @@ Numeric
484486
- Bug in the :class:`DataFrame` constructor in which data containing very large positive or very large negative numbers was causing ``OverflowError`` (:issue:`18584`)
485487
- Bug in :class:`Index` constructor with ``dtype='uint64'`` where int-like floats were not coerced to :class:`UInt64Index` (:issue:`18400`)
486488

487-
-
488-
489489

490490
Indexing
491491
^^^^^^^^

pandas/core/indexes/datetimes.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,16 @@ def wrapper(self, other):
120120
else:
121121
if isinstance(other, list):
122122
other = DatetimeIndex(other)
123-
elif not isinstance(other, (np.ndarray, Index, ABCSeries)):
124-
other = _ensure_datetime64(other)
123+
elif not isinstance(other, (np.datetime64, np.ndarray,
124+
Index, ABCSeries)):
125+
# Following Timestamp convention, __eq__ is all-False
126+
# and __ne__ is all True, others raise TypeError.
127+
if opname == '__eq__':
128+
return np.zeros(shape=self.shape, dtype=bool)
129+
elif opname == '__ne__':
130+
return np.ones(shape=self.shape, dtype=bool)
131+
raise TypeError('%s type object %s' %
132+
(type(other), str(other)))
125133

126134
if is_datetimelike(other):
127135
self._assert_tzawareness_compat(other)
@@ -148,12 +156,6 @@ def wrapper(self, other):
148156
return compat.set_function_name(wrapper, opname, cls)
149157

150158

151-
def _ensure_datetime64(other):
152-
if isinstance(other, np.datetime64):
153-
return other
154-
raise TypeError('%s type object %s' % (type(other), str(other)))
155-
156-
157159
_midnight = time(0, 0)
158160

159161

pandas/tests/indexes/datetimes/test_arithmetic.py

+144-64
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pandas import (Timestamp, Timedelta, Series,
1515
DatetimeIndex, TimedeltaIndex,
1616
date_range)
17+
from pandas._libs import tslib
1718

1819

1920
@pytest.fixture(params=[None, 'UTC', 'Asia/Tokyo',
@@ -44,7 +45,83 @@ def addend(request):
4445

4546

4647
class TestDatetimeIndexComparisons(object):
47-
# TODO: De-duplicate with test_comparisons_nat below
48+
@pytest.mark.parametrize('other', [datetime(2016, 1, 1),
49+
Timestamp('2016-01-01'),
50+
np.datetime64('2016-01-01')])
51+
def test_dti_cmp_datetimelike(self, other, tz):
52+
dti = pd.date_range('2016-01-01', periods=2, tz=tz)
53+
if tz is not None:
54+
if isinstance(other, np.datetime64):
55+
# no tzaware version available
56+
return
57+
elif isinstance(other, Timestamp):
58+
other = other.tz_localize(dti.tzinfo)
59+
else:
60+
other = tslib._localize_pydatetime(other, dti.tzinfo)
61+
62+
result = dti == other
63+
expected = np.array([True, False])
64+
tm.assert_numpy_array_equal(result, expected)
65+
66+
result = dti > other
67+
expected = np.array([False, True])
68+
tm.assert_numpy_array_equal(result, expected)
69+
70+
result = dti >= other
71+
expected = np.array([True, True])
72+
tm.assert_numpy_array_equal(result, expected)
73+
74+
result = dti < other
75+
expected = np.array([False, False])
76+
tm.assert_numpy_array_equal(result, expected)
77+
78+
result = dti <= other
79+
expected = np.array([True, False])
80+
tm.assert_numpy_array_equal(result, expected)
81+
82+
def dti_cmp_non_datetime(self, tz):
83+
# GH#19301 by convention datetime.date is not considered comparable
84+
# to Timestamp or DatetimeIndex. This may change in the future.
85+
dti = pd.date_range('2016-01-01', periods=2, tz=tz)
86+
87+
other = datetime(2016, 1, 1).date()
88+
assert not (dti == other).any()
89+
assert (dti != other).all()
90+
with pytest.raises(TypeError):
91+
dti < other
92+
with pytest.raises(TypeError):
93+
dti <= other
94+
with pytest.raises(TypeError):
95+
dti > other
96+
with pytest.raises(TypeError):
97+
dti >= other
98+
99+
@pytest.mark.parametrize('other', [None, np.nan, pd.NaT])
100+
def test_dti_eq_null_scalar(self, other, tz):
101+
# GH#19301
102+
dti = pd.date_range('2016-01-01', periods=2, tz=tz)
103+
assert not (dti == other).any()
104+
105+
@pytest.mark.parametrize('other', [None, np.nan, pd.NaT])
106+
def test_dti_ne_null_scalar(self, other, tz):
107+
# GH#19301
108+
dti = pd.date_range('2016-01-01', periods=2, tz=tz)
109+
assert (dti != other).all()
110+
111+
@pytest.mark.parametrize('other', [None, np.nan])
112+
def test_dti_cmp_null_scalar_inequality(self, tz, other):
113+
# GH#19301
114+
dti = pd.date_range('2016-01-01', periods=2, tz=tz)
115+
116+
with pytest.raises(TypeError):
117+
dti < other
118+
with pytest.raises(TypeError):
119+
dti <= other
120+
with pytest.raises(TypeError):
121+
dti > other
122+
with pytest.raises(TypeError):
123+
dti >= other
124+
48125
def test_dti_cmp_nat(self):
49126
left = pd.DatetimeIndex([pd.Timestamp('2011-01-01'), pd.NaT,
50127
pd.Timestamp('2011-01-03')])
@@ -72,69 +149,7 @@ def test_dti_cmp_nat(self):
72149
tm.assert_numpy_array_equal(lhs < pd.NaT, expected)
73150
tm.assert_numpy_array_equal(pd.NaT > lhs, expected)
74151

75-
@pytest.mark.parametrize('op', [operator.eq, operator.ne,
76-
operator.gt, operator.ge,
77-
operator.lt, operator.le])
78-
def test_comparison_tzawareness_compat(self, op):
79-
# GH#18162
80-
dr = pd.date_range('2016-01-01', periods=6)
81-
dz = dr.tz_localize('US/Pacific')
82-
83-
with pytest.raises(TypeError):
84-
op(dr, dz)
85-
with pytest.raises(TypeError):
86-
op(dr, list(dz))
87-
with pytest.raises(TypeError):
88-
op(dz, dr)
89-
with pytest.raises(TypeError):
90-
op(dz, list(dr))
91-
92-
# Check that there isn't a problem aware-aware and naive-naive do not
93-
# raise
94-
assert (dr == dr).all()
95-
assert (dr == list(dr)).all()
96-
assert (dz == dz).all()
97-
assert (dz == list(dz)).all()
98-
99-
# Check comparisons against scalar Timestamps
100-
ts = pd.Timestamp('2000-03-14 01:59')
101-
ts_tz = pd.Timestamp('2000-03-14 01:59', tz='Europe/Amsterdam')
102-
103-
assert (dr > ts).all()
104-
with pytest.raises(TypeError):
105-
op(dr, ts_tz)
106-
107-
assert (dz > ts_tz).all()
108-
with pytest.raises(TypeError):
109-
op(dz, ts)
110-
111-
@pytest.mark.parametrize('op', [operator.eq, operator.ne,
112-
operator.gt, operator.ge,
113-
operator.lt, operator.le])
114-
def test_nat_comparison_tzawareness(self, op):
115-
# GH#19276
116-
# tzaware DatetimeIndex should not raise when compared to NaT
117-
dti = pd.DatetimeIndex(['2014-01-01', pd.NaT, '2014-03-01', pd.NaT,
118-
'2014-05-01', '2014-07-01'])
119-
expected = np.array([op == operator.ne] * len(dti))
120-
result = op(dti, pd.NaT)
121-
tm.assert_numpy_array_equal(result, expected)
122-
123-
result = op(dti.tz_localize('US/Pacific'), pd.NaT)
124-
tm.assert_numpy_array_equal(result, expected)
125-
126-
def test_comparisons_coverage(self):
127-
rng = date_range('1/1/2000', periods=10)
128-
129-
# raise TypeError for now
130-
pytest.raises(TypeError, rng.__lt__, rng[3].value)
131-
132-
result = rng == list(rng)
133-
exp = rng == rng
134-
tm.assert_numpy_array_equal(result, exp)
135-
136-
def test_comparisons_nat(self):
137-
152+
def test_dti_cmp_nat_behaves_like_float_cmp_nan(self):
138153
fidx1 = pd.Index([1.0, np.nan, 3.0, np.nan, 5.0, 7.0])
139154
fidx2 = pd.Index([2.0, 3.0, np.nan, np.nan, 6.0, 7.0])
140155

@@ -223,6 +238,71 @@ def test_comparisons_nat(self):
223238
expected = np.array([True, True, False, True, True, True])
224239
tm.assert_numpy_array_equal(result, expected)
225240

241+
@pytest.mark.parametrize('op', [operator.eq, operator.ne,
242+
operator.gt, operator.ge,
243+
operator.lt, operator.le])
244+
def test_comparison_tzawareness_compat(self, op):
245+
# GH#18162
246+
dr = pd.date_range('2016-01-01', periods=6)
247+
dz = dr.tz_localize('US/Pacific')
248+
249+
with pytest.raises(TypeError):
250+
op(dr, dz)
251+
with pytest.raises(TypeError):
252+
op(dr, list(dz))
253+
with pytest.raises(TypeError):
254+
op(dz, dr)
255+
with pytest.raises(TypeError):
256+
op(dz, list(dr))
257+
258+
# Check that there isn't a problem aware-aware and naive-naive do not
259+
# raise
260+
assert (dr == dr).all()
261+
assert (dr == list(dr)).all()
262+
assert (dz == dz).all()
263+
assert (dz == list(dz)).all()
264+
265+
# Check comparisons against scalar Timestamps
266+
ts = pd.Timestamp('2000-03-14 01:59')
267+
ts_tz = pd.Timestamp('2000-03-14 01:59', tz='Europe/Amsterdam')
268+
269+
assert (dr > ts).all()
270+
with pytest.raises(TypeError):
271+
op(dr, ts_tz)
272+
273+
assert (dz > ts_tz).all()
274+
with pytest.raises(TypeError):
275+
op(dz, ts)
276+
277+
@pytest.mark.parametrize('op', [operator.eq, operator.ne,
278+
operator.gt, operator.ge,
279+
operator.lt, operator.le])
280+
def test_nat_comparison_tzawareness(self, op):
281+
# GH#19276
282+
# tzaware DatetimeIndex should not raise when compared to NaT
283+
dti = pd.DatetimeIndex(['2014-01-01', pd.NaT, '2014-03-01', pd.NaT,
284+
'2014-05-01', '2014-07-01'])
285+
expected = np.array([op == operator.ne] * len(dti))
286+
result = op(dti, pd.NaT)
287+
tm.assert_numpy_array_equal(result, expected)
288+
289+
result = op(dti.tz_localize('US/Pacific'), pd.NaT)
290+
tm.assert_numpy_array_equal(result, expected)
291+
292+
def test_dti_cmp_int_raises(self):
293+
rng = date_range('1/1/2000', periods=10)
294+
295+
# raise TypeError for now
296+
with pytest.raises(TypeError):
297+
rng < rng[3].value
298+
299+
def test_dti_cmp_list(self):
300+
rng = date_range('1/1/2000', periods=10)
301+
302+
result = rng == list(rng)
303+
expected = rng == rng
304+
tm.assert_numpy_array_equal(result, expected)
305+
226306

227307
class TestDatetimeIndexArithmetic(object):
228308

0 commit comments

Comments
 (0)