Skip to content

Commit 602502e

Browse files
authored
BUG: equals/assert_numpy_array_equals with non-singleton NAs (#39650)
1 parent a2f42ac commit 602502e

File tree

7 files changed

+151
-5
lines changed

7 files changed

+151
-5
lines changed

doc/source/whatsnew/v1.3.0.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,8 @@ Other
454454
- Bug in :func:`pandas.testing.assert_series_equal`, :func:`pandas.testing.assert_frame_equal`, :func:`pandas.testing.assert_index_equal` and :func:`pandas.testing.assert_extension_array_equal` incorrectly raising when an attribute has an unrecognized NA type (:issue:`39461`)
455455
- Bug in :class:`Styler` where ``subset`` arg in methods raised an error for some valid multiindex slices (:issue:`33562`)
456456
- :class:`Styler` rendered HTML output minor alterations to support w3 good code standard (:issue:`39626`)
457-
-
457+
- Bug in :meth:`DataFrame.equals`, :meth:`Series.equals`, :meth:`Index.equals` with object-dtype containing ``np.datetime64("NaT")`` or ``np.timedelta64("NaT")`` (:issue:`39650`)
458+
458459

459460
.. ---------------------------------------------------------------------------
460461

pandas/_libs/lib.pyx

+5-2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ from pandas._libs.tslib import array_to_datetime
7373
from pandas._libs.missing cimport (
7474
C_NA,
7575
checknull,
76+
is_matching_na,
7677
is_null_datetime64,
7778
is_null_timedelta64,
7879
isnaobj,
@@ -584,8 +585,10 @@ def array_equivalent_object(left: object[:], right: object[:]) -> bool:
584585
return False
585586
elif (x is C_NA) ^ (y is C_NA):
586587
return False
587-
elif not (PyObject_RichCompareBool(x, y, Py_EQ) or
588-
(x is None or is_nan(x)) and (y is None or is_nan(y))):
588+
elif not (
589+
PyObject_RichCompareBool(x, y, Py_EQ)
590+
or is_matching_na(x, y, nan_matches_none=True)
591+
):
589592
return False
590593
except ValueError:
591594
# Avoid raising ValueError when comparing Numpy arrays to other types

pandas/_libs/missing.pxd

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from numpy cimport ndarray, uint8_t
22

33

4+
cpdef bint is_matching_na(object left, object right, bint nan_matches_none=*)
5+
46
cpdef bint checknull(object val)
57
cpdef bint checknull_old(object val)
68
cpdef ndarray[uint8_t] isnaobj(ndarray arr)

pandas/_libs/missing.pyx

+52
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,58 @@ cdef:
2929
bint is_32bit = not IS64
3030

3131

32+
cpdef bint is_matching_na(object left, object right, bint nan_matches_none=False):
33+
"""
34+
Check if two scalars are both NA of matching types.
35+
36+
Parameters
37+
----------
38+
left : Any
39+
right : Any
40+
nan_matches_none : bool, default False
41+
For backwards compatibility, consider NaN as matching None.
42+
43+
Returns
44+
-------
45+
bool
46+
"""
47+
if left is None:
48+
if nan_matches_none and util.is_nan(right):
49+
return True
50+
return right is None
51+
elif left is C_NA:
52+
return right is C_NA
53+
elif left is NaT:
54+
return right is NaT
55+
elif util.is_float_object(left):
56+
if nan_matches_none and right is None:
57+
return True
58+
return (
59+
util.is_nan(left)
60+
and util.is_float_object(right)
61+
and util.is_nan(right)
62+
)
63+
elif util.is_complex_object(left):
64+
return (
65+
util.is_nan(left)
66+
and util.is_complex_object(right)
67+
and util.is_nan(right)
68+
)
69+
elif util.is_datetime64_object(left):
70+
return (
71+
get_datetime64_value(left) == NPY_NAT
72+
and util.is_datetime64_object(right)
73+
and get_datetime64_value(right) == NPY_NAT
74+
)
75+
elif util.is_timedelta64_object(left):
76+
return (
77+
get_timedelta64_value(left) == NPY_NAT
78+
and util.is_timedelta64_object(right)
79+
and get_timedelta64_value(right) == NPY_NAT
80+
)
81+
return False
82+
83+
3284
cpdef bint checknull(object val):
3385
"""
3486
Return boolean describing of the input is NA-like, defined here as any

pandas/tests/dtypes/test_missing.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pandas._libs import missing as libmissing
1111
from pandas._libs.tslibs import iNaT, is_null_datetimelike
1212

13-
from pandas.core.dtypes.common import is_scalar
13+
from pandas.core.dtypes.common import is_float, is_scalar
1414
from pandas.core.dtypes.dtypes import DatetimeTZDtype, IntervalDtype, PeriodDtype
1515
from pandas.core.dtypes.missing import (
1616
array_equivalent,
@@ -653,3 +653,25 @@ def test_is_null_datetimelike(self):
653653

654654
for value in never_na_vals:
655655
assert not is_null_datetimelike(value)
656+
657+
def test_is_matching_na(self, nulls_fixture, nulls_fixture2):
658+
left = nulls_fixture
659+
right = nulls_fixture2
660+
661+
assert libmissing.is_matching_na(left, left)
662+
663+
if left is right:
664+
assert libmissing.is_matching_na(left, right)
665+
elif is_float(left) and is_float(right):
666+
# np.nan vs float("NaN") we consider as matching
667+
assert libmissing.is_matching_na(left, right)
668+
else:
669+
assert not libmissing.is_matching_na(left, right)
670+
671+
def test_is_matching_na_nan_matches_none(self):
672+
673+
assert not libmissing.is_matching_na(None, np.nan)
674+
assert not libmissing.is_matching_na(np.nan, None)
675+
676+
assert libmissing.is_matching_na(None, np.nan, nan_matches_none=True)
677+
assert libmissing.is_matching_na(np.nan, None, nan_matches_none=True)

pandas/tests/series/methods/test_equals.py

+57-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
from contextlib import nullcontext
2+
import copy
23

34
import numpy as np
45
import pytest
56

6-
from pandas import MultiIndex, Series
7+
from pandas._libs.missing import is_matching_na
8+
9+
from pandas.core.dtypes.common import is_float
10+
11+
from pandas import Index, MultiIndex, Series
712
import pandas._testing as tm
813

914

@@ -65,3 +70,54 @@ def test_equals_false_negative():
6570
assert s1.equals(s4)
6671
assert s1.equals(s5)
6772
assert s5.equals(s6)
73+
74+
75+
def test_equals_matching_nas():
76+
# matching but not identical NAs
77+
left = Series([np.datetime64("NaT")], dtype=object)
78+
right = Series([np.datetime64("NaT")], dtype=object)
79+
assert left.equals(right)
80+
assert Index(left).equals(Index(right))
81+
assert left.array.equals(right.array)
82+
83+
left = Series([np.timedelta64("NaT")], dtype=object)
84+
right = Series([np.timedelta64("NaT")], dtype=object)
85+
assert left.equals(right)
86+
assert Index(left).equals(Index(right))
87+
assert left.array.equals(right.array)
88+
89+
left = Series([np.float64("NaN")], dtype=object)
90+
right = Series([np.float64("NaN")], dtype=object)
91+
assert left.equals(right)
92+
assert Index(left).equals(Index(right))
93+
assert left.array.equals(right.array)
94+
95+
96+
def test_equals_mismatched_nas(nulls_fixture, nulls_fixture2):
97+
# GH#39650
98+
left = nulls_fixture
99+
right = nulls_fixture2
100+
if hasattr(right, "copy"):
101+
right = right.copy()
102+
else:
103+
right = copy.copy(right)
104+
105+
ser = Series([left], dtype=object)
106+
ser2 = Series([right], dtype=object)
107+
108+
if is_matching_na(left, right):
109+
assert ser.equals(ser2)
110+
elif (left is None and is_float(right)) or (right is None and is_float(left)):
111+
assert ser.equals(ser2)
112+
else:
113+
assert not ser.equals(ser2)
114+
115+
116+
def test_equals_none_vs_nan():
117+
# GH#39650
118+
ser = Series([1, None], dtype=object)
119+
ser2 = Series([1, np.nan], dtype=object)
120+
121+
assert ser.equals(ser2)
122+
assert Index(ser).equals(Index(ser2))
123+
assert ser.array.equals(ser2.array)

pandas/tests/util/test_assert_numpy_array_equal.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import copy
2+
13
import numpy as np
24
import pytest
35

@@ -198,6 +200,14 @@ def test_numpy_array_equal_identical_na(nulls_fixture):
198200

199201
tm.assert_numpy_array_equal(a, a)
200202

203+
# matching but not the identical object
204+
if hasattr(nulls_fixture, "copy"):
205+
other = nulls_fixture.copy()
206+
else:
207+
other = copy.copy(nulls_fixture)
208+
b = np.array([other], dtype=object)
209+
tm.assert_numpy_array_equal(a, b)
210+
201211

202212
def test_numpy_array_equal_different_na():
203213
a = np.array([np.nan], dtype=object)

0 commit comments

Comments
 (0)