@@ -19,37 +19,41 @@ def test_set_ops_error_cases(idx, case, sort, method):
19
19
20
20
21
21
@pytest .mark .parametrize ("sort" , [None , False ])
22
- def test_intersection_base (idx , sort ):
22
+ @pytest .mark .parametrize ("klass" , [MultiIndex , np .array , Series , list ])
23
+ def test_intersection_base (idx , sort , klass ):
23
24
first = idx [2 ::- 1 ] # first 3 elements reversed
24
25
second = idx [:5 ]
25
26
26
- array_like_cases = [klass (second .values ) for klass in [np .array , Series , list ]]
27
- for case in [second , * array_like_cases ]:
28
- intersect = first .intersection (case , sort = sort )
29
- if sort is None :
30
- expected = first .sort_values ()
31
- else :
32
- expected = first
33
- tm .assert_index_equal (intersect , expected )
27
+ if klass is not MultiIndex :
28
+ second = klass (second .values )
29
+
30
+ intersect = first .intersection (second , sort = sort )
31
+ if sort is None :
32
+ expected = first .sort_values ()
33
+ else :
34
+ expected = first
35
+ tm .assert_index_equal (intersect , expected )
34
36
35
37
msg = "other must be a MultiIndex or a list of tuples"
36
38
with pytest .raises (TypeError , match = msg ):
37
39
first .intersection ([1 , 2 , 3 ], sort = sort )
38
40
39
41
40
42
@pytest .mark .parametrize ("sort" , [None , False ])
41
- def test_union_base (idx , sort ):
43
+ @pytest .mark .parametrize ("klass" , [MultiIndex , np .array , Series , list ])
44
+ def test_union_base (idx , sort , klass ):
42
45
first = idx [::- 1 ]
43
46
second = idx [:5 ]
44
47
45
- array_like_cases = [klass (second .values ) for klass in [np .array , Series , list ]]
46
- for case in [second , * array_like_cases ]:
47
- union = first .union (case , sort = sort )
48
- if sort is None :
49
- expected = first .sort_values ()
50
- else :
51
- expected = first
52
- tm .assert_index_equal (union , expected )
48
+ if klass is not MultiIndex :
49
+ second = klass (second .values )
50
+
51
+ union = first .union (second , sort = sort )
52
+ if sort is None :
53
+ expected = first .sort_values ()
54
+ else :
55
+ expected = first
56
+ tm .assert_index_equal (union , expected )
53
57
54
58
msg = "other must be a MultiIndex or a list of tuples"
55
59
with pytest .raises (TypeError , match = msg ):
0 commit comments