Skip to content

Commit c8c57a8

Browse files
Backport PR pandas-dev#40723: BUG: fix comparison of NaT with numpy array (pandas-dev#40734)
Co-authored-by: Joris Van den Bossche <[email protected]>
1 parent fd2a85a commit c8c57a8

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed

doc/source/whatsnew/v1.2.4.rst

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Fixed regressions
1717

1818
- Fixed regression in :meth:`DataFrame.sum` when ``min_count`` greater than the :class:`DataFrame` shape was passed resulted in a ``ValueError`` (:issue:`39738`)
1919
- Fixed regression in :meth:`DataFrame.to_json` raising ``AttributeError`` when run on PyPy (:issue:`39837`)
20+
- Fixed regression in (in)equality comparison of ``pd.NaT`` with a non-datetimelike numpy array returning a scalar instead of an array (:issue:`40722`)
2021
- Fixed regression in :meth:`DataFrame.where` not returning a copy in the case of an all True condition (:issue:`39595`)
2122
- Fixed regression in :meth:`DataFrame.replace` raising ``IndexError`` when ``regex`` was a multi-key dictionary (:issue:`39338`)
2223
-

pandas/_libs/tslibs/nattype.pyx

+4
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ cdef class _NaT(datetime):
124124
result.fill(_nat_scalar_rules[op])
125125
elif other.dtype.kind == "O":
126126
result = np.array([PyObject_RichCompare(self, x, op) for x in other])
127+
elif op == Py_EQ:
128+
result = np.zeros(other.shape, dtype=bool)
129+
elif op == Py_NE:
130+
result = np.ones(other.shape, dtype=bool)
127131
else:
128132
return NotImplemented
129133
return result

pandas/tests/scalar/test_nat.py

+41
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,47 @@ def test_nat_comparisons_invalid(other, op):
575575
op(other, NaT)
576576

577577

578+
@pytest.mark.parametrize(
579+
"other",
580+
[
581+
np.array(["foo"] * 2, dtype=object),
582+
np.array([2, 3], dtype="int64"),
583+
np.array([2.0, 3.5], dtype="float64"),
584+
],
585+
ids=["str", "int", "float"],
586+
)
587+
def test_nat_comparisons_invalid_ndarray(other):
588+
# GH#40722
589+
expected = np.array([False, False])
590+
result = NaT == other
591+
tm.assert_numpy_array_equal(result, expected)
592+
result = other == NaT
593+
tm.assert_numpy_array_equal(result, expected)
594+
595+
expected = np.array([True, True])
596+
result = NaT != other
597+
tm.assert_numpy_array_equal(result, expected)
598+
result = other != NaT
599+
tm.assert_numpy_array_equal(result, expected)
600+
601+
for symbol, op in [
602+
("<=", operator.le),
603+
("<", operator.lt),
604+
(">=", operator.ge),
605+
(">", operator.gt),
606+
]:
607+
msg = f"'{symbol}' not supported between"
608+
609+
with pytest.raises(TypeError, match=msg):
610+
op(NaT, other)
611+
612+
if other.dtype == np.dtype("object"):
613+
# uses the reverse operator, so symbol changes
614+
msg = None
615+
with pytest.raises(TypeError, match=msg):
616+
op(other, NaT)
617+
618+
578619
def test_compare_date():
579620
# GH#39151 comparing NaT with date object is deprecated
580621
# See also: tests.scalar.timestamps.test_comparisons::test_compare_date

0 commit comments

Comments
 (0)