Skip to content

ENH: Add lazy copy to where #51336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ Copy-on-Write improvements
- :meth:`DataFrame.interpolate` / :meth:`Series.interpolate`
- :meth:`DataFrame.ffill` / :meth:`Series.ffill`
- :meth:`DataFrame.bfill` / :meth:`Series.bfill`
- :meth:`DataFrame.where` / :meth:`Series.where`
- :meth:`DataFrame.infer_objects` / :meth:`Series.infer_objects`
- :meth:`DataFrame.astype` / :meth:`Series.astype`
- :meth:`DataFrame.convert_dtypes` / :meth:`Series.convert_dtypes`
Expand Down
36 changes: 26 additions & 10 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,9 @@ def putmask(self, mask, new, using_cow: bool = False) -> list[Block]:
res_blocks.extend(rbs)
return res_blocks

def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
def where(
self, other, cond, _downcast: str | bool = "infer", using_cow: bool = False
) -> list[Block]:
"""
evaluate the block; return result block(s) from the result

Expand Down Expand Up @@ -1101,6 +1103,8 @@ def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
icond, noop = validate_putmask(values, ~cond)
if noop:
# GH-39595: Always return a copy; short-circuit up/downcasting
if using_cow:
return [self.copy(deep=False)]
return [self.copy()]

if other is lib.no_default:
Expand All @@ -1120,8 +1124,10 @@ def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
# no need to split columns

block = self.coerce_to_target_dtype(other)
blocks = block.where(orig_other, cond)
return self._maybe_downcast(blocks, downcast=_downcast)
blocks = block.where(orig_other, cond, using_cow=using_cow)
return self._maybe_downcast(
blocks, downcast=_downcast, using_cow=using_cow
)

else:
# since _maybe_downcast would split blocks anyway, we
Expand All @@ -1138,7 +1144,9 @@ def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
oth = other[:, i : i + 1]

submask = cond[:, i : i + 1]
rbs = nb.where(oth, submask, _downcast=_downcast)
rbs = nb.where(
oth, submask, _downcast=_downcast, using_cow=using_cow
)
res_blocks.extend(rbs)
return res_blocks

Expand Down Expand Up @@ -1527,7 +1535,9 @@ def setitem(self, indexer, value, using_cow: bool = False):
else:
return self

def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
def where(
self, other, cond, _downcast: str | bool = "infer", using_cow: bool = False
) -> list[Block]:
# _downcast private bc we only specify it when calling from fillna
arr = self.values.T

Expand All @@ -1545,6 +1555,8 @@ def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
if noop:
# GH#44181, GH#45135
# Avoid a) raising for Interval/PeriodDtype and b) unnecessary object upcast
if using_cow:
return [self.copy(deep=False)]
return [self.copy()]

try:
Expand All @@ -1556,15 +1568,19 @@ def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
if is_interval_dtype(self.dtype):
# TestSetitemFloatIntervalWithIntIntervalValues
blk = self.coerce_to_target_dtype(orig_other)
nbs = blk.where(orig_other, orig_cond)
return self._maybe_downcast(nbs, downcast=_downcast)
nbs = blk.where(orig_other, orig_cond, using_cow=using_cow)
return self._maybe_downcast(
nbs, downcast=_downcast, using_cow=using_cow
)

elif isinstance(self, NDArrayBackedExtensionBlock):
# NB: not (yet) the same as
# isinstance(values, NDArrayBackedExtensionArray)
blk = self.coerce_to_target_dtype(orig_other)
nbs = blk.where(orig_other, orig_cond)
return self._maybe_downcast(nbs, downcast=_downcast)
nbs = blk.where(orig_other, orig_cond, using_cow=using_cow)
return self._maybe_downcast(
nbs, downcast=_downcast, using_cow=using_cow
)

else:
raise
Expand All @@ -1582,7 +1598,7 @@ def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
n = orig_other[:, i : i + 1]

submask = orig_cond[:, i : i + 1]
rbs = nb.where(n, submask)
rbs = nb.where(n, submask, using_cow=using_cow)
res_blocks.extend(rbs)
return res_blocks

Expand Down
1 change: 1 addition & 0 deletions pandas/core/internals/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def where(self: T, other, cond, align: bool) -> T:
align_keys=align_keys,
other=other,
cond=cond,
using_cow=using_copy_on_write(),
)

def setitem(self: T, indexer, value) -> T:
Expand Down
48 changes: 48 additions & 0 deletions pandas/tests/copy_view/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,6 +1309,54 @@ def test_putmask_dont_copy_some_blocks(using_copy_on_write, val, exp):
assert view.iloc[0, 0] == 5


@pytest.mark.parametrize("dtype", ["int64", "Int64"])
def test_where_noop(using_copy_on_write, dtype):
ser = Series([1, 2, 3], dtype=dtype)
ser_orig = ser.copy()

result = ser.where(ser > 0, 10)

if using_copy_on_write:
assert np.shares_memory(get_array(ser), get_array(result))
else:
assert not np.shares_memory(get_array(ser), get_array(result))

result.iloc[0] = 10
if using_copy_on_write:
assert not np.shares_memory(get_array(ser), get_array(result))
tm.assert_series_equal(ser, ser_orig)


@pytest.mark.parametrize("dtype", ["int64", "Int64"])
def test_where(using_copy_on_write, dtype):
ser = Series([1, 2, 3], dtype=dtype)
ser_orig = ser.copy()

result = ser.where(ser < 0, 10)

assert not np.shares_memory(get_array(ser), get_array(result))
tm.assert_series_equal(ser, ser_orig)


@pytest.mark.parametrize("dtype, val", [("int64", 10.5), ("Int64", 10)])
def test_where_noop_on_single_column(using_copy_on_write, dtype, val):
df = DataFrame({"a": [1, 2, 3], "b": [-4, -5, -6]}, dtype=dtype)
df_orig = df.copy()

result = df.where(df < 0, val)

if using_copy_on_write:
assert np.shares_memory(get_array(df, "b"), get_array(result, "b"))
assert not np.shares_memory(get_array(df, "a"), get_array(result, "a"))
else:
assert not np.shares_memory(get_array(df, "b"), get_array(result, "b"))

result.iloc[0, 1] = 10
if using_copy_on_write:
assert not np.shares_memory(get_array(df, "b"), get_array(result, "b"))
tm.assert_frame_equal(df, df_orig)


def test_asfreq_noop(using_copy_on_write):
df = DataFrame(
{"a": [0.0, None, 2.0, 3.0]},
Expand Down
5 changes: 4 additions & 1 deletion pandas/tests/copy_view/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ def get_array(obj, col=None):
this is done by some other operation).
"""
if isinstance(obj, Series) and (col is None or obj.name == col):
return obj._values
arr = obj._values
if isinstance(arr, BaseMaskedArray):
return arr._data
return arr
assert col is not None
icol = obj.columns.get_loc(col)
assert isinstance(icol, int)
Expand Down