Skip to content

Commit 2e8c993

Browse files
sinhrksjreback
authored andcommitted
BUG: Series contains NaT with object dtype comparison incorrect (pandas-dev#13592)
closes pandas-dev#9005
1 parent c2cc68d commit 2e8c993

File tree

8 files changed

+208
-44
lines changed

8 files changed

+208
-44
lines changed

doc/source/whatsnew/v0.19.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,8 @@ Bug Fixes
527527
- Bug in extension dtype creation where the created types were not is/identical (:issue:`13285`)
528528

529529
- Bug in ``NaT`` - ``Period`` raises ``AttributeError`` (:issue:`13071`)
530+
- Bug in ``Series`` comparison may output incorrect result if rhs contains ``NaT`` (:issue:`9005`)
531+
- Bug in ``Series`` and ``Index`` comparison may output incorrect result if it contains ``NaT`` with ``object`` dtype (:issue:`13592`)
530532
- Bug in ``Period`` addition raises ``TypeError`` if ``Period`` is on right hand side (:issue:`13069`)
531533
- Bug in ``Peirod`` and ``Series`` or ``Index`` comparison raises ``TypeError`` (:issue:`13200`)
532534
- Bug in ``pd.set_eng_float_format()`` that would prevent NaN's from formatting (:issue:`11981`)

pandas/core/ops.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
is_integer_dtype, is_categorical_dtype,
2828
is_object_dtype, is_timedelta64_dtype,
2929
is_datetime64_dtype, is_datetime64tz_dtype,
30-
is_bool_dtype, PerformanceWarning, ABCSeries)
30+
is_bool_dtype, PerformanceWarning,
31+
ABCSeries, ABCIndex)
3132

3233
# -----------------------------------------------------------------------------
3334
# Functions that add arithmetic methods to objects, given arithmetic factory
@@ -664,6 +665,22 @@ def wrapper(left, right, name=name, na_op=na_op):
664665
return wrapper
665666

666667

668+
def _comp_method_OBJECT_ARRAY(op, x, y):
669+
if isinstance(y, list):
670+
y = lib.list_to_object_array(y)
671+
if isinstance(y, (np.ndarray, ABCSeries, ABCIndex)):
672+
if not is_object_dtype(y.dtype):
673+
y = y.astype(np.object_)
674+
675+
if isinstance(y, (ABCSeries, ABCIndex)):
676+
y = y.values
677+
678+
result = lib.vec_compare(x, y, op)
679+
else:
680+
result = lib.scalar_compare(x, y, op)
681+
return result
682+
683+
667684
def _comp_method_SERIES(op, name, str_rep, masker=False):
668685
"""
669686
Wrapper function for Series arithmetic operations, to avoid
@@ -680,16 +697,7 @@ def na_op(x, y):
680697
return op(y, x)
681698

682699
if is_object_dtype(x.dtype):
683-
if isinstance(y, list):
684-
y = lib.list_to_object_array(y)
685-
686-
if isinstance(y, (np.ndarray, ABCSeries)):
687-
if not is_object_dtype(y.dtype):
688-
result = lib.vec_compare(x, y.astype(np.object_), op)
689-
else:
690-
result = lib.vec_compare(x, y, op)
691-
else:
692-
result = lib.scalar_compare(x, y, op)
700+
result = _comp_method_OBJECT_ARRAY(op, x, y)
693701
else:
694702

695703
# we want to compare like types
@@ -713,12 +721,11 @@ def na_op(x, y):
713721
(not isscalar(y) and needs_i8_conversion(y))):
714722

715723
if isscalar(y):
724+
mask = isnull(x)
716725
y = _index.convert_scalar(x, _values_from_object(y))
717726
else:
727+
mask = isnull(x) | isnull(y)
718728
y = y.view('i8')
719-
720-
mask = isnull(x)
721-
722729
x = x.view('i8')
723730

724731
try:

pandas/indexes/base.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
is_list_like, is_bool_dtype,
3232
is_integer_dtype, is_float_dtype,
3333
needs_i8_conversion)
34+
from pandas.core.ops import _comp_method_OBJECT_ARRAY
3435
from pandas.core.strings import StringAccessorMixin
3536

3637
from pandas.core.config import get_option
@@ -3182,8 +3183,11 @@ def _evaluate_compare(self, other):
31823183
if needs_i8_conversion(self) and needs_i8_conversion(other):
31833184
return self._evaluate_compare(other, op)
31843185

3185-
func = getattr(self.values, op)
3186-
result = func(np.asarray(other))
3186+
if is_object_dtype(self) and self.nlevels == 1:
3187+
# don't pass MultiIndex
3188+
result = _comp_method_OBJECT_ARRAY(op, self.values, other)
3189+
else:
3190+
result = op(self.values, np.asarray(other))
31873191

31883192
# technically we could support bool dtyped Index
31893193
# for now just return the indexing array directly
@@ -3196,12 +3200,12 @@ def _evaluate_compare(self, other):
31963200

31973201
return _evaluate_compare
31983202

3199-
cls.__eq__ = _make_compare('__eq__')
3200-
cls.__ne__ = _make_compare('__ne__')
3201-
cls.__lt__ = _make_compare('__lt__')
3202-
cls.__gt__ = _make_compare('__gt__')
3203-
cls.__le__ = _make_compare('__le__')
3204-
cls.__ge__ = _make_compare('__ge__')
3203+
cls.__eq__ = _make_compare(operator.eq)
3204+
cls.__ne__ = _make_compare(operator.ne)
3205+
cls.__lt__ = _make_compare(operator.lt)
3206+
cls.__gt__ = _make_compare(operator.gt)
3207+
cls.__le__ = _make_compare(operator.le)
3208+
cls.__ge__ = _make_compare(operator.ge)
32053209

32063210
@classmethod
32073211
def _add_numericlike_set_methods_disabled(cls):

pandas/lib.pyx

+6-6
Original file line numberDiff line numberDiff line change
@@ -768,12 +768,12 @@ def scalar_compare(ndarray[object] values, object val, object op):
768768
raise ValueError('Unrecognized operator')
769769

770770
result = np.empty(n, dtype=bool).view(np.uint8)
771-
isnull_val = _checknull(val)
771+
isnull_val = checknull(val)
772772

773773
if flag == cpython.Py_NE:
774774
for i in range(n):
775775
x = values[i]
776-
if _checknull(x):
776+
if checknull(x):
777777
result[i] = True
778778
elif isnull_val:
779779
result[i] = True
@@ -785,7 +785,7 @@ def scalar_compare(ndarray[object] values, object val, object op):
785785
elif flag == cpython.Py_EQ:
786786
for i in range(n):
787787
x = values[i]
788-
if _checknull(x):
788+
if checknull(x):
789789
result[i] = False
790790
elif isnull_val:
791791
result[i] = False
@@ -798,7 +798,7 @@ def scalar_compare(ndarray[object] values, object val, object op):
798798
else:
799799
for i in range(n):
800800
x = values[i]
801-
if _checknull(x):
801+
if checknull(x):
802802
result[i] = False
803803
elif isnull_val:
804804
result[i] = False
@@ -864,7 +864,7 @@ def vec_compare(ndarray[object] left, ndarray[object] right, object op):
864864
x = left[i]
865865
y = right[i]
866866

867-
if _checknull(x) or _checknull(y):
867+
if checknull(x) or checknull(y):
868868
result[i] = True
869869
else:
870870
result[i] = cpython.PyObject_RichCompareBool(x, y, flag)
@@ -873,7 +873,7 @@ def vec_compare(ndarray[object] left, ndarray[object] right, object op):
873873
x = left[i]
874874
y = right[i]
875875

876-
if _checknull(x) or _checknull(y):
876+
if checknull(x) or checknull(y):
877877
result[i] = False
878878
else:
879879
result[i] = cpython.PyObject_RichCompareBool(x, y, flag)

pandas/tests/series/test_operators.py

+87-14
Original file line numberDiff line numberDiff line change
@@ -980,24 +980,97 @@ def test_comparison_invalid(self):
980980
self.assertRaises(TypeError, lambda: x <= y)
981981

982982
def test_more_na_comparisons(self):
983-
left = Series(['a', np.nan, 'c'])
984-
right = Series(['a', np.nan, 'd'])
983+
for dtype in [None, object]:
984+
left = Series(['a', np.nan, 'c'], dtype=dtype)
985+
right = Series(['a', np.nan, 'd'], dtype=dtype)
985986

986-
result = left == right
987-
expected = Series([True, False, False])
988-
assert_series_equal(result, expected)
987+
result = left == right
988+
expected = Series([True, False, False])
989+
assert_series_equal(result, expected)
989990

990-
result = left != right
991-
expected = Series([False, True, True])
992-
assert_series_equal(result, expected)
991+
result = left != right
992+
expected = Series([False, True, True])
993+
assert_series_equal(result, expected)
993994

994-
result = left == np.nan
995-
expected = Series([False, False, False])
996-
assert_series_equal(result, expected)
995+
result = left == np.nan
996+
expected = Series([False, False, False])
997+
assert_series_equal(result, expected)
997998

998-
result = left != np.nan
999-
expected = Series([True, True, True])
1000-
assert_series_equal(result, expected)
999+
result = left != np.nan
1000+
expected = Series([True, True, True])
1001+
assert_series_equal(result, expected)
1002+
1003+
def test_nat_comparisons(self):
1004+
data = [([pd.Timestamp('2011-01-01'), pd.NaT,
1005+
pd.Timestamp('2011-01-03')],
1006+
[pd.NaT, pd.NaT, pd.Timestamp('2011-01-03')]),
1007+
1008+
([pd.Timedelta('1 days'), pd.NaT,
1009+
pd.Timedelta('3 days')],
1010+
[pd.NaT, pd.NaT, pd.Timedelta('3 days')]),
1011+
1012+
([pd.Period('2011-01', freq='M'), pd.NaT,
1013+
pd.Period('2011-03', freq='M')],
1014+
[pd.NaT, pd.NaT, pd.Period('2011-03', freq='M')])]
1015+
1016+
# add lhs / rhs switched data
1017+
data = data + [(r, l) for l, r in data]
1018+
1019+
for l, r in data:
1020+
for dtype in [None, object]:
1021+
left = Series(l, dtype=dtype)
1022+
1023+
# Series, Index
1024+
for right in [Series(r, dtype=dtype), Index(r, dtype=dtype)]:
1025+
expected = Series([False, False, True])
1026+
assert_series_equal(left == right, expected)
1027+
1028+
expected = Series([True, True, False])
1029+
assert_series_equal(left != right, expected)
1030+
1031+
expected = Series([False, False, False])
1032+
assert_series_equal(left < right, expected)
1033+
1034+
expected = Series([False, False, False])
1035+
assert_series_equal(left > right, expected)
1036+
1037+
expected = Series([False, False, True])
1038+
assert_series_equal(left >= right, expected)
1039+
1040+
expected = Series([False, False, True])
1041+
assert_series_equal(left <= right, expected)
1042+
1043+
def test_nat_comparisons_scalar(self):
1044+
data = [[pd.Timestamp('2011-01-01'), pd.NaT,
1045+
pd.Timestamp('2011-01-03')],
1046+
1047+
[pd.Timedelta('1 days'), pd.NaT, pd.Timedelta('3 days')],
1048+
1049+
[pd.Period('2011-01', freq='M'), pd.NaT,
1050+
pd.Period('2011-03', freq='M')]]
1051+
1052+
for l in data:
1053+
for dtype in [None, object]:
1054+
left = Series(l, dtype=dtype)
1055+
1056+
expected = Series([False, False, False])
1057+
assert_series_equal(left == pd.NaT, expected)
1058+
assert_series_equal(pd.NaT == left, expected)
1059+
1060+
expected = Series([True, True, True])
1061+
assert_series_equal(left != pd.NaT, expected)
1062+
assert_series_equal(pd.NaT != left, expected)
1063+
1064+
expected = Series([False, False, False])
1065+
assert_series_equal(left < pd.NaT, expected)
1066+
assert_series_equal(pd.NaT > left, expected)
1067+
assert_series_equal(left <= pd.NaT, expected)
1068+
assert_series_equal(pd.NaT >= left, expected)
1069+
1070+
assert_series_equal(left > pd.NaT, expected)
1071+
assert_series_equal(pd.NaT < left, expected)
1072+
assert_series_equal(left >= pd.NaT, expected)
1073+
assert_series_equal(pd.NaT <= left, expected)
10011074

10021075
def test_comparison_different_length(self):
10031076
a = Series(['a', 'b', 'c'])

pandas/tseries/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _evaluate_compare(self, other, op):
142142
other = type(self)(other)
143143

144144
# compare
145-
result = getattr(self.asi8, op)(other.asi8)
145+
result = op(self.asi8, other.asi8)
146146

147147
# technically we could support bool dtyped Index
148148
# for now just return the indexing array directly

pandas/tseries/tdi.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _td_index_cmp(opname, nat_result=False):
3636

3737
def wrapper(self, other):
3838
func = getattr(super(TimedeltaIndex, self), opname)
39-
if _is_convertible_to_td(other):
39+
if _is_convertible_to_td(other) or other is tslib.NaT:
4040
other = _to_m8(other)
4141
result = func(other)
4242
if com.isnull(other):

pandas/tseries/tests/test_base.py

+78
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,32 @@ def test_sub_period(self):
458458
with tm.assertRaises(TypeError):
459459
p - idx
460460

461+
def test_comp_nat(self):
462+
left = pd.DatetimeIndex([pd.Timestamp('2011-01-01'), pd.NaT,
463+
pd.Timestamp('2011-01-03')])
464+
right = pd.DatetimeIndex([pd.NaT, pd.NaT, pd.Timestamp('2011-01-03')])
465+
466+
for l, r in [(left, right), (left.asobject, right.asobject)]:
467+
result = l == r
468+
expected = np.array([False, False, True])
469+
tm.assert_numpy_array_equal(result, expected)
470+
471+
result = l != r
472+
expected = np.array([True, True, False])
473+
tm.assert_numpy_array_equal(result, expected)
474+
475+
expected = np.array([False, False, False])
476+
tm.assert_numpy_array_equal(l == pd.NaT, expected)
477+
tm.assert_numpy_array_equal(pd.NaT == r, expected)
478+
479+
expected = np.array([True, True, True])
480+
tm.assert_numpy_array_equal(l != pd.NaT, expected)
481+
tm.assert_numpy_array_equal(pd.NaT != l, expected)
482+
483+
expected = np.array([False, False, False])
484+
tm.assert_numpy_array_equal(l < pd.NaT, expected)
485+
tm.assert_numpy_array_equal(pd.NaT > l, expected)
486+
461487
def test_value_counts_unique(self):
462488
# GH 7735
463489
for tz in [None, 'UTC', 'Asia/Tokyo', 'US/Eastern']:
@@ -1238,6 +1264,32 @@ def test_addition_ops(self):
12381264
expected = Timestamp('20130102')
12391265
self.assertEqual(result, expected)
12401266

1267+
def test_comp_nat(self):
1268+
left = pd.TimedeltaIndex([pd.Timedelta('1 days'), pd.NaT,
1269+
pd.Timedelta('3 days')])
1270+
right = pd.TimedeltaIndex([pd.NaT, pd.NaT, pd.Timedelta('3 days')])
1271+
1272+
for l, r in [(left, right), (left.asobject, right.asobject)]:
1273+
result = l == r
1274+
expected = np.array([False, False, True])
1275+
tm.assert_numpy_array_equal(result, expected)
1276+
1277+
result = l != r
1278+
expected = np.array([True, True, False])
1279+
tm.assert_numpy_array_equal(result, expected)
1280+
1281+
expected = np.array([False, False, False])
1282+
tm.assert_numpy_array_equal(l == pd.NaT, expected)
1283+
tm.assert_numpy_array_equal(pd.NaT == r, expected)
1284+
1285+
expected = np.array([True, True, True])
1286+
tm.assert_numpy_array_equal(l != pd.NaT, expected)
1287+
tm.assert_numpy_array_equal(pd.NaT != l, expected)
1288+
1289+
expected = np.array([False, False, False])
1290+
tm.assert_numpy_array_equal(l < pd.NaT, expected)
1291+
tm.assert_numpy_array_equal(pd.NaT > l, expected)
1292+
12411293
def test_value_counts_unique(self):
12421294
# GH 7735
12431295

@@ -2039,6 +2091,32 @@ def test_sub_isub(self):
20392091
rng -= 1
20402092
tm.assert_index_equal(rng, expected)
20412093

2094+
def test_comp_nat(self):
2095+
left = pd.PeriodIndex([pd.Period('2011-01-01'), pd.NaT,
2096+
pd.Period('2011-01-03')])
2097+
right = pd.PeriodIndex([pd.NaT, pd.NaT, pd.Period('2011-01-03')])
2098+
2099+
for l, r in [(left, right), (left.asobject, right.asobject)]:
2100+
result = l == r
2101+
expected = np.array([False, False, True])
2102+
tm.assert_numpy_array_equal(result, expected)
2103+
2104+
result = l != r
2105+
expected = np.array([True, True, False])
2106+
tm.assert_numpy_array_equal(result, expected)
2107+
2108+
expected = np.array([False, False, False])
2109+
tm.assert_numpy_array_equal(l == pd.NaT, expected)
2110+
tm.assert_numpy_array_equal(pd.NaT == r, expected)
2111+
2112+
expected = np.array([True, True, True])
2113+
tm.assert_numpy_array_equal(l != pd.NaT, expected)
2114+
tm.assert_numpy_array_equal(pd.NaT != l, expected)
2115+
2116+
expected = np.array([False, False, False])
2117+
tm.assert_numpy_array_equal(l < pd.NaT, expected)
2118+
tm.assert_numpy_array_equal(pd.NaT > l, expected)
2119+
20422120
def test_value_counts_unique(self):
20432121
# GH 7735
20442122
idx = pd.period_range('2011-01-01 09:00', freq='H', periods=10)

0 commit comments

Comments
 (0)