Skip to content

Commit 158ae5b

Browse files
gfyoungjreback
authored andcommitted
COMPAT, TST: allow numpy array comparisons with complex dtypes (#13392)
Traces back to bug in NumPy v1.7.1 in which the 'array_equivalent' method could not compare NumPy arrays with complicated dtypes. As pandas relies on this function to check NumPy array equality during testing, this commit adds a fallback method for doing so. Closes gh-13388.
1 parent 67b72e3 commit 158ae5b

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

pandas/core/common.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,24 @@ def array_equivalent(left, right, strict_nan=False):
349349
right = right.view('i8')
350350

351351
# NaNs cannot occur otherwise.
352-
return np.array_equal(left, right)
352+
try:
353+
return np.array_equal(left, right)
354+
except AttributeError:
355+
# see gh-13388
356+
#
357+
# NumPy v1.7.1 has a bug in its array_equal
358+
# function that prevents it from correctly
359+
# comparing two arrays with complex dtypes.
360+
# This bug is corrected in v1.8.0, so remove
361+
# this try-except block as soon as we stop
362+
# supporting NumPy versions < 1.8.0
363+
if not is_dtype_equal(left.dtype, right.dtype):
364+
return False
365+
366+
left = left.tolist()
367+
right = right.tolist()
368+
369+
return left == right
353370

354371

355372
def _iterable_not_string(x):

pandas/tests/test_common.py

+18
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,24 @@ def test_is_timedelta():
832832
assert (not com.is_timedelta64_ns_dtype(tdi.astype('timedelta64[h]')))
833833

834834

835+
def test_array_equivalent_compat():
836+
# see gh-13388
837+
m = np.array([(1, 2), (3, 4)], dtype=[('a', int), ('b', float)])
838+
n = np.array([(1, 2), (3, 4)], dtype=[('a', int), ('b', float)])
839+
assert (com.array_equivalent(m, n, strict_nan=True))
840+
assert (com.array_equivalent(m, n, strict_nan=False))
841+
842+
m = np.array([(1, 2), (3, 4)], dtype=[('a', int), ('b', float)])
843+
n = np.array([(1, 2), (4, 3)], dtype=[('a', int), ('b', float)])
844+
assert (not com.array_equivalent(m, n, strict_nan=True))
845+
assert (not com.array_equivalent(m, n, strict_nan=False))
846+
847+
m = np.array([(1, 2), (3, 4)], dtype=[('a', int), ('b', float)])
848+
n = np.array([(1, 2), (3, 4)], dtype=[('b', int), ('a', float)])
849+
assert (not com.array_equivalent(m, n, strict_nan=True))
850+
assert (not com.array_equivalent(m, n, strict_nan=False))
851+
852+
835853
if __name__ == '__main__':
836854
nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'],
837855
exit=False)

0 commit comments

Comments
 (0)