@@ -553,3 +553,29 @@ def test_replace_dtype(self, dtype, input_data, to_replace, expected_data):
553
553
result = ser .replace (to_replace )
554
554
expected = pd .Series (expected_data , dtype = dtype )
555
555
tm .assert_series_equal (result , expected )
556
+
557
+ def test_replace_string_dtype (self ):
558
+ # GH#40732, GH#44940
559
+ ser = pd .Series (["one" , "two" , np .nan ], dtype = "string" )
560
+ res = ser .replace ({"one" : "1" , "two" : "2" })
561
+ expected = pd .Series (["1" , "2" , np .nan ], dtype = "string" )
562
+ tm .assert_series_equal (res , expected )
563
+
564
+ def test_replace_nullable_numeric (self ):
565
+ # GH#40732, GH#44940
566
+
567
+ floats = pd .Series ([1.0 , 2.0 , 3.999 , 4.4 ], dtype = pd .Float64Dtype ())
568
+ assert floats .replace ({1.0 : 9 }).dtype == floats .dtype
569
+ assert floats .replace (1.0 , 9 ).dtype == floats .dtype
570
+ assert floats .replace ({1.0 : 9.0 }).dtype == floats .dtype
571
+ assert floats .replace (1.0 , 9.0 ).dtype == floats .dtype
572
+
573
+ res = floats .replace (to_replace = [1.0 , 2.0 ], value = [9.0 , 10.0 ])
574
+ assert res .dtype == floats .dtype
575
+
576
+ ints = pd .Series ([1 , 2 , 3 , 4 ], dtype = pd .Int64Dtype ())
577
+ assert ints .replace ({1 : 9 }).dtype == ints .dtype
578
+ assert ints .replace (1 , 9 ).dtype == ints .dtype
579
+ assert ints .replace ({1 : 9.0 }).dtype == ints .dtype
580
+ assert ints .replace (1 , 9.0 ).dtype == ints .dtype
581
+ # FIXME: ints.replace({1: 9.5}) raises bc of incorrect _can_hold_element
0 commit comments