Skip to content

Commit cb04202

Browse files
authored
ENH: Add CoW mechanism to replace_regex (#51669)
* ENH: Add CoW mechanism to replace_regex * Fix test * Fix test
1 parent 4d18871 commit cb04202

File tree

3 files changed

+64
-4
lines changed

3 files changed

+64
-4
lines changed

pandas/core/internals/blocks.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,7 @@ def _replace_regex(
651651
value,
652652
inplace: bool = False,
653653
mask=None,
654+
using_cow: bool = False,
654655
) -> list[Block]:
655656
"""
656657
Replace elements by the given value.
@@ -665,6 +666,8 @@ def _replace_regex(
665666
Perform inplace modification.
666667
mask : array-like of bool, optional
667668
True indicate corresponding element is ignored.
669+
using_cow: bool, default False
670+
Specifying if copy on write is enabled.
668671
669672
Returns
670673
-------
@@ -673,15 +676,27 @@ def _replace_regex(
673676
if not self._can_hold_element(to_replace):
674677
# i.e. only ObjectBlock, but could in principle include a
675678
# String ExtensionBlock
679+
if using_cow:
680+
return [self.copy(deep=False)]
676681
return [self] if inplace else [self.copy()]
677682

678683
rx = re.compile(to_replace)
679684

680-
new_values = self.values if inplace else self.values.copy()
685+
if using_cow:
686+
if inplace and not self.refs.has_reference():
687+
refs = self.refs
688+
new_values = self.values
689+
else:
690+
refs = None
691+
new_values = self.values.copy()
692+
else:
693+
refs = None
694+
new_values = self.values if inplace else self.values.copy()
695+
681696
replace_regex(new_values, rx, value, mask)
682697

683-
block = self.make_block(new_values)
684-
return block.convert(copy=False)
698+
block = self.make_block(new_values, refs=refs)
699+
return block.convert(copy=False, using_cow=using_cow)
685700

686701
@final
687702
def replace_list(

pandas/core/internals/managers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def replace(self: T, to_replace, value, inplace: bool) -> T:
467467
)
468468

469469
def replace_regex(self, **kwargs):
470-
return self.apply("_replace_regex", **kwargs)
470+
return self.apply("_replace_regex", **kwargs, using_cow=using_copy_on_write())
471471

472472
def replace_list(
473473
self: T,

pandas/tests/copy_view/test_replace.py

+45
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,51 @@ def test_replace(using_copy_on_write, replace_kwargs):
4747
tm.assert_frame_equal(df, df_orig)
4848

4949

50+
def test_replace_regex_inplace_refs(using_copy_on_write):
51+
df = DataFrame({"a": ["aaa", "bbb"]})
52+
df_orig = df.copy()
53+
view = df[:]
54+
arr = get_array(df, "a")
55+
df.replace(to_replace=r"^a.*$", value="new", inplace=True, regex=True)
56+
if using_copy_on_write:
57+
assert not np.shares_memory(arr, get_array(df, "a"))
58+
assert df._mgr._has_no_reference(0)
59+
tm.assert_frame_equal(view, df_orig)
60+
else:
61+
assert np.shares_memory(arr, get_array(df, "a"))
62+
63+
64+
def test_replace_regex_inplace(using_copy_on_write):
65+
df = DataFrame({"a": ["aaa", "bbb"]})
66+
arr = get_array(df, "a")
67+
df.replace(to_replace=r"^a.*$", value="new", inplace=True, regex=True)
68+
if using_copy_on_write:
69+
assert df._mgr._has_no_reference(0)
70+
assert np.shares_memory(arr, get_array(df, "a"))
71+
72+
df_orig = df.copy()
73+
df2 = df.replace(to_replace=r"^b.*$", value="new", regex=True)
74+
tm.assert_frame_equal(df_orig, df)
75+
assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a"))
76+
77+
78+
def test_replace_regex_inplace_no_op(using_copy_on_write):
79+
df = DataFrame({"a": [1, 2]})
80+
arr = get_array(df, "a")
81+
df.replace(to_replace=r"^a.$", value="new", inplace=True, regex=True)
82+
if using_copy_on_write:
83+
assert df._mgr._has_no_reference(0)
84+
assert np.shares_memory(arr, get_array(df, "a"))
85+
86+
df_orig = df.copy()
87+
df2 = df.replace(to_replace=r"^x.$", value="new", regex=True)
88+
tm.assert_frame_equal(df_orig, df)
89+
if using_copy_on_write:
90+
assert np.shares_memory(get_array(df2, "a"), get_array(df, "a"))
91+
else:
92+
assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a"))
93+
94+
5095
def test_replace_mask_all_false_second_block(using_copy_on_write):
5196
df = DataFrame({"a": [1.5, 2, 3], "b": 100.5, "c": 1, "d": 2})
5297
df_orig = df.copy()

0 commit comments

Comments
 (0)