Skip to content

Commit 5377aef

Browse files
Backport PR #31910: BUG: Handle NA in assert_numpy_array_equal (#31947)
Co-authored-by: Daniel Saxton <[email protected]>
1 parent 95e83c5 commit 5377aef

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

Diff for: pandas/_libs/lib.pyx

+2
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,8 @@ def array_equivalent_object(left: object[:], right: object[:]) -> bool:
527527
if PyArray_Check(x) and PyArray_Check(y):
528528
if not array_equivalent_object(x, y):
529529
return False
530+
elif (x is C_NA) ^ (y is C_NA):
531+
return False
530532
elif not (PyObject_RichCompareBool(x, y, Py_EQ) or
531533
(x is None or is_nan(x)) and (y is None or is_nan(y))):
532534
return False

Diff for: pandas/tests/util/test_assert_numpy_array_equal.py

+36
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33

4+
import pandas as pd
45
from pandas import Timestamp
56
import pandas._testing as tm
67

@@ -175,3 +176,38 @@ def test_numpy_array_equal_copy_flag(other_type, check_same):
175176
tm.assert_numpy_array_equal(a, other, check_same=check_same)
176177
else:
177178
tm.assert_numpy_array_equal(a, other, check_same=check_same)
179+
180+
181+
def test_numpy_array_equal_contains_na():
182+
# https://github.com/pandas-dev/pandas/issues/31881
183+
a = np.array([True, False])
184+
b = np.array([True, pd.NA], dtype=object)
185+
186+
msg = """numpy array are different
187+
188+
numpy array values are different \\(50.0 %\\)
189+
\\[left\\]: \\[True, False\\]
190+
\\[right\\]: \\[True, <NA>\\]"""
191+
192+
with pytest.raises(AssertionError, match=msg):
193+
tm.assert_numpy_array_equal(a, b)
194+
195+
196+
def test_numpy_array_equal_identical_na(nulls_fixture):
197+
a = np.array([nulls_fixture], dtype=object)
198+
199+
tm.assert_numpy_array_equal(a, a)
200+
201+
202+
def test_numpy_array_equal_different_na():
203+
a = np.array([np.nan], dtype=object)
204+
b = np.array([pd.NA], dtype=object)
205+
206+
msg = """numpy array are different
207+
208+
numpy array values are different \\(100.0 %\\)
209+
\\[left\\]: \\[nan\\]
210+
\\[right\\]: \\[<NA>\\]"""
211+
212+
with pytest.raises(AssertionError, match=msg):
213+
tm.assert_numpy_array_equal(a, b)

0 commit comments

Comments
 (0)