Skip to content

Commit c74d057

Browse files
authored
ENH: Avoid copying unnecessary columns in setitem by splitting blocks for CoW (#51031)
1 parent 2d61ec3 commit c74d057

File tree

3 files changed

+60
-11
lines changed

3 files changed

+60
-11
lines changed

pandas/core/internals/blocks.py

+3
Original file line numberDiff line numberDiff line change
@@ -1635,6 +1635,9 @@ def delete(self, loc) -> list[Block]:
16351635
values = self.values.delete(loc)
16361636
mgr_locs = self._mgr_locs.delete(loc)
16371637
return [type(self)(values, placement=mgr_locs, ndim=self.ndim)]
1638+
elif self.values.ndim == 1:
1639+
# We get here through to_stata
1640+
return []
16381641
return super().delete(loc)
16391642

16401643
@cache_readonly

pandas/core/internals/managers.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -1246,7 +1246,10 @@ def value_getitem(placement):
12461246
self._known_consolidated = False
12471247

12481248
def _iset_split_block(
1249-
self, blkno_l: int, blk_locs: np.ndarray, value: ArrayLike | None = None
1249+
self,
1250+
blkno_l: int,
1251+
blk_locs: np.ndarray | list[int],
1252+
value: ArrayLike | None = None,
12501253
) -> None:
12511254
"""Removes columns from a block by splitting the block.
12521255
@@ -1267,12 +1270,8 @@ def _iset_split_block(
12671270

12681271
nbs_tup = tuple(blk.delete(blk_locs))
12691272
if value is not None:
1270-
# Argument 1 to "BlockPlacement" has incompatible type "BlockPlacement";
1271-
# expected "Union[int, slice, ndarray[Any, Any]]"
1272-
first_nb = new_block_2d(
1273-
value,
1274-
BlockPlacement(blk.mgr_locs[blk_locs]), # type: ignore[arg-type]
1275-
)
1273+
locs = blk.mgr_locs.as_array[blk_locs]
1274+
first_nb = new_block_2d(value, BlockPlacement(locs))
12761275
else:
12771276
first_nb = nbs_tup[0]
12781277
nbs_tup = tuple(nbs_tup[1:])
@@ -1283,6 +1282,10 @@ def _iset_split_block(
12831282
)
12841283
self.blocks = blocks_tup
12851284

1285+
if not nbs_tup and value is not None:
1286+
# No need to update anything if split did not happen
1287+
return
1288+
12861289
self._blklocs[first_nb.mgr_locs.indexer] = np.arange(len(first_nb))
12871290

12881291
for i, nb in enumerate(nbs_tup):
@@ -1326,11 +1329,18 @@ def column_setitem(
13261329
intermediate Series at the DataFrame level (`s = df[loc]; s[idx] = value`)
13271330
"""
13281331
if using_copy_on_write() and not self._has_no_reference(loc):
1329-
# otherwise perform Copy-on-Write and clear the reference
13301332
blkno = self.blknos[loc]
1331-
blocks = list(self.blocks)
1332-
blocks[blkno] = blocks[blkno].copy()
1333-
self.blocks = tuple(blocks)
1333+
# Split blocks to only copy the column we want to modify
1334+
blk_loc = self.blklocs[loc]
1335+
# Copy our values
1336+
values = self.blocks[blkno].values
1337+
if values.ndim == 1:
1338+
values = values.copy()
1339+
else:
1340+
# Use [blk_loc] as indexer to keep ndim=2, this already results in a
1341+
# copy
1342+
values = values[[blk_loc]]
1343+
self._iset_split_block(blkno, [blk_loc], values)
13341344

13351345
# this manager is only created temporarily to mutate the values in place
13361346
# so don't track references, otherwise the `setitem` would perform CoW again

pandas/tests/copy_view/test_indexing.py

+36
Original file line numberDiff line numberDiff line change
@@ -883,3 +883,39 @@ def test_dataframe_add_column_from_series():
883883
df.loc[2, "new"] = 100
884884
expected_s = Series([0, 11, 12])
885885
tm.assert_series_equal(s, expected_s)
886+
887+
888+
@pytest.mark.parametrize("val", [100, "a"])
889+
@pytest.mark.parametrize(
890+
"indexer_func, indexer",
891+
[
892+
(tm.loc, (0, "a")),
893+
(tm.iloc, (0, 0)),
894+
(tm.loc, ([0], "a")),
895+
(tm.iloc, ([0], 0)),
896+
(tm.loc, (slice(None), "a")),
897+
(tm.iloc, (slice(None), 0)),
898+
],
899+
)
900+
def test_set_value_copy_only_necessary_column(
901+
using_copy_on_write, indexer_func, indexer, val
902+
):
903+
# When setting inplace, only copy column that is modified instead of the whole
904+
# block (by splitting the block)
905+
# TODO multi-block only for now
906+
df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]})
907+
df_orig = df.copy()
908+
view = df[:]
909+
910+
indexer_func(df)[indexer] = val
911+
912+
if using_copy_on_write:
913+
assert np.shares_memory(get_array(df, "b"), get_array(view, "b"))
914+
assert not np.shares_memory(get_array(df, "a"), get_array(view, "a"))
915+
tm.assert_frame_equal(view, df_orig)
916+
else:
917+
assert np.shares_memory(get_array(df, "c"), get_array(view, "c"))
918+
if val == "a":
919+
assert not np.shares_memory(get_array(df, "a"), get_array(view, "a"))
920+
else:
921+
assert np.shares_memory(get_array(df, "a"), get_array(view, "a"))

0 commit comments

Comments
 (0)