diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index f87dac0669e00..6c4bd35c8f183 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -276,6 +276,7 @@ Numeric - Bug in :meth:`DataFrame.any` with ``axis=1`` and ``bool_only=True`` ignoring the ``bool_only`` keyword (:issue:`32432`) - Bug in :meth:`Series.equals` where a ``ValueError`` was raised when numpy arrays were compared to scalars (:issue:`35267`) - Bug in :class:`Series` where two :class:`Series` each have a :class:`DatetimeIndex` with different timezones having those indexes incorrectly changed when performing arithmetic operations (:issue:`33671`) +- Bug in :meth:`pd._testing.assert_almost_equal` was incorrect for complex numeric types (:issue:`28235`) - Conversion diff --git a/pandas/_libs/testing.pyx b/pandas/_libs/testing.pyx index 64fc8d615ea9c..b2f19fcf5f5da 100644 --- a/pandas/_libs/testing.pyx +++ b/pandas/_libs/testing.pyx @@ -1,3 +1,4 @@ +import cmath import math import numpy as np @@ -7,6 +8,7 @@ from numpy cimport import_array import_array() from pandas._libs.util cimport is_array +from pandas._libs.lib import is_complex from pandas.core.dtypes.common import is_dtype_equal from pandas.core.dtypes.missing import array_equivalent, isna @@ -210,4 +212,14 @@ cpdef assert_almost_equal(a, b, f"with rtol={rtol}, atol={atol}") return True + if is_complex(a) and is_complex(b): + if array_equivalent(a, b, strict_nan=True): + # inf comparison + return True + + if not cmath.isclose(a, b, rel_tol=rtol, abs_tol=atol): + assert False, (f"expected {b:.5f} but got {a:.5f}, " + f"with rtol={rtol}, atol={atol}") + return True + raise AssertionError(f"{a} != {b}") diff --git a/pandas/tests/util/test_assert_almost_equal.py b/pandas/tests/util/test_assert_almost_equal.py index c25668c33bfc4..c4bc3b7ee352d 100644 --- a/pandas/tests/util/test_assert_almost_equal.py +++ b/pandas/tests/util/test_assert_almost_equal.py @@ -146,6 +146,37 @@ def test_assert_not_almost_equal_numbers_rtol(a, b): _assert_not_almost_equal_both(a, b, rtol=0.05) +@pytest.mark.parametrize( + "a,b,rtol", + [ + (1.00001, 1.00005, 0.001), + (-0.908356 + 0.2j, -0.908358 + 0.2j, 1e-3), + (0.1 + 1.009j, 0.1 + 1.006j, 0.1), + (0.1001 + 2.0j, 0.1 + 2.001j, 0.01), + ], +) +def test_assert_almost_equal_complex_numbers(a, b, rtol): + _assert_almost_equal_both(a, b, rtol=rtol) + _assert_almost_equal_both(np.complex64(a), np.complex64(b), rtol=rtol) + _assert_almost_equal_both(np.complex128(a), np.complex128(b), rtol=rtol) + + +@pytest.mark.parametrize( + "a,b,rtol", + [ + (0.58310768, 0.58330768, 1e-7), + (-0.908 + 0.2j, -0.978 + 0.2j, 0.001), + (0.1 + 1j, 0.1 + 2j, 0.01), + (-0.132 + 1.001j, -0.132 + 1.005j, 1e-5), + (0.58310768j, 0.58330768j, 1e-9), + ], +) +def test_assert_not_almost_equal_complex_numbers(a, b, rtol): + _assert_not_almost_equal_both(a, b, rtol=rtol) + _assert_not_almost_equal_both(np.complex64(a), np.complex64(b), rtol=rtol) + _assert_not_almost_equal_both(np.complex128(a), np.complex128(b), rtol=rtol) + + @pytest.mark.parametrize("a,b", [(0, 0), (0, 0.0), (0, np.float64(0)), (0.00000001, 0)]) def test_assert_almost_equal_numbers_with_zeros(a, b): _assert_almost_equal_both(a, b)