Skip to content

Commit c46b2d6

Browse files
committed
check_exact only takes effect for floating dtypes
1 parent 8f608a0 commit c46b2d6

File tree

7 files changed

+71
-35
lines changed

7 files changed

+71
-35
lines changed

doc/source/whatsnew/v2.2.0.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ See :ref:`install.dependencies` and :ref:`install.optional_dependencies` for mor
315315

316316
Other API changes
317317
^^^^^^^^^^^^^^^^^
318-
-
318+
- ``check_exact`` now only takes effect for floating-point dtypes in :func:`testing.assert_frame_equal` and :func:`testing.assert_series_equal`. In particular, integer dtypes are always checked exactly (:issue:`55882`)
319319
-
320320

321321
.. ---------------------------------------------------------------------------

pandas/_testing/asserters.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from pandas.core.dtypes.common import (
1818
is_bool,
19+
is_float_dtype,
1920
is_integer_dtype,
2021
is_number,
2122
is_numeric_dtype,
@@ -713,7 +714,7 @@ def assert_extension_array_equal(
713714
index_values : Index | numpy.ndarray, default None
714715
Optional index (shared by both left and right), used in output.
715716
check_exact : bool, default False
716-
Whether to compare number exactly.
717+
Whether to compare number exactly. Only takes effect for float dtypes.
717718
rtol : float, default 1e-5
718719
Relative tolerance. Only used when check_exact is False.
719720
atol : float, default 1e-8
@@ -782,7 +783,10 @@ def assert_extension_array_equal(
782783

783784
left_valid = left[~left_na].to_numpy(dtype=object)
784785
right_valid = right[~right_na].to_numpy(dtype=object)
785-
if check_exact:
786+
if check_exact or (
787+
(is_numeric_dtype(left.dtype) and not is_float_dtype(left.dtype))
788+
or (is_numeric_dtype(right.dtype) and not is_float_dtype(right.dtype))
789+
):
786790
assert_numpy_array_equal(
787791
left_valid, right_valid, obj=obj, index_values=index_values
788792
)
@@ -836,8 +840,7 @@ def assert_series_equal(
836840
check_names : bool, default True
837841
Whether to check the Series and Index names attribute.
838842
check_exact : bool, default False
839-
Whether to compare number exactly.
840-
Note: Will be set to True if dtype is int.
843+
Whether to compare number exactly. Only takes effect for float dtypes.
841844
check_datetimelike_compat : bool, default False
842845
Compare datetime-like which is comparable ignoring dtype.
843846
check_categorical : bool, default True
@@ -930,16 +933,10 @@ def assert_series_equal(
930933
pass
931934
else:
932935
assert_attr_equal("dtype", left, right, obj=f"Attributes of {obj}")
933-
934-
if (
935-
is_integer_dtype(left.dtype)
936-
and is_integer_dtype(right.dtype)
937-
and isinstance(left._values, type(right._values))
938-
and isinstance(right._values, type(left._values))
936+
if check_exact or (
937+
(is_numeric_dtype(left.dtype) and not is_float_dtype(left.dtype))
938+
or (is_numeric_dtype(right.dtype) and not is_float_dtype(right.dtype))
939939
):
940-
check_exact = True
941-
942-
if check_exact and is_numeric_dtype(left.dtype) and is_numeric_dtype(right.dtype):
943940
left_values = left._values
944941
right_values = right._values
945942
# Only check exact if dtype is numeric
@@ -1102,7 +1099,7 @@ def assert_frame_equal(
11021099
Specify how to compare internal data. If False, compare by columns.
11031100
If True, compare by blocks.
11041101
check_exact : bool, default False
1105-
Whether to compare number exactly.
1102+
Whether to compare number exactly. Only takes effect for float dtypes.
11061103
check_datetimelike_compat : bool, default False
11071104
Compare datetime-like which is comparable ignoring dtype.
11081105
check_categorical : bool, default True

pandas/tests/extension/base/methods.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pandas._typing import Dtype
88

99
from pandas.core.dtypes.common import is_bool_dtype
10+
from pandas.core.dtypes.dtypes import NumpyEADtype
1011
from pandas.core.dtypes.missing import na_value_for_dtype
1112

1213
import pandas as pd
@@ -331,7 +332,8 @@ def test_fillna_length_mismatch(self, data_missing):
331332
data_missing.fillna(data_missing.take([1]))
332333

333334
# Subclasses can override if we expect e.g Sparse[bool], boolean, pyarrow[bool]
334-
_combine_le_expected_dtype: Dtype = np.dtype(bool)
335+
# _combine_le_expected_dtype: Dtype = np.dtype(bool)
336+
_combine_le_expected_dtype: Dtype = NumpyEADtype("bool")
335337

336338
def test_combine_le(self, data_repeated):
337339
# GH 20825
@@ -341,16 +343,20 @@ def test_combine_le(self, data_repeated):
341343
s2 = pd.Series(orig_data2)
342344
result = s1.combine(s2, lambda x1, x2: x1 <= x2)
343345
expected = pd.Series(
344-
[a <= b for (a, b) in zip(list(orig_data1), list(orig_data2))],
345-
dtype=self._combine_le_expected_dtype,
346+
pd.array(
347+
[a <= b for (a, b) in zip(list(orig_data1), list(orig_data2))],
348+
dtype=self._combine_le_expected_dtype,
349+
)
346350
)
347351
tm.assert_series_equal(result, expected)
348352

349353
val = s1.iloc[0]
350354
result = s1.combine(val, lambda x1, x2: x1 <= x2)
351355
expected = pd.Series(
352-
[a <= val for a in list(orig_data1)],
353-
dtype=self._combine_le_expected_dtype,
356+
pd.array(
357+
[a <= val for a in list(orig_data1)],
358+
dtype=self._combine_le_expected_dtype,
359+
)
354360
)
355361
tm.assert_series_equal(result, expected)
356362

pandas/tests/series/test_constructors.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,10 @@ def test_constructor_maskedarray(self):
572572
data[1] = 1
573573
result = Series(data, index=index)
574574
expected = Series([0, 1, 2], index=index, dtype=int)
575-
tm.assert_series_equal(result, expected)
575+
with pytest.raises(AssertionError, match="Series classes are different"):
576+
# TODO should this be raising at all?
577+
# https://github.com/pandas-dev/pandas/issues/56131
578+
tm.assert_series_equal(result, expected)
576579

577580
data = ma.masked_all((3,), dtype=bool)
578581
result = Series(data)
@@ -589,7 +592,10 @@ def test_constructor_maskedarray(self):
589592
data[1] = True
590593
result = Series(data, index=index)
591594
expected = Series([True, True, False], index=index, dtype=bool)
592-
tm.assert_series_equal(result, expected)
595+
with pytest.raises(AssertionError, match="Series classes are different"):
596+
# TODO should this be raising at all?
597+
# https://github.com/pandas-dev/pandas/issues/56131
598+
tm.assert_series_equal(result, expected)
593599

594600
data = ma.masked_all((3,), dtype="M8[ns]")
595601
result = Series(data)

pandas/tests/tools/test_to_datetime.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2116,7 +2116,13 @@ def test_float_to_datetime_raise_near_bounds(self):
21162116
expected = (should_succeed * oneday_in_ns).astype(np.int64)
21172117
for error_mode in ["raise", "coerce", "ignore"]:
21182118
result1 = to_datetime(should_succeed, unit="D", errors=error_mode)
2119-
tm.assert_almost_equal(result1.astype(np.int64), expected, rtol=1e-10)
2119+
# Cast to `np.float64` so that `rtol` and inexact checking kick in
2120+
# (`check_exact` doesn't take place for integer dtypes)
2121+
tm.assert_almost_equal(
2122+
result1.astype(np.int64).astype(np.float64),
2123+
expected.astype(np.float64),
2124+
rtol=1e-10,
2125+
)
21202126
# just out of bounds
21212127
should_fail1 = Series([0, tsmax_in_days + 0.005], dtype=float)
21222128
should_fail2 = Series([0, -tsmax_in_days - 0.005], dtype=float)

pandas/tests/util/test_assert_frame_equal.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,10 @@ def test_assert_frame_equal_extension_dtype_mismatch():
203203
"\\[right\\]: int[32|64]"
204204
)
205205

206-
tm.assert_frame_equal(left, right, check_dtype=False)
206+
# TODO: this shouldn't raise (or should raise a better error message)
207+
# https://github.com/pandas-dev/pandas/issues/56131
208+
with pytest.raises(AssertionError, match="classes are different"):
209+
tm.assert_frame_equal(left, right, check_dtype=False)
207210

208211
with pytest.raises(AssertionError, match=msg):
209212
tm.assert_frame_equal(left, right, check_dtype=True)
@@ -228,11 +231,18 @@ def test_assert_frame_equal_interval_dtype_mismatch():
228231
tm.assert_frame_equal(left, right, check_dtype=True)
229232

230233

231-
@pytest.mark.parametrize("right_dtype", ["Int32", "int64"])
232-
def test_assert_frame_equal_ignore_extension_dtype_mismatch(right_dtype):
234+
def test_assert_frame_equal_ignore_extension_dtype_mismatch():
235+
# https://github.com/pandas-dev/pandas/issues/35715
236+
left = DataFrame({"a": [1, 2, 3]}, dtype="Int64")
237+
right = DataFrame({"a": [1, 2, 3]}, dtype="Int32")
238+
tm.assert_frame_equal(left, right, check_dtype=False)
239+
240+
241+
@pytest.mark.xfail(reason="https://github.com/pandas-dev/pandas/issues/56131")
242+
def test_assert_frame_equal_ignore_extension_dtype_mismatch_cross_class():
233243
# https://github.com/pandas-dev/pandas/issues/35715
234244
left = DataFrame({"a": [1, 2, 3]}, dtype="Int64")
235-
right = DataFrame({"a": [1, 2, 3]}, dtype=right_dtype)
245+
right = DataFrame({"a": [1, 2, 3]}, dtype="int64")
236246
tm.assert_frame_equal(left, right, check_dtype=False)
237247

238248

pandas/tests/util/test_assert_series_equal.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,10 @@ def test_assert_series_equal_extension_dtype_mismatch():
276276
\\[left\\]: Int64
277277
\\[right\\]: int[32|64]"""
278278

279-
tm.assert_series_equal(left, right, check_dtype=False)
279+
# TODO: this shouldn't raise (or should raise a better error message)
280+
# https://github.com/pandas-dev/pandas/issues/56131
281+
with pytest.raises(AssertionError, match="Series classes are different"):
282+
tm.assert_series_equal(left, right, check_dtype=False)
280283

281284
with pytest.raises(AssertionError, match=msg):
282285
tm.assert_series_equal(left, right, check_dtype=True)
@@ -348,11 +351,18 @@ def test_series_equal_exact_for_nonnumeric():
348351
tm.assert_series_equal(s3, s1, check_exact=True)
349352

350353

351-
@pytest.mark.parametrize("right_dtype", ["Int32", "int64"])
352-
def test_assert_series_equal_ignore_extension_dtype_mismatch(right_dtype):
354+
def test_assert_series_equal_ignore_extension_dtype_mismatch():
355+
# https://github.com/pandas-dev/pandas/issues/35715
356+
left = Series([1, 2, 3], dtype="Int64")
357+
right = Series([1, 2, 3], dtype="Int32")
358+
tm.assert_series_equal(left, right, check_dtype=False)
359+
360+
361+
@pytest.mark.xfail(reason="https://github.com/pandas-dev/pandas/issues/56131")
362+
def test_assert_series_equal_ignore_extension_dtype_mismatch_cross_class():
353363
# https://github.com/pandas-dev/pandas/issues/35715
354364
left = Series([1, 2, 3], dtype="Int64")
355-
right = Series([1, 2, 3], dtype=right_dtype)
365+
right = Series([1, 2, 3], dtype="int64")
356366
tm.assert_series_equal(left, right, check_dtype=False)
357367

358368

@@ -425,9 +435,10 @@ def test_check_dtype_false_different_reso(dtype):
425435
tm.assert_series_equal(ser_s, ser_ms, check_dtype=False)
426436

427437

428-
def test_check_exact_true_for_int_dtype():
429-
# GH 55882
430-
left = Series([1577840521123000])
431-
right = Series([1577840521123543])
438+
@pytest.mark.parametrize("dtype", ["Int64", "int64"])
439+
def test_large_unequal_ints(dtype):
440+
# https://github.com/pandas-dev/pandas/issues/55882
441+
left = Series([1577840521123000], dtype=dtype)
442+
right = Series([1577840521123543], dtype=dtype)
432443
with pytest.raises(AssertionError, match="Series are different"):
433444
tm.assert_series_equal(left, right)

0 commit comments

Comments
 (0)