Skip to content

Commit c68e4d8

Browse files
author
Jean-Francois Zinque
committed
Parametrize MultiIndex intersection and union tests
1 parent af5de36 commit c68e4d8

File tree

1 file changed

+22
-18
lines changed

1 file changed

+22
-18
lines changed

pandas/tests/indexes/multi/test_setops.py

+22-18
Original file line numberDiff line numberDiff line change
@@ -19,37 +19,41 @@ def test_set_ops_error_cases(idx, case, sort, method):
1919

2020

2121
@pytest.mark.parametrize("sort", [None, False])
22-
def test_intersection_base(idx, sort):
22+
@pytest.mark.parametrize("klass", [MultiIndex, np.array, Series, list])
23+
def test_intersection_base(idx, sort, klass):
2324
first = idx[2::-1] # first 3 elements reversed
2425
second = idx[:5]
2526

26-
array_like_cases = [klass(second.values) for klass in [np.array, Series, list]]
27-
for case in [second, *array_like_cases]:
28-
intersect = first.intersection(case, sort=sort)
29-
if sort is None:
30-
expected = first.sort_values()
31-
else:
32-
expected = first
33-
tm.assert_index_equal(intersect, expected)
27+
if klass is not MultiIndex:
28+
second = klass(second.values)
29+
30+
intersect = first.intersection(second, sort=sort)
31+
if sort is None:
32+
expected = first.sort_values()
33+
else:
34+
expected = first
35+
tm.assert_index_equal(intersect, expected)
3436

3537
msg = "other must be a MultiIndex or a list of tuples"
3638
with pytest.raises(TypeError, match=msg):
3739
first.intersection([1, 2, 3], sort=sort)
3840

3941

4042
@pytest.mark.parametrize("sort", [None, False])
41-
def test_union_base(idx, sort):
43+
@pytest.mark.parametrize("klass", [MultiIndex, np.array, Series, list])
44+
def test_union_base(idx, sort, klass):
4245
first = idx[::-1]
4346
second = idx[:5]
4447

45-
array_like_cases = [klass(second.values) for klass in [np.array, Series, list]]
46-
for case in [second, *array_like_cases]:
47-
union = first.union(case, sort=sort)
48-
if sort is None:
49-
expected = first.sort_values()
50-
else:
51-
expected = first
52-
tm.assert_index_equal(union, expected)
48+
if klass is not MultiIndex:
49+
second = klass(second.values)
50+
51+
union = first.union(second, sort=sort)
52+
if sort is None:
53+
expected = first.sort_values()
54+
else:
55+
expected = first
56+
tm.assert_index_equal(union, expected)
5357

5458
msg = "other must be a MultiIndex or a list of tuples"
5559
with pytest.raises(TypeError, match=msg):

0 commit comments

Comments
 (0)