Skip to content

Commit cefc6f8

Browse files
authored
ENH: Add lazy copy to where (#51336)
* ENH: Add lazy copy to where * Move tests
1 parent 82a595b commit cefc6f8

File tree

5 files changed

+80
-11
lines changed

5 files changed

+80
-11
lines changed

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ Copy-on-Write improvements
227227
- :meth:`DataFrame.interpolate` / :meth:`Series.interpolate`
228228
- :meth:`DataFrame.ffill` / :meth:`Series.ffill`
229229
- :meth:`DataFrame.bfill` / :meth:`Series.bfill`
230+
- :meth:`DataFrame.where` / :meth:`Series.where`
230231
- :meth:`DataFrame.infer_objects` / :meth:`Series.infer_objects`
231232
- :meth:`DataFrame.astype` / :meth:`Series.astype`
232233
- :meth:`DataFrame.convert_dtypes` / :meth:`Series.convert_dtypes`

pandas/core/internals/blocks.py

+26-10
Original file line numberDiff line numberDiff line change
@@ -1071,7 +1071,9 @@ def putmask(self, mask, new, using_cow: bool = False) -> list[Block]:
10711071
res_blocks.extend(rbs)
10721072
return res_blocks
10731073

1074-
def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
1074+
def where(
1075+
self, other, cond, _downcast: str | bool = "infer", using_cow: bool = False
1076+
) -> list[Block]:
10751077
"""
10761078
evaluate the block; return result block(s) from the result
10771079
@@ -1102,6 +1104,8 @@ def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
11021104
icond, noop = validate_putmask(values, ~cond)
11031105
if noop:
11041106
# GH-39595: Always return a copy; short-circuit up/downcasting
1107+
if using_cow:
1108+
return [self.copy(deep=False)]
11051109
return [self.copy()]
11061110

11071111
if other is lib.no_default:
@@ -1121,8 +1125,10 @@ def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
11211125
# no need to split columns
11221126

11231127
block = self.coerce_to_target_dtype(other)
1124-
blocks = block.where(orig_other, cond)
1125-
return self._maybe_downcast(blocks, downcast=_downcast)
1128+
blocks = block.where(orig_other, cond, using_cow=using_cow)
1129+
return self._maybe_downcast(
1130+
blocks, downcast=_downcast, using_cow=using_cow
1131+
)
11261132

11271133
else:
11281134
# since _maybe_downcast would split blocks anyway, we
@@ -1139,7 +1145,9 @@ def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
11391145
oth = other[:, i : i + 1]
11401146

11411147
submask = cond[:, i : i + 1]
1142-
rbs = nb.where(oth, submask, _downcast=_downcast)
1148+
rbs = nb.where(
1149+
oth, submask, _downcast=_downcast, using_cow=using_cow
1150+
)
11431151
res_blocks.extend(rbs)
11441152
return res_blocks
11451153

@@ -1528,7 +1536,9 @@ def setitem(self, indexer, value, using_cow: bool = False):
15281536
else:
15291537
return self
15301538

1531-
def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
1539+
def where(
1540+
self, other, cond, _downcast: str | bool = "infer", using_cow: bool = False
1541+
) -> list[Block]:
15321542
# _downcast private bc we only specify it when calling from fillna
15331543
arr = self.values.T
15341544

@@ -1546,6 +1556,8 @@ def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
15461556
if noop:
15471557
# GH#44181, GH#45135
15481558
# Avoid a) raising for Interval/PeriodDtype and b) unnecessary object upcast
1559+
if using_cow:
1560+
return [self.copy(deep=False)]
15491561
return [self.copy()]
15501562

15511563
try:
@@ -1557,15 +1569,19 @@ def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
15571569
if is_interval_dtype(self.dtype):
15581570
# TestSetitemFloatIntervalWithIntIntervalValues
15591571
blk = self.coerce_to_target_dtype(orig_other)
1560-
nbs = blk.where(orig_other, orig_cond)
1561-
return self._maybe_downcast(nbs, downcast=_downcast)
1572+
nbs = blk.where(orig_other, orig_cond, using_cow=using_cow)
1573+
return self._maybe_downcast(
1574+
nbs, downcast=_downcast, using_cow=using_cow
1575+
)
15621576

15631577
elif isinstance(self, NDArrayBackedExtensionBlock):
15641578
# NB: not (yet) the same as
15651579
# isinstance(values, NDArrayBackedExtensionArray)
15661580
blk = self.coerce_to_target_dtype(orig_other)
1567-
nbs = blk.where(orig_other, orig_cond)
1568-
return self._maybe_downcast(nbs, downcast=_downcast)
1581+
nbs = blk.where(orig_other, orig_cond, using_cow=using_cow)
1582+
return self._maybe_downcast(
1583+
nbs, downcast=_downcast, using_cow=using_cow
1584+
)
15691585

15701586
else:
15711587
raise
@@ -1583,7 +1599,7 @@ def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
15831599
n = orig_other[:, i : i + 1]
15841600

15851601
submask = orig_cond[:, i : i + 1]
1586-
rbs = nb.where(n, submask)
1602+
rbs = nb.where(n, submask, using_cow=using_cow)
15871603
res_blocks.extend(rbs)
15881604
return res_blocks
15891605

pandas/core/internals/managers.py

+1
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def where(self: T, other, cond, align: bool) -> T:
341341
align_keys=align_keys,
342342
other=other,
343343
cond=cond,
344+
using_cow=using_copy_on_write(),
344345
)
345346

346347
def setitem(self: T, indexer, value) -> T:

pandas/tests/copy_view/test_methods.py

+48
Original file line numberDiff line numberDiff line change
@@ -1309,6 +1309,54 @@ def test_putmask_dont_copy_some_blocks(using_copy_on_write, val, exp):
13091309
assert view.iloc[0, 0] == 5
13101310

13111311

1312+
@pytest.mark.parametrize("dtype", ["int64", "Int64"])
1313+
def test_where_noop(using_copy_on_write, dtype):
1314+
ser = Series([1, 2, 3], dtype=dtype)
1315+
ser_orig = ser.copy()
1316+
1317+
result = ser.where(ser > 0, 10)
1318+
1319+
if using_copy_on_write:
1320+
assert np.shares_memory(get_array(ser), get_array(result))
1321+
else:
1322+
assert not np.shares_memory(get_array(ser), get_array(result))
1323+
1324+
result.iloc[0] = 10
1325+
if using_copy_on_write:
1326+
assert not np.shares_memory(get_array(ser), get_array(result))
1327+
tm.assert_series_equal(ser, ser_orig)
1328+
1329+
1330+
@pytest.mark.parametrize("dtype", ["int64", "Int64"])
1331+
def test_where(using_copy_on_write, dtype):
1332+
ser = Series([1, 2, 3], dtype=dtype)
1333+
ser_orig = ser.copy()
1334+
1335+
result = ser.where(ser < 0, 10)
1336+
1337+
assert not np.shares_memory(get_array(ser), get_array(result))
1338+
tm.assert_series_equal(ser, ser_orig)
1339+
1340+
1341+
@pytest.mark.parametrize("dtype, val", [("int64", 10.5), ("Int64", 10)])
1342+
def test_where_noop_on_single_column(using_copy_on_write, dtype, val):
1343+
df = DataFrame({"a": [1, 2, 3], "b": [-4, -5, -6]}, dtype=dtype)
1344+
df_orig = df.copy()
1345+
1346+
result = df.where(df < 0, val)
1347+
1348+
if using_copy_on_write:
1349+
assert np.shares_memory(get_array(df, "b"), get_array(result, "b"))
1350+
assert not np.shares_memory(get_array(df, "a"), get_array(result, "a"))
1351+
else:
1352+
assert not np.shares_memory(get_array(df, "b"), get_array(result, "b"))
1353+
1354+
result.iloc[0, 1] = 10
1355+
if using_copy_on_write:
1356+
assert not np.shares_memory(get_array(df, "b"), get_array(result, "b"))
1357+
tm.assert_frame_equal(df, df_orig)
1358+
1359+
13121360
def test_asfreq_noop(using_copy_on_write):
13131361
df = DataFrame(
13141362
{"a": [0.0, None, 2.0, 3.0]},

pandas/tests/copy_view/util.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ def get_array(obj, col=None):
1111
this is done by some other operation).
1212
"""
1313
if isinstance(obj, Series) and (col is None or obj.name == col):
14-
return obj._values
14+
arr = obj._values
15+
if isinstance(arr, BaseMaskedArray):
16+
return arr._data
17+
return arr
1518
assert col is not None
1619
icol = obj.columns.get_loc(col)
1720
assert isinstance(icol, int)

0 commit comments

Comments
 (0)