Skip to content

Commit 5793ed2

Browse files
committed
ENH: Optimize putmask implementation for CoW (#51268)
1 parent da92024 commit 5793ed2

File tree

4 files changed

+89
-19
lines changed

4 files changed

+89
-19
lines changed

pandas/core/generic.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -9612,7 +9612,8 @@ def _where(
96129612
# align the cond to same shape as myself
96139613
cond = common.apply_if_callable(cond, self)
96149614
if isinstance(cond, NDFrame):
9615-
cond, _ = cond.align(self, join="right", broadcast_axis=1, copy=False)
9615+
# CoW: Make sure reference is not kept alive
9616+
cond = cond.align(self, join="right", broadcast_axis=1, copy=False)[0]
96169617
else:
96179618
if not hasattr(cond, "shape"):
96189619
cond = np.asanyarray(cond)
@@ -9648,14 +9649,15 @@ def _where(
96489649
# align with me
96499650
if other.ndim <= self.ndim:
96509651

9651-
_, other = self.align(
9652+
# CoW: Make sure reference is not kept alive
9653+
other = self.align(
96529654
other,
96539655
join="left",
96549656
axis=axis,
96559657
level=level,
96569658
fill_value=None,
96579659
copy=False,
9658-
)
9660+
)[1]
96599661

96609662
# if we are NOT aligned, raise as we cannot where index
96619663
if axis is None and not other._indexed_same(self):

pandas/core/internals/blocks.py

+35-6
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,7 @@ def _unstack(
946946

947947
# ---------------------------------------------------------------------
948948

949-
def setitem(self, indexer, value) -> Block:
949+
def setitem(self, indexer, value, using_cow: bool = False) -> Block:
950950
"""
951951
Attempt self.values[indexer] = value, possibly creating a new array.
952952
@@ -956,6 +956,8 @@ def setitem(self, indexer, value) -> Block:
956956
The subset of self.values to set
957957
value : object
958958
The value being set
959+
using_cow: bool, default False
960+
Signaling if CoW is used.
959961
960962
Returns
961963
-------
@@ -991,10 +993,17 @@ def setitem(self, indexer, value) -> Block:
991993
# checking lib.is_scalar here fails on
992994
# test_iloc_setitem_custom_object
993995
casted = setitem_datetimelike_compat(values, len(vi), casted)
996+
997+
if using_cow and self.refs.has_reference():
998+
values = values.copy()
999+
self = self.make_block_same_class(
1000+
values.T if values.ndim == 2 else values
1001+
)
1002+
9941003
values[indexer] = casted
9951004
return self
9961005

997-
def putmask(self, mask, new) -> list[Block]:
1006+
def putmask(self, mask, new, using_cow: bool = False) -> list[Block]:
9981007
"""
9991008
putmask the data to the block; it is possible that we may create a
10001009
new dtype of block
@@ -1022,11 +1031,21 @@ def putmask(self, mask, new) -> list[Block]:
10221031
new = extract_array(new, extract_numpy=True)
10231032

10241033
if noop:
1034+
if using_cow:
1035+
return [self.copy(deep=False)]
10251036
return [self]
10261037

10271038
try:
10281039
casted = np_can_hold_element(values.dtype, new)
1040+
1041+
if using_cow and self.refs.has_reference():
1042+
# Do this here to avoid copying twice
1043+
values = values.copy()
1044+
self = self.make_block_same_class(values)
1045+
10291046
putmask_without_repeat(values.T, mask, casted)
1047+
if using_cow:
1048+
return [self.copy(deep=False)]
10301049
return [self]
10311050
except LossySetitemError:
10321051

@@ -1038,7 +1057,7 @@ def putmask(self, mask, new) -> list[Block]:
10381057
return self.coerce_to_target_dtype(new).putmask(mask, new)
10391058
else:
10401059
indexer = mask.nonzero()[0]
1041-
nb = self.setitem(indexer, new[indexer])
1060+
nb = self.setitem(indexer, new[indexer], using_cow=using_cow)
10421061
return [nb]
10431062

10441063
else:
@@ -1053,7 +1072,7 @@ def putmask(self, mask, new) -> list[Block]:
10531072
n = new[:, i : i + 1]
10541073

10551074
submask = orig_mask[:, i : i + 1]
1056-
rbs = nb.putmask(submask, n)
1075+
rbs = nb.putmask(submask, n, using_cow=using_cow)
10571076
res_blocks.extend(rbs)
10581077
return res_blocks
10591078

@@ -1462,7 +1481,7 @@ class EABackedBlock(Block):
14621481

14631482
values: ExtensionArray
14641483

1465-
def setitem(self, indexer, value):
1484+
def setitem(self, indexer, value, using_cow: bool = False):
14661485
"""
14671486
Attempt self.values[indexer] = value, possibly creating a new array.
14681487
@@ -1475,6 +1494,8 @@ def setitem(self, indexer, value):
14751494
The subset of self.values to set
14761495
value : object
14771496
The value being set
1497+
using_cow: bool, default False
1498+
Signaling if CoW is used.
14781499
14791500
Returns
14801501
-------
@@ -1581,7 +1602,7 @@ def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
15811602
nb = self.make_block_same_class(res_values)
15821603
return [nb]
15831604

1584-
def putmask(self, mask, new) -> list[Block]:
1605+
def putmask(self, mask, new, using_cow: bool = False) -> list[Block]:
15851606
"""
15861607
See Block.putmask.__doc__
15871608
"""
@@ -1599,8 +1620,16 @@ def putmask(self, mask, new) -> list[Block]:
15991620
mask = self._maybe_squeeze_arg(mask)
16001621

16011622
if not mask.any():
1623+
if using_cow:
1624+
return [self.copy(deep=False)]
16021625
return [self]
16031626

1627+
if using_cow and self.refs.has_reference():
1628+
values = values.copy()
1629+
self = self.make_block_same_class( # type: ignore[assignment]
1630+
values.T if values.ndim == 2 else values
1631+
)
1632+
16041633
try:
16051634
# Caller is responsible for ensuring matching lengths
16061635
values._putmask(mask, new)

pandas/core/internals/managers.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -362,14 +362,6 @@ def setitem(self: T, indexer, value) -> T:
362362
return self.apply("setitem", indexer=indexer, value=value)
363363

364364
def putmask(self, mask, new, align: bool = True):
365-
if using_copy_on_write() and any(
366-
not self._has_no_reference_block(i) for i in range(len(self.blocks))
367-
):
368-
# some reference -> copy full dataframe
369-
# TODO(CoW) this could be optimized to only copy the blocks that would
370-
# get modified
371-
self = self.copy()
372-
373365
if align:
374366
align_keys = ["new", "mask"]
375367
else:
@@ -381,6 +373,7 @@ def putmask(self, mask, new, align: bool = True):
381373
align_keys=align_keys,
382374
mask=mask,
383375
new=new,
376+
using_cow=using_copy_on_write(),
384377
)
385378

386379
def diff(self: T, n: int, axis: AxisInt) -> T:

pandas/tests/copy_view/test_methods.py

+48-2
Original file line numberDiff line numberDiff line change
@@ -1240,8 +1240,9 @@ def test_replace(using_copy_on_write, replace_kwargs):
12401240
tm.assert_frame_equal(df, df_orig)
12411241

12421242

1243-
def test_putmask(using_copy_on_write):
1244-
df = DataFrame({"a": [1, 2], "b": 1, "c": 2})
1243+
@pytest.mark.parametrize("dtype", ["int64", "Int64"])
1244+
def test_putmask(using_copy_on_write, dtype):
1245+
df = DataFrame({"a": [1, 2], "b": 1, "c": 2}, dtype=dtype)
12451246
view = df[:]
12461247
df_orig = df.copy()
12471248
df[df == df] = 5
@@ -1255,6 +1256,51 @@ def test_putmask(using_copy_on_write):
12551256
assert view.iloc[0, 0] == 5
12561257

12571258

1259+
@pytest.mark.parametrize("dtype", ["int64", "Int64"])
1260+
def test_putmask_no_reference(using_copy_on_write, dtype):
1261+
df = DataFrame({"a": [1, 2], "b": 1, "c": 2}, dtype=dtype)
1262+
arr_a = get_array(df, "a")
1263+
df[df == df] = 5
1264+
1265+
if using_copy_on_write:
1266+
assert np.shares_memory(arr_a, get_array(df, "a"))
1267+
1268+
1269+
@pytest.mark.parametrize("dtype", ["float64", "Float64"])
1270+
def test_putmask_aligns_rhs_no_reference(using_copy_on_write, dtype):
1271+
df = DataFrame({"a": [1.5, 2], "b": 1.5}, dtype=dtype)
1272+
arr_a = get_array(df, "a")
1273+
df[df == df] = DataFrame({"a": [5.5, 5]})
1274+
1275+
if using_copy_on_write:
1276+
assert np.shares_memory(arr_a, get_array(df, "a"))
1277+
1278+
1279+
@pytest.mark.parametrize("val, exp", [(5.5, True), (5, False)])
1280+
def test_putmask_dont_copy_some_blocks(using_copy_on_write, val, exp):
1281+
df = DataFrame({"a": [1, 2], "b": 1, "c": 1.5})
1282+
view = df[:]
1283+
df_orig = df.copy()
1284+
indexer = DataFrame(
1285+
[[True, False, False], [True, False, False]], columns=list("abc")
1286+
)
1287+
df[indexer] = val
1288+
1289+
if using_copy_on_write:
1290+
assert not np.shares_memory(get_array(view, "a"), get_array(df, "a"))
1291+
# TODO(CoW): Could split blocks to avoid copying the whole block
1292+
assert np.shares_memory(get_array(view, "b"), get_array(df, "b")) is exp
1293+
assert np.shares_memory(get_array(view, "c"), get_array(df, "c"))
1294+
assert df._mgr._has_no_reference(1) is not exp
1295+
assert not df._mgr._has_no_reference(2)
1296+
tm.assert_frame_equal(view, df_orig)
1297+
elif val == 5:
1298+
# Without CoW the original will be modified, the other case upcasts, e.g. copy
1299+
assert np.shares_memory(get_array(view, "a"), get_array(df, "a"))
1300+
assert np.shares_memory(get_array(view, "c"), get_array(df, "c"))
1301+
assert view.iloc[0, 0] == 5
1302+
1303+
12581304
def test_asfreq_noop(using_copy_on_write):
12591305
df = DataFrame(
12601306
{"a": [0.0, None, 2.0, 3.0]},

0 commit comments

Comments
 (0)