Skip to content

Commit 87fefe2

Browse files
jbrockmendeljreback
authored andcommitted
dispatch Series[datetime64] comparison ops to DatetimeIndex (pandas-dev#19800)
1 parent 9242248 commit 87fefe2

File tree

5 files changed

+55
-31
lines changed

5 files changed

+55
-31
lines changed

pandas/core/indexes/datetimes.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,9 @@ def wrapper(self, other):
138138
result = func(np.asarray(other))
139139
result = com._values_from_object(result)
140140

141-
if isinstance(other, Index):
142-
o_mask = other.values.view('i8') == libts.iNaT
143-
else:
144-
o_mask = other.view('i8') == libts.iNaT
145-
141+
# Make sure to pass an array to result[...]; indexing with
142+
# Series breaks with older version of numpy
143+
o_mask = np.array(isna(other))
146144
if o_mask.any():
147145
result[o_mask] = nat_result
148146

pandas/core/ops.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
import numpy as np
1111
import pandas as pd
1212

13-
from pandas._libs import (lib, index as libindex,
14-
algos as libalgos, ops as libops)
13+
from pandas._libs import algos as libalgos, ops as libops
1514

1615
from pandas import compat
1716
from pandas.util._decorators import Appender
@@ -1127,24 +1126,20 @@ def na_op(x, y):
11271126
# integer comparisons
11281127

11291128
# we have a datetime/timedelta and may need to convert
1129+
assert not needs_i8_conversion(x)
11301130
mask = None
1131-
if (needs_i8_conversion(x) or
1132-
(not is_scalar(y) and needs_i8_conversion(y))):
1133-
1134-
if is_scalar(y):
1135-
mask = isna(x)
1136-
y = libindex.convert_scalar(x, com._values_from_object(y))
1137-
else:
1138-
mask = isna(x) | isna(y)
1139-
y = y.view('i8')
1131+
if not is_scalar(y) and needs_i8_conversion(y):
1132+
mask = isna(x) | isna(y)
1133+
y = y.view('i8')
11401134
x = x.view('i8')
11411135

1142-
try:
1136+
method = getattr(x, name, None)
1137+
if method is not None:
11431138
with np.errstate(all='ignore'):
1144-
result = getattr(x, name)(y)
1139+
result = method(y)
11451140
if result is NotImplemented:
11461141
raise TypeError("invalid type comparison")
1147-
except AttributeError:
1142+
else:
11481143
result = op(x, y)
11491144

11501145
if mask is not None and mask.any():
@@ -1174,6 +1169,14 @@ def wrapper(self, other, axis=None):
11741169
return self._constructor(res_values, index=self.index,
11751170
name=res_name)
11761171

1172+
if is_datetime64_dtype(self) or is_datetime64tz_dtype(self):
1173+
# Dispatch to DatetimeIndex to ensure identical
1174+
# Series/Index behavior
1175+
res_values = dispatch_to_index_op(op, self, other,
1176+
pd.DatetimeIndex)
1177+
return self._constructor(res_values, index=self.index,
1178+
name=res_name)
1179+
11771180
elif is_timedelta64_dtype(self):
11781181
res_values = dispatch_to_index_op(op, self, other,
11791182
pd.TimedeltaIndex)
@@ -1191,8 +1194,7 @@ def wrapper(self, other, axis=None):
11911194
elif isinstance(other, (np.ndarray, pd.Index)):
11921195
# do not check length of zerodim array
11931196
# as it will broadcast
1194-
if (not is_scalar(lib.item_from_zerodim(other)) and
1195-
len(self) != len(other)):
1197+
if other.ndim != 0 and len(self) != len(other):
11961198
raise ValueError('Lengths must match to compare')
11971199

11981200
res_values = na_op(self.values, np.asarray(other))

pandas/tests/indexes/datetimes/test_partial_slicing.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from datetime import datetime, date
5+
from datetime import datetime
66
import numpy as np
77
import pandas as pd
88
import operator as op
@@ -349,7 +349,7 @@ def test_loc_datetime_length_one(self):
349349

350350
@pytest.mark.parametrize('datetimelike', [
351351
Timestamp('20130101'), datetime(2013, 1, 1),
352-
date(2013, 1, 1), np.datetime64('2013-01-01T00:00', 'ns')])
352+
np.datetime64('2013-01-01T00:00', 'ns')])
353353
@pytest.mark.parametrize('op,expected', [
354354
(op.lt, [True, False, False, False]),
355355
(op.le, [True, True, False, False]),

pandas/tests/series/test_arithmetic.py

+17
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,23 @@ def test_ser_cmp_result_names(self, names, op):
8888

8989

9090
class TestTimestampSeriesComparison(object):
91+
def test_dt64ser_cmp_date_invalid(self):
92+
# GH#19800 datetime.date comparison raises to
93+
# match DatetimeIndex/Timestamp. This also matches the behavior
94+
# of stdlib datetime.datetime
95+
ser = pd.Series(pd.date_range('20010101', periods=10), name='dates')
96+
date = ser.iloc[0].to_pydatetime().date()
97+
assert not (ser == date).any()
98+
assert (ser != date).all()
99+
with pytest.raises(TypeError):
100+
ser > date
101+
with pytest.raises(TypeError):
102+
ser < date
103+
with pytest.raises(TypeError):
104+
ser >= date
105+
with pytest.raises(TypeError):
106+
ser <= date
107+
91108
def test_dt64ser_cmp_period_scalar(self):
92109
ser = Series(pd.period_range('2000-01-01', periods=10, freq='D'))
93110
val = Period('2000-01-04', freq='D')

pandas/tests/test_base.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pandas as pd
1111
import pandas.compat as compat
1212
from pandas.core.dtypes.common import (
13-
is_object_dtype, is_datetimetz,
13+
is_object_dtype, is_datetimetz, is_datetime64_dtype,
1414
needs_i8_conversion)
1515
import pandas.util.testing as tm
1616
from pandas import (Series, Index, DatetimeIndex, TimedeltaIndex,
@@ -296,14 +296,21 @@ def test_none_comparison(self):
296296
# result = None != o # noqa
297297
# assert result.iat[0]
298298
# assert result.iat[1]
299+
if (is_datetime64_dtype(o) or is_datetimetz(o)):
300+
# Following DatetimeIndex (and Timestamp) convention,
301+
# inequality comparisons with Series[datetime64] raise
302+
with pytest.raises(TypeError):
303+
None > o
304+
with pytest.raises(TypeError):
305+
o > None
306+
else:
307+
result = None > o
308+
assert not result.iat[0]
309+
assert not result.iat[1]
299310

300-
result = None > o
301-
assert not result.iat[0]
302-
assert not result.iat[1]
303-
304-
result = o < None
305-
assert not result.iat[0]
306-
assert not result.iat[1]
311+
result = o < None
312+
assert not result.iat[0]
313+
assert not result.iat[1]
307314

308315
def test_ndarray_compat_properties(self):
309316

0 commit comments

Comments
 (0)