@@ -40,6 +40,44 @@ def right():
40
40
columns = ['j_one' , 'j_two' , 'j_three' ])
41
41
42
42
43
+ @pytest .fixture
44
+ def left_multi ():
45
+ return (
46
+ DataFrame (
47
+ dict (Origin = ['A' , 'A' , 'B' , 'B' , 'C' ],
48
+ Destination = ['A' , 'B' , 'A' , 'C' , 'A' ],
49
+ Period = ['AM' , 'AM' , 'IP' , 'AM' , 'OP' ],
50
+ TripPurp = ['hbw' , 'nhb' , 'hbo' , 'nhb' , 'hbw' ],
51
+ Trips = [1987 , 3647 , 2470 , 4296 , 4444 ]),
52
+ columns = ['Origin' , 'Destination' , 'Period' ,
53
+ 'TripPurp' , 'Trips' ])
54
+ .set_index (['Origin' , 'Destination' , 'Period' , 'TripPurp' ]))
55
+
56
+
57
+ @pytest .fixture
58
+ def right_multi ():
59
+ return (
60
+ DataFrame (
61
+ dict (Origin = ['A' , 'A' , 'B' , 'B' , 'C' , 'C' , 'E' ],
62
+ Destination = ['A' , 'B' , 'A' , 'B' , 'A' , 'B' , 'F' ],
63
+ Period = ['AM' , 'AM' , 'IP' , 'AM' , 'OP' , 'IP' , 'AM' ],
64
+ LinkType = ['a' , 'b' , 'c' , 'b' , 'a' , 'b' , 'a' ],
65
+ Distance = [100 , 80 , 90 , 80 , 75 , 35 , 55 ]),
66
+ columns = ['Origin' , 'Destination' , 'Period' ,
67
+ 'LinkType' , 'Distance' ])
68
+ .set_index (['Origin' , 'Destination' , 'Period' , 'LinkType' ]))
69
+
70
+
71
+ @pytest .fixture
72
+ def on_cols_multi ():
73
+ return ['Origin' , 'Destination' , 'Period' ]
74
+
75
+
76
+ @pytest .fixture
77
+ def idx_cols_multi ():
78
+ return ['Origin' , 'Destination' , 'Period' , 'TripPurp' , 'LinkType' ]
79
+
80
+
43
81
class TestMergeMulti (object ):
44
82
45
83
def setup_method (self ):
@@ -549,66 +587,28 @@ def test_join_multi_levels2(self):
549
587
tm .assert_frame_equal (result , expected )
550
588
551
589
552
- @pytest .fixture
553
- def left_multi ():
554
- return (
555
- DataFrame (
556
- dict (Origin = ['A' , 'A' , 'B' , 'B' , 'C' ],
557
- Destination = ['A' , 'B' , 'A' , 'C' , 'A' ],
558
- Period = ['AM' , 'AM' , 'IP' , 'AM' , 'OP' ],
559
- TripPurp = ['hbw' , 'nhb' , 'hbo' , 'nhb' , 'hbw' ],
560
- Trips = [1987 , 3647 , 2470 , 4296 , 4444 ]),
561
- columns = ['Origin' , 'Destination' , 'Period' ,
562
- 'TripPurp' , 'Trips' ])
563
- .set_index (['Origin' , 'Destination' , 'Period' , 'TripPurp' ]))
564
-
565
-
566
- @pytest .fixture
567
- def right_multi ():
568
- return (
569
- DataFrame (
570
- dict (Origin = ['A' , 'A' , 'B' , 'B' , 'C' , 'C' , 'E' ],
571
- Destination = ['A' , 'B' , 'A' , 'B' , 'A' , 'B' , 'F' ],
572
- Period = ['AM' , 'AM' , 'IP' , 'AM' , 'OP' , 'IP' , 'AM' ],
573
- LinkType = ['a' , 'b' , 'c' , 'b' , 'a' , 'b' , 'a' ],
574
- Distance = [100 , 80 , 90 , 80 , 75 , 35 , 55 ]),
575
- columns = ['Origin' , 'Destination' , 'Period' ,
576
- 'LinkType' , 'Distance' ])
577
- .set_index (['Origin' , 'Destination' , 'Period' , 'LinkType' ]))
578
-
579
-
580
- @pytest .fixture
581
- def on_cols ():
582
- return ['Origin' , 'Destination' , 'Period' ]
583
-
584
-
585
- @pytest .fixture
586
- def idx_cols ():
587
- return ['Origin' , 'Destination' , 'Period' , 'TripPurp' , 'LinkType' ]
588
-
589
-
590
590
class TestJoinMultiMulti (object ):
591
591
592
592
def test_join_multi_multi (self , left_multi , right_multi , join_type ,
593
- on_cols , idx_cols ):
593
+ on_cols_multi , idx_cols_multi ):
594
594
# Multi-index join tests
595
595
expected = (pd .merge (left_multi .reset_index (),
596
596
right_multi .reset_index (),
597
- how = join_type , on = on_cols ).set_index (idx_cols )
597
+ how = join_type , on = on_cols_multi ).set_index (idx_cols_multi )
598
598
.sort_index ())
599
599
600
600
result = left_multi .join (right_multi , how = join_type ).sort_index ()
601
601
tm .assert_frame_equal (result , expected )
602
602
603
603
def test_join_multi_empty_frames (self , left_multi , right_multi , join_type ,
604
- on_cols , idx_cols ):
604
+ on_cols_multi , idx_cols_multi ):
605
605
606
606
left_multi = left_multi .drop (columns = left_multi .columns )
607
607
right_multi = right_multi .drop (columns = right_multi .columns )
608
608
609
609
expected = (pd .merge (left_multi .reset_index (),
610
610
right_multi .reset_index (),
611
- how = join_type , on = on_cols ).set_index (idx_cols )
611
+ how = join_type , on = on_cols_multi ).set_index (idx_cols_multi )
612
612
.sort_index ())
613
613
614
614
result = left_multi .join (right_multi , how = join_type ).sort_index ()
0 commit comments