Skip to content

Commit 694d058

Browse files
BUG: fix comparison of NaT with numpy array (#40723)
1 parent d925376 commit 694d058

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed

doc/source/whatsnew/v1.2.4.rst

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Fixed regressions
1717

1818
- Fixed regression in :meth:`DataFrame.sum` when ``min_count`` greater than the :class:`DataFrame` shape was passed resulted in a ``ValueError`` (:issue:`39738`)
1919
- Fixed regression in :meth:`DataFrame.to_json` raising ``AttributeError`` when run on PyPy (:issue:`39837`)
20+
- Fixed regression in (in)equality comparison of ``pd.NaT`` with a non-datetimelike numpy array returning a scalar instead of an array (:issue:`40722`)
2021
- Fixed regression in :meth:`DataFrame.where` not returning a copy in the case of an all True condition (:issue:`39595`)
2122
- Fixed regression in :meth:`DataFrame.replace` raising ``IndexError`` when ``regex`` was a multi-key dictionary (:issue:`39338`)
2223
-

pandas/_libs/tslibs/nattype.pyx

+4
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ cdef class _NaT(datetime):
127127
result.fill(_nat_scalar_rules[op])
128128
elif other.dtype.kind == "O":
129129
result = np.array([PyObject_RichCompare(self, x, op) for x in other])
130+
elif op == Py_EQ:
131+
result = np.zeros(other.shape, dtype=bool)
132+
elif op == Py_NE:
133+
result = np.ones(other.shape, dtype=bool)
130134
else:
131135
return NotImplemented
132136
return result

pandas/tests/scalar/test_nat.py

+41
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,47 @@ def test_nat_comparisons_invalid(other_and_type, symbol_and_op):
590590
op(other, NaT)
591591

592592

593+
@pytest.mark.parametrize(
594+
"other",
595+
[
596+
np.array(["foo"] * 2, dtype=object),
597+
np.array([2, 3], dtype="int64"),
598+
np.array([2.0, 3.5], dtype="float64"),
599+
],
600+
ids=["str", "int", "float"],
601+
)
602+
def test_nat_comparisons_invalid_ndarray(other):
603+
# GH#40722
604+
expected = np.array([False, False])
605+
result = NaT == other
606+
tm.assert_numpy_array_equal(result, expected)
607+
result = other == NaT
608+
tm.assert_numpy_array_equal(result, expected)
609+
610+
expected = np.array([True, True])
611+
result = NaT != other
612+
tm.assert_numpy_array_equal(result, expected)
613+
result = other != NaT
614+
tm.assert_numpy_array_equal(result, expected)
615+
616+
for symbol, op in [
617+
("<=", operator.le),
618+
("<", operator.lt),
619+
(">=", operator.ge),
620+
(">", operator.gt),
621+
]:
622+
msg = f"'{symbol}' not supported between"
623+
624+
with pytest.raises(TypeError, match=msg):
625+
op(NaT, other)
626+
627+
if other.dtype == np.dtype("object"):
628+
# uses the reverse operator, so symbol changes
629+
msg = None
630+
with pytest.raises(TypeError, match=msg):
631+
op(other, NaT)
632+
633+
593634
def test_compare_date():
594635
# GH#39151 comparing NaT with date object is deprecated
595636
# See also: tests.scalar.timestamps.test_comparisons::test_compare_date

0 commit comments

Comments
 (0)