diff --git a/doc/source/whatsnew/v1.2.4.rst b/doc/source/whatsnew/v1.2.4.rst index 26d768f830830..9cef1307278e8 100644 --- a/doc/source/whatsnew/v1.2.4.rst +++ b/doc/source/whatsnew/v1.2.4.rst @@ -17,6 +17,7 @@ Fixed regressions - Fixed regression in :meth:`DataFrame.sum` when ``min_count`` greater than the :class:`DataFrame` shape was passed resulted in a ``ValueError`` (:issue:`39738`) - Fixed regression in :meth:`DataFrame.to_json` raising ``AttributeError`` when run on PyPy (:issue:`39837`) +- Fixed regression in (in)equality comparison of ``pd.NaT`` with a non-datetimelike numpy array returning a scalar instead of an array (:issue:`40722`) - Fixed regression in :meth:`DataFrame.where` not returning a copy in the case of an all True condition (:issue:`39595`) - Fixed regression in :meth:`DataFrame.replace` raising ``IndexError`` when ``regex`` was a multi-key dictionary (:issue:`39338`) - diff --git a/pandas/_libs/tslibs/nattype.pyx b/pandas/_libs/tslibs/nattype.pyx index d86d3261d404e..0c598beb6ad16 100644 --- a/pandas/_libs/tslibs/nattype.pyx +++ b/pandas/_libs/tslibs/nattype.pyx @@ -127,6 +127,10 @@ cdef class _NaT(datetime): result.fill(_nat_scalar_rules[op]) elif other.dtype.kind == "O": result = np.array([PyObject_RichCompare(self, x, op) for x in other]) + elif op == Py_EQ: + result = np.zeros(other.shape, dtype=bool) + elif op == Py_NE: + result = np.ones(other.shape, dtype=bool) else: return NotImplemented return result diff --git a/pandas/tests/scalar/test_nat.py b/pandas/tests/scalar/test_nat.py index 96aea4da9fac5..08c5ea706111a 100644 --- a/pandas/tests/scalar/test_nat.py +++ b/pandas/tests/scalar/test_nat.py @@ -590,6 +590,47 @@ def test_nat_comparisons_invalid(other_and_type, symbol_and_op): op(other, NaT) +@pytest.mark.parametrize( + "other", + [ + np.array(["foo"] * 2, dtype=object), + np.array([2, 3], dtype="int64"), + np.array([2.0, 3.5], dtype="float64"), + ], + ids=["str", "int", "float"], +) +def test_nat_comparisons_invalid_ndarray(other): + # GH#40722 + expected = np.array([False, False]) + result = NaT == other + tm.assert_numpy_array_equal(result, expected) + result = other == NaT + tm.assert_numpy_array_equal(result, expected) + + expected = np.array([True, True]) + result = NaT != other + tm.assert_numpy_array_equal(result, expected) + result = other != NaT + tm.assert_numpy_array_equal(result, expected) + + for symbol, op in [ + ("<=", operator.le), + ("<", operator.lt), + (">=", operator.ge), + (">", operator.gt), + ]: + msg = f"'{symbol}' not supported between" + + with pytest.raises(TypeError, match=msg): + op(NaT, other) + + if other.dtype == np.dtype("object"): + # uses the reverse operator, so symbol changes + msg = None + with pytest.raises(TypeError, match=msg): + op(other, NaT) + + def test_compare_date(): # GH#39151 comparing NaT with date object is deprecated # See also: tests.scalar.timestamps.test_comparisons::test_compare_date