Skip to content

Commit 6e70bb5

Browse files
topper-123JulianWgs
authored andcommitted
REF: Refactor assert_index_equal (pandas-dev#41980)
1 parent e48244a commit 6e70bb5

File tree

2 files changed

+36
-22
lines changed

2 files changed

+36
-22
lines changed

pandas/_testing/asserters.py

+17-18
Original file line numberDiff line numberDiff line change
@@ -314,18 +314,16 @@ def _check_types(left, right, obj="Index") -> None:
314314
return
315315

316316
assert_class_equal(left, right, exact=exact, obj=obj)
317+
assert_attr_equal("inferred_type", left, right, obj=obj)
317318

318319
# Skip exact dtype checking when `check_categorical` is False
319-
if check_categorical:
320-
assert_attr_equal("dtype", left, right, obj=obj)
321-
if is_categorical_dtype(left.dtype) and is_categorical_dtype(right.dtype):
320+
if is_categorical_dtype(left.dtype) and is_categorical_dtype(right.dtype):
321+
if check_categorical:
322+
assert_attr_equal("dtype", left, right, obj=obj)
322323
assert_index_equal(left.categories, right.categories, exact=exact)
324+
return
323325

324-
# allow string-like to have different inferred_types
325-
if left.inferred_type in ("string"):
326-
assert right.inferred_type in ("string")
327-
else:
328-
assert_attr_equal("inferred_type", left, right, obj=obj)
326+
assert_attr_equal("dtype", left, right, obj=obj)
329327

330328
def _get_ilevel_values(index, level):
331329
# accept level number only
@@ -437,6 +435,8 @@ def assert_class_equal(left, right, exact: bool | str = True, obj="Input"):
437435
"""
438436
Checks classes are equal.
439437
"""
438+
from pandas.core.indexes.numeric import NumericIndex
439+
440440
__tracebackhide__ = True
441441

442442
def repr_class(x):
@@ -446,17 +446,16 @@ def repr_class(x):
446446

447447
return type(x).__name__
448448

449+
if type(left) == type(right):
450+
return
451+
449452
if exact == "equiv":
450-
if type(left) != type(right):
451-
# allow equivalence of Int64Index/RangeIndex
452-
types = {type(left).__name__, type(right).__name__}
453-
if len(types - {"Int64Index", "RangeIndex"}):
454-
msg = f"{obj} classes are not equivalent"
455-
raise_assert_detail(obj, msg, repr_class(left), repr_class(right))
456-
elif exact:
457-
if type(left) != type(right):
458-
msg = f"{obj} classes are different"
459-
raise_assert_detail(obj, msg, repr_class(left), repr_class(right))
453+
# accept equivalence of NumericIndex (sub-)classes
454+
if isinstance(left, NumericIndex) and isinstance(right, NumericIndex):
455+
return
456+
457+
msg = f"{obj} classes are different"
458+
raise_assert_detail(obj, msg, repr_class(left), repr_class(right))
460459

461460

462461
def assert_attr_equal(attr: str, left, right, obj: str = "Attributes"):

pandas/tests/util/test_assert_index_equal.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,30 @@ def test_index_equal_length_mismatch(check_exact):
5858
tm.assert_index_equal(idx1, idx2, check_exact=check_exact)
5959

6060

61-
def test_index_equal_class_mismatch(check_exact):
62-
msg = """Index are different
61+
@pytest.mark.parametrize("exact", [False, "equiv"])
62+
def test_index_equal_class(exact):
63+
idx1 = Index([0, 1, 2])
64+
idx2 = RangeIndex(3)
65+
66+
tm.assert_index_equal(idx1, idx2, exact=exact)
67+
68+
69+
@pytest.mark.parametrize(
70+
"idx_values, msg_str",
71+
[
72+
[[1, 2, 3.0], "Float64Index\\(\\[1\\.0, 2\\.0, 3\\.0\\], dtype='float64'\\)"],
73+
[range(3), "RangeIndex\\(start=0, stop=3, step=1\\)"],
74+
],
75+
)
76+
def test_index_equal_class_mismatch(check_exact, idx_values, msg_str):
77+
msg = f"""Index are different
6378
6479
Index classes are different
6580
\\[left\\]: Int64Index\\(\\[1, 2, 3\\], dtype='int64'\\)
66-
\\[right\\]: Float64Index\\(\\[1\\.0, 2\\.0, 3\\.0\\], dtype='float64'\\)"""
81+
\\[right\\]: {msg_str}"""
6782

6883
idx1 = Index([1, 2, 3])
69-
idx2 = Index([1, 2, 3.0])
84+
idx2 = Index(idx_values)
7085

7186
with pytest.raises(AssertionError, match=msg):
7287
tm.assert_index_equal(idx1, idx2, exact=True, check_exact=check_exact)

0 commit comments

Comments
 (0)