Skip to content

Commit 13578bf

Browse files
authored
Fix obj arguments in assertions (#59460)
* Fix obj arguments in assertions * Add test for assert_interval_array_equal * Still say MultiIndex in some cases * Be a little braver and change the original obj definition
1 parent f3e1991 commit 13578bf

File tree

4 files changed

+30
-8
lines changed

4 files changed

+30
-8
lines changed

pandas/_testing/asserters.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def assert_index_equal(
188188
check_order: bool = True,
189189
rtol: float = 1.0e-5,
190190
atol: float = 1.0e-8,
191-
obj: str = "Index",
191+
obj: str | None = None,
192192
) -> None:
193193
"""
194194
Check that left and right Index are equal.
@@ -217,7 +217,7 @@ def assert_index_equal(
217217
Relative tolerance. Only used when check_exact is False.
218218
atol : float, default 1e-8
219219
Absolute tolerance. Only used when check_exact is False.
220-
obj : str, default 'Index'
220+
obj : str, default 'Index' or 'MultiIndex'
221221
Specify object name being compared, internally used to show appropriate
222222
assertion message.
223223
@@ -235,6 +235,9 @@ def assert_index_equal(
235235
"""
236236
__tracebackhide__ = True
237237

238+
if obj is None:
239+
obj = "MultiIndex" if isinstance(left, MultiIndex) else "Index"
240+
238241
def _check_types(left, right, obj: str = "Index") -> None:
239242
if not exact:
240243
return
@@ -283,7 +286,7 @@ def _check_types(left, right, obj: str = "Index") -> None:
283286
right = cast(MultiIndex, right)
284287

285288
for level in range(left.nlevels):
286-
lobj = f"MultiIndex level [{level}]"
289+
lobj = f"{obj} level [{level}]"
287290
try:
288291
# try comparison on levels/codes to avoid densifying MultiIndex
289292
assert_index_equal(
@@ -314,7 +317,7 @@ def _check_types(left, right, obj: str = "Index") -> None:
314317
obj=lobj,
315318
)
316319
# get_level_values may change dtype
317-
_check_types(left.levels[level], right.levels[level], obj=obj)
320+
_check_types(left.levels[level], right.levels[level], obj=lobj)
318321

319322
# skip exact index checking when `check_categorical` is False
320323
elif check_exact and check_categorical:
@@ -527,7 +530,7 @@ def assert_interval_array_equal(
527530
kwargs["check_freq"] = False
528531

529532
assert_equal(left._left, right._left, obj=f"{obj}.left", **kwargs)
530-
assert_equal(left._right, right._right, obj=f"{obj}.left", **kwargs)
533+
assert_equal(left._right, right._right, obj=f"{obj}.right", **kwargs)
531534

532535
assert_attr_equal("closed", left, right, obj=obj)
533536

pandas/tests/util/test_assert_frame_equal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def test_frame_equal_shape_mismatch(df1, df2, frame_or_series):
7979
DataFrame.from_records(
8080
{"a": [1.0, 2.0], "b": [2.1, 1.5], "c": ["l1", "l2"]}, index=["a", "b"]
8181
),
82-
"MultiIndex level \\[0\\] are different",
82+
"DataFrame\\.index level \\[0\\] are different",
8383
),
8484
],
8585
)

pandas/tests/util/test_assert_interval_array_equal.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import pytest
22

3-
from pandas import interval_range
3+
from pandas import (
4+
Interval,
5+
interval_range,
6+
)
47
import pandas._testing as tm
8+
from pandas.arrays import IntervalArray
59

610

711
@pytest.mark.parametrize(
@@ -79,3 +83,18 @@ def test_interval_array_equal_start_mismatch():
7983

8084
with pytest.raises(AssertionError, match=msg):
8185
tm.assert_interval_array_equal(arr1, arr2)
86+
87+
88+
def test_interval_array_equal_end_mismatch_only():
89+
arr1 = IntervalArray([Interval(0, 1), Interval(0, 5)])
90+
arr2 = IntervalArray([Interval(0, 1), Interval(0, 6)])
91+
92+
msg = """\
93+
IntervalArray.right are different
94+
95+
IntervalArray.right values are different \\(50.0 %\\)
96+
\\[left\\]: \\[1, 5\\]
97+
\\[right\\]: \\[1, 6\\]"""
98+
99+
with pytest.raises(AssertionError, match=msg):
100+
tm.assert_interval_array_equal(arr1, arr2)

pandas/tests/util/test_assert_series_equal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def test_less_precise(data1, data2, any_float_dtype, decimals):
137137
DataFrame.from_records(
138138
{"a": [1.0, 2.0], "b": [2.1, 1.5], "c": ["l1", "l2"]}, index=["a", "b"]
139139
).c,
140-
"MultiIndex level \\[0\\] are different",
140+
"Series\\.index level \\[0\\] are different",
141141
),
142142
],
143143
)

0 commit comments

Comments
 (0)