Skip to content

Commit f066cda

Browse files
authored
REF: avoid object-dtype casting in Block.replace (#40082)
1 parent 739f550 commit f066cda

File tree

3 files changed

+28
-14
lines changed

3 files changed

+28
-14
lines changed

pandas/core/internals/blocks.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,6 @@ def replace(
796796
It is used in ObjectBlocks. It is here for API compatibility.
797797
"""
798798
inplace = validate_bool_kwarg(inplace, "inplace")
799-
original_to_replace = to_replace
800799

801800
if not self._can_hold_element(to_replace):
802801
# We cannot hold `to_replace`, so we know immediately that
@@ -814,17 +813,28 @@ def replace(
814813
return [self] if inplace else [self.copy()]
815814

816815
if not self._can_hold_element(value):
817-
blk = self.astype(object)
816+
if self.ndim == 2 and self.shape[0] > 1:
817+
# split so that we only upcast where necessary
818+
nbs = self._split()
819+
res_blocks = extend_blocks(
820+
[
821+
blk.replace(to_replace, value, inplace=inplace, regex=regex)
822+
for blk in nbs
823+
]
824+
)
825+
return res_blocks
826+
827+
blk = self.coerce_to_target_dtype(value)
818828
return blk.replace(
819-
to_replace=original_to_replace,
829+
to_replace=to_replace,
820830
value=value,
821831
inplace=True,
822832
regex=regex,
823833
)
824834

825835
blk = self if inplace else self.copy()
826836
putmask_inplace(blk.values, mask, value)
827-
blocks = blk.convert(numeric=False, copy=not inplace)
837+
blocks = blk.convert(numeric=False, copy=False)
828838
return blocks
829839

830840
@final
@@ -867,11 +877,7 @@ def _replace_regex(
867877
replace_regex(new_values, rx, value, mask)
868878

869879
block = self.make_block(new_values)
870-
if convert:
871-
nbs = block.convert(numeric=False)
872-
else:
873-
nbs = [block]
874-
return nbs
880+
return [block]
875881

876882
@final
877883
def _replace_list(

pandas/tests/frame/methods/test_fillna.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -265,12 +265,13 @@ def test_fillna_dtype_conversion(self):
265265
expected = DataFrame("nan", index=range(3), columns=["A", "B"])
266266
tm.assert_frame_equal(result, expected)
267267

268-
# equiv of replace
268+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) object upcasting
269+
@pytest.mark.parametrize("val", ["", 1, np.nan, 1.0])
270+
def test_fillna_dtype_conversion_equiv_replace(self, val):
269271
df = DataFrame({"A": [1, np.nan], "B": [1.0, 2.0]})
270-
for v in ["", 1, np.nan, 1.0]:
271-
expected = df.replace(np.nan, v)
272-
result = df.fillna(v)
273-
tm.assert_frame_equal(result, expected)
272+
expected = df.replace(np.nan, val)
273+
result = df.fillna(val)
274+
tm.assert_frame_equal(result, expected)
274275

275276
@td.skip_array_manager_invalid_test
276277
def test_fillna_datetime_columns(self):

pandas/tests/frame/methods/test_replace.py

+7
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,8 @@ def test_replace_mixed(self, float_string_frame):
783783
tm.assert_frame_equal(result, expected)
784784
tm.assert_frame_equal(result.replace(-1e8, np.nan), float_string_frame)
785785

786+
def test_replace_mixed_int_block_upcasting(self):
787+
786788
# int block upcasting
787789
df = DataFrame(
788790
{
@@ -803,6 +805,8 @@ def test_replace_mixed(self, float_string_frame):
803805
assert return_value is None
804806
tm.assert_frame_equal(df, expected)
805807

808+
def test_replace_mixed_int_block_splitting(self):
809+
806810
# int block splitting
807811
df = DataFrame(
808812
{
@@ -821,6 +825,8 @@ def test_replace_mixed(self, float_string_frame):
821825
result = df.replace(0, 0.5)
822826
tm.assert_frame_equal(result, expected)
823827

828+
def test_replace_mixed2(self):
829+
824830
# to object block upcasting
825831
df = DataFrame(
826832
{
@@ -846,6 +852,7 @@ def test_replace_mixed(self, float_string_frame):
846852
result = df.replace([1, 2], ["foo", "bar"])
847853
tm.assert_frame_equal(result, expected)
848854

855+
def test_replace_mixed3(self):
849856
# test case from
850857
df = DataFrame(
851858
{"A": Series([3, 0], dtype="int64"), "B": Series([0, 3], dtype="int64")}

0 commit comments

Comments
 (0)