diff --git a/pandas/core/apply.py b/pandas/core/apply.py index bb3cc3a03760f..169a44accf066 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -20,6 +20,7 @@ from pandas._config import option_context from pandas._libs import lib +from pandas._libs.internals import BlockValuesRefs from pandas._typing import ( AggFuncType, AggFuncTypeBase, @@ -1254,6 +1255,8 @@ def series_generator(self) -> Generator[Series, None, None]: ser = self.obj._ixs(0, axis=0) mgr = ser._mgr + is_view = mgr.blocks[0].refs.has_reference() # type: ignore[union-attr] + if isinstance(ser.dtype, ExtensionDtype): # values will be incorrect for this block # TODO(EA2D): special case would be unnecessary with 2D EAs @@ -1267,6 +1270,14 @@ def series_generator(self) -> Generator[Series, None, None]: ser._mgr = mgr mgr.set_values(arr) object.__setattr__(ser, "_name", name) + if not is_view: + # In apply_series_generator we store the a shallow copy of the + # result, which potentially increases the ref count of this reused + # `ser` object (depending on the result of the applied function) + # -> if that happened and `ser` is already a copy, then we reset + # the refs here to avoid triggering a unnecessary CoW inside the + # applied function (https://github.com/pandas-dev/pandas/pull/56212) + mgr.blocks[0].refs = BlockValuesRefs(mgr.blocks[0]) # type: ignore[union-attr] yield ser @staticmethod diff --git a/pandas/core/internals/managers.py b/pandas/core/internals/managers.py index a02f31d4483b2..9d2a1e634eea2 100644 --- a/pandas/core/internals/managers.py +++ b/pandas/core/internals/managers.py @@ -2071,11 +2071,11 @@ def set_values(self, values: ArrayLike) -> None: Set the values of the single block in place. Use at your own risk! This does not check if the passed values are - valid for the current Block/SingleBlockManager (length, dtype, etc). + valid for the current Block/SingleBlockManager (length, dtype, etc), + and this does not properly keep track of references. """ - # TODO(CoW) do we need to handle copy on write here? Currently this is - # only used for FrameColumnApply.series_generator (what if apply is - # mutating inplace?) + # NOTE(CoW) Currently this is only used for FrameColumnApply.series_generator + # which handles CoW by setting the refs manually if necessary self.blocks[0].values = values self.blocks[0]._mgr_locs = BlockPlacement(slice(len(values))) diff --git a/pandas/tests/apply/test_invalid_arg.py b/pandas/tests/apply/test_invalid_arg.py index 48dde6d42f743..b5ad1094f5bf5 100644 --- a/pandas/tests/apply/test_invalid_arg.py +++ b/pandas/tests/apply/test_invalid_arg.py @@ -18,7 +18,6 @@ DataFrame, Series, date_range, - notna, ) import pandas._testing as tm @@ -150,9 +149,7 @@ def test_transform_axis_1_raises(): Series([1]).transform("sum", axis=1) -# TODO(CoW-warn) should not need to warn -@pytest.mark.filterwarnings("ignore:Setting a value on a view:FutureWarning") -def test_apply_modify_traceback(warn_copy_on_write): +def test_apply_modify_traceback(): data = DataFrame( { "A": [ @@ -207,15 +204,9 @@ def transform(row): row["D"] = 7 return row - def transform2(row): - if notna(row["C"]) and row["C"].startswith("shin") and row["A"] == "foo": - row["D"] = 7 - return row - msg = "'float' object has no attribute 'startswith'" with pytest.raises(AttributeError, match=msg): - with tm.assert_cow_warning(warn_copy_on_write): - data.apply(transform, axis=1) + data.apply(transform, axis=1) @pytest.mark.parametrize( diff --git a/pandas/tests/copy_view/test_methods.py b/pandas/tests/copy_view/test_methods.py index ba8e4bd684198..84b2c59bda5d9 100644 --- a/pandas/tests/copy_view/test_methods.py +++ b/pandas/tests/copy_view/test_methods.py @@ -2013,3 +2013,31 @@ def test_eval_inplace(using_copy_on_write, warn_copy_on_write): df.iloc[0, 0] = 100 if using_copy_on_write: tm.assert_frame_equal(df_view, df_orig) + + +def test_apply_modify_row(using_copy_on_write, warn_copy_on_write): + # Case: applying a function on each row as a Series object, where the + # function mutates the row object (which needs to trigger CoW if row is a view) + df = DataFrame({"A": [1, 2], "B": [3, 4]}) + df_orig = df.copy() + + def transform(row): + row["B"] = 100 + return row + + with tm.assert_cow_warning(warn_copy_on_write): + df.apply(transform, axis=1) + + if using_copy_on_write: + tm.assert_frame_equal(df, df_orig) + else: + assert df.loc[0, "B"] == 100 + + # row Series is a copy + df = DataFrame({"A": [1, 2], "B": ["b", "c"]}) + df_orig = df.copy() + + with tm.assert_produces_warning(None): + df.apply(transform, axis=1) + + tm.assert_frame_equal(df, df_orig)