Skip to content

Commit 3395ca3

Browse files
CoW: Avoid warning in apply for mixed dtype frame (#56212)
Co-authored-by: Joris Van den Bossche <[email protected]>
1 parent d44f6c1 commit 3395ca3

File tree

4 files changed

+45
-15
lines changed

4 files changed

+45
-15
lines changed

pandas/core/apply.py

+11
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pandas._config import option_context
2121

2222
from pandas._libs import lib
23+
from pandas._libs.internals import BlockValuesRefs
2324
from pandas._typing import (
2425
AggFuncType,
2526
AggFuncTypeBase,
@@ -1254,6 +1255,8 @@ def series_generator(self) -> Generator[Series, None, None]:
12541255
ser = self.obj._ixs(0, axis=0)
12551256
mgr = ser._mgr
12561257

1258+
is_view = mgr.blocks[0].refs.has_reference() # type: ignore[union-attr]
1259+
12571260
if isinstance(ser.dtype, ExtensionDtype):
12581261
# values will be incorrect for this block
12591262
# TODO(EA2D): special case would be unnecessary with 2D EAs
@@ -1267,6 +1270,14 @@ def series_generator(self) -> Generator[Series, None, None]:
12671270
ser._mgr = mgr
12681271
mgr.set_values(arr)
12691272
object.__setattr__(ser, "_name", name)
1273+
if not is_view:
1274+
# In apply_series_generator we store the a shallow copy of the
1275+
# result, which potentially increases the ref count of this reused
1276+
# `ser` object (depending on the result of the applied function)
1277+
# -> if that happened and `ser` is already a copy, then we reset
1278+
# the refs here to avoid triggering a unnecessary CoW inside the
1279+
# applied function (https://github.com/pandas-dev/pandas/pull/56212)
1280+
mgr.blocks[0].refs = BlockValuesRefs(mgr.blocks[0]) # type: ignore[union-attr]
12701281
yield ser
12711282

12721283
@staticmethod

pandas/core/internals/managers.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2079,11 +2079,11 @@ def set_values(self, values: ArrayLike) -> None:
20792079
Set the values of the single block in place.
20802080
20812081
Use at your own risk! This does not check if the passed values are
2082-
valid for the current Block/SingleBlockManager (length, dtype, etc).
2082+
valid for the current Block/SingleBlockManager (length, dtype, etc),
2083+
and this does not properly keep track of references.
20832084
"""
2084-
# TODO(CoW) do we need to handle copy on write here? Currently this is
2085-
# only used for FrameColumnApply.series_generator (what if apply is
2086-
# mutating inplace?)
2085+
# NOTE(CoW) Currently this is only used for FrameColumnApply.series_generator
2086+
# which handles CoW by setting the refs manually if necessary
20872087
self.blocks[0].values = values
20882088
self.blocks[0]._mgr_locs = BlockPlacement(slice(len(values)))
20892089

pandas/tests/apply/test_invalid_arg.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
DataFrame,
1919
Series,
2020
date_range,
21-
notna,
2221
)
2322
import pandas._testing as tm
2423

@@ -150,9 +149,7 @@ def test_transform_axis_1_raises():
150149
Series([1]).transform("sum", axis=1)
151150

152151

153-
# TODO(CoW-warn) should not need to warn
154-
@pytest.mark.filterwarnings("ignore:Setting a value on a view:FutureWarning")
155-
def test_apply_modify_traceback(warn_copy_on_write):
152+
def test_apply_modify_traceback():
156153
data = DataFrame(
157154
{
158155
"A": [
@@ -207,15 +204,9 @@ def transform(row):
207204
row["D"] = 7
208205
return row
209206

210-
def transform2(row):
211-
if notna(row["C"]) and row["C"].startswith("shin") and row["A"] == "foo":
212-
row["D"] = 7
213-
return row
214-
215207
msg = "'float' object has no attribute 'startswith'"
216208
with pytest.raises(AttributeError, match=msg):
217-
with tm.assert_cow_warning(warn_copy_on_write):
218-
data.apply(transform, axis=1)
209+
data.apply(transform, axis=1)
219210

220211

221212
@pytest.mark.parametrize(

pandas/tests/copy_view/test_methods.py

+28
Original file line numberDiff line numberDiff line change
@@ -2013,3 +2013,31 @@ def test_eval_inplace(using_copy_on_write, warn_copy_on_write):
20132013
df.iloc[0, 0] = 100
20142014
if using_copy_on_write:
20152015
tm.assert_frame_equal(df_view, df_orig)
2016+
2017+
2018+
def test_apply_modify_row(using_copy_on_write, warn_copy_on_write):
2019+
# Case: applying a function on each row as a Series object, where the
2020+
# function mutates the row object (which needs to trigger CoW if row is a view)
2021+
df = DataFrame({"A": [1, 2], "B": [3, 4]})
2022+
df_orig = df.copy()
2023+
2024+
def transform(row):
2025+
row["B"] = 100
2026+
return row
2027+
2028+
with tm.assert_cow_warning(warn_copy_on_write):
2029+
df.apply(transform, axis=1)
2030+
2031+
if using_copy_on_write:
2032+
tm.assert_frame_equal(df, df_orig)
2033+
else:
2034+
assert df.loc[0, "B"] == 100
2035+
2036+
# row Series is a copy
2037+
df = DataFrame({"A": [1, 2], "B": ["b", "c"]})
2038+
df_orig = df.copy()
2039+
2040+
with tm.assert_produces_warning(None):
2041+
df.apply(transform, axis=1)
2042+
2043+
tm.assert_frame_equal(df, df_orig)

0 commit comments

Comments
 (0)