Skip to content

REF: helpers to de-duplicate CoW checks #53882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 27, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 43 additions & 41 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,43 @@ def copy(self, deep: bool = True) -> Self:
refs = self.refs
return type(self)(values, placement=self._mgr_locs, ndim=self.ndim, refs=refs)

# ---------------------------------------------------------------------
# Copy-on-Write Helpers

@final
def _maybe_copy(self, using_cow: bool, inplace: bool) -> Self:
if using_cow and inplace:
deep = self.refs.has_reference()
blk = self.copy(deep=deep)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a general note here (it's not a problem right now): this could have side effects, creating a shallow copy of the block sets up another reference, which will yield True for blk.refs.has_reference() and self.refs.has_reference(), so have to be a bit careful when using

else:
blk = self if inplace else self.copy()
return blk

@final
def _get_values_and_refs(self, using_cow, inplace):
if using_cow:
if inplace and not self.refs.has_reference():
refs = self.refs
new_values = self.values
else:
refs = None
new_values = self.values.copy()
else:
refs = None
new_values = self.values if inplace else self.values.copy()
return new_values, refs

@final
def _get_refs_and_copy(self, using_cow: bool, inplace: bool):
refs = None
arr_inplace = inplace
if inplace:
if using_cow and self.refs.has_reference():
arr_inplace = False
else:
refs = self.refs
return arr_inplace, refs

# ---------------------------------------------------------------------
# Replace

Expand All @@ -597,12 +634,7 @@ def replace(
if isinstance(values, Categorical):
# TODO: avoid special-casing
# GH49404
if using_cow and (self.refs.has_reference() or not inplace):
blk = self.copy()
elif using_cow:
blk = self.copy(deep=False)
else:
blk = self if inplace else self.copy()
blk = self._maybe_copy(using_cow, inplace)
values = cast(Categorical, blk.values)
values._replace(to_replace=to_replace, value=value, inplace=True)
return [blk]
Expand Down Expand Up @@ -630,13 +662,7 @@ def replace(
elif self._can_hold_element(value):
# TODO(CoW): Maybe split here as well into columns where mask has True
# and rest?
if using_cow:
if inplace:
blk = self.copy(deep=self.refs.has_reference())
else:
blk = self.copy()
else:
blk = self if inplace else self.copy()
blk = self._maybe_copy(using_cow, inplace)
putmask_inplace(blk.values, mask, value)
if not (self.is_object and value is None):
# if the user *explicitly* gave None, we keep None, otherwise
Expand Down Expand Up @@ -712,16 +738,7 @@ def _replace_regex(

rx = re.compile(to_replace)

if using_cow:
if inplace and not self.refs.has_reference():
refs = self.refs
new_values = self.values
else:
refs = None
new_values = self.values.copy()
else:
refs = None
new_values = self.values if inplace else self.values.copy()
new_values, refs = self._get_values_and_refs(using_cow, inplace)

replace_regex(new_values, rx, value, mask)

Expand All @@ -745,10 +762,7 @@ def replace_list(
if isinstance(values, Categorical):
# TODO: avoid special-casing
# GH49404
if using_cow and inplace:
blk = self.copy(deep=self.refs.has_reference())
else:
blk = self if inplace else self.copy()
blk = self._maybe_copy(using_cow, inplace)
values = cast(Categorical, blk.values)
values._replace(to_replace=src_list, value=dest_list, inplace=True)
return [blk]
Expand Down Expand Up @@ -1429,13 +1443,7 @@ def interpolate(
**kwargs,
)

refs = None
arr_inplace = inplace
if inplace:
if using_cow and self.refs.has_reference():
arr_inplace = False
else:
refs = self.refs
arr_inplace, refs = self._get_refs_and_copy(using_cow, inplace)

# Dispatch to the PandasArray method.
# We know self.array_values is a PandasArray bc EABlock overrides
Expand Down Expand Up @@ -2281,13 +2289,7 @@ def interpolate(
# "Literal['linear']") [comparison-overlap]
if method == "linear": # type: ignore[comparison-overlap]
# TODO: GH#50950 implement for arbitrary EAs
refs = None
arr_inplace = inplace
if using_cow:
if inplace and not self.refs.has_reference():
refs = self.refs
else:
arr_inplace = False
arr_inplace, refs = self._get_refs_and_copy(using_cow, inplace)

new_values = self.values.interpolate(
method=method,
Expand Down