@@ -459,6 +459,26 @@ def test_corrwith_mixed_dtypes(self):
459
459
expected = pd .Series (data = corrs , index = ['a' , 'b' ])
460
460
tm .assert_series_equal (result , expected )
461
461
462
+ def test_corrwith_index_intersection (self ):
463
+ df1 = pd .DataFrame (np .random .random (size = (10 , 2 )),
464
+ columns = ["a" , "b" ])
465
+ df2 = pd .DataFrame (np .random .random (size = (10 , 3 )),
466
+ columns = ["a" , "b" , "c" ])
467
+
468
+ result = df1 .corrwith (df2 , drop = True ).index .sort_values ()
469
+ expected = df1 .columns .intersection (df2 .columns ).sort_values ()
470
+ tm .assert_index_equal (result , expected )
471
+
472
+ def test_corrwith_index_union (self ):
473
+ df1 = pd .DataFrame (np .random .random (size = (10 , 2 )),
474
+ columns = ["a" , "b" ])
475
+ df2 = pd .DataFrame (np .random .random (size = (10 , 3 )),
476
+ columns = ["a" , "b" , "c" ])
477
+
478
+ result = df1 .corrwith (df2 , drop = False ).index .sort_values ()
479
+ expected = df1 .columns .union (df2 .columns ).sort_values ()
480
+ tm .assert_index_equal (result , expected )
481
+
462
482
def test_corrwith_dup_cols (self ):
463
483
# GH 21925
464
484
df1 = pd .DataFrame (np .vstack ([np .arange (10 )] * 3 ).T )
0 commit comments