@@ -399,13 +399,36 @@ def test_concat_range_index_result(self):
399
399
expected_index = pd .RangeIndex (0 , 2 )
400
400
tm .assert_index_equal (result .index , expected_index , exact = True )
401
401
402
- @pytest .mark .parametrize ("dtype" , ["Int64" , "object" ])
403
- def test_concat_index_keep_dtype (self , dtype ):
402
+ def test_concat_index_keep_dtype (self ):
403
+ # GH#47329
404
+ df1 = DataFrame ([[0 , 1 , 1 ]], columns = Index ([1 , 2 , 3 ], dtype = "object" ))
405
+ df2 = DataFrame ([[0 , 1 ]], columns = Index ([1 , 2 ], dtype = "object" ))
406
+ result = concat ([df1 , df2 ], ignore_index = True , join = "outer" , sort = True )
407
+ expected = DataFrame (
408
+ [[0 , 1 , 1.0 ], [0 , 1 , np .nan ]], columns = Index ([1 , 2 , 3 ], dtype = "object" )
409
+ )
410
+ tm .assert_frame_equal (result , expected )
411
+
412
+ def test_concat_index_keep_dtype_ea_numeric (self , any_numeric_ea_dtype ):
413
+ # GH#47329
414
+ df1 = DataFrame (
415
+ [[0 , 1 , 1 ]], columns = Index ([1 , 2 , 3 ], dtype = any_numeric_ea_dtype )
416
+ )
417
+ df2 = DataFrame ([[0 , 1 ]], columns = Index ([1 , 2 ], dtype = any_numeric_ea_dtype ))
418
+ result = concat ([df1 , df2 ], ignore_index = True , join = "outer" , sort = True )
419
+ expected = DataFrame (
420
+ [[0 , 1 , 1.0 ], [0 , 1 , np .nan ]],
421
+ columns = Index ([1 , 2 , 3 ], dtype = any_numeric_ea_dtype ),
422
+ )
423
+ tm .assert_frame_equal (result , expected )
424
+
425
+ @pytest .mark .parametrize ("dtype" , ["Int8" , "Int16" , "Int32" ])
426
+ def test_concat_index_find_common (self , dtype ):
404
427
# GH#47329
405
428
df1 = DataFrame ([[0 , 1 , 1 ]], columns = Index ([1 , 2 , 3 ], dtype = dtype ))
406
- df2 = DataFrame ([[0 , 1 ]], columns = Index ([1 , 2 ], dtype = dtype ))
429
+ df2 = DataFrame ([[0 , 1 ]], columns = Index ([1 , 2 ], dtype = "Int32" ))
407
430
result = concat ([df1 , df2 ], ignore_index = True , join = "outer" , sort = True )
408
431
expected = DataFrame (
409
- [[0 , 1 , 1.0 ], [0 , 1 , np .nan ]], columns = Index ([1 , 2 , 3 ], dtype = dtype )
432
+ [[0 , 1 , 1.0 ], [0 , 1 , np .nan ]], columns = Index ([1 , 2 , 3 ], dtype = "Int32" )
410
433
)
411
434
tm .assert_frame_equal (result , expected )
0 commit comments