Skip to content

Commit d744bdb

Browse files
authored
ENH: Add CoW optimization for fillna (#51279)
1 parent ddceb8e commit d744bdb

File tree

8 files changed

+124
-16
lines changed

8 files changed

+124
-16
lines changed

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ Copy-on-Write improvements
223223
- :meth:`DataFrame.to_period` / :meth:`Series.to_period`
224224
- :meth:`DataFrame.truncate`
225225
- :meth:`DataFrame.tz_convert` / :meth:`Series.tz_localize`
226+
- :meth:`DataFrame.fillna` / :meth:`Series.fillna`
226227
- :meth:`DataFrame.interpolate` / :meth:`Series.interpolate`
227228
- :meth:`DataFrame.ffill` / :meth:`Series.ffill`
228229
- :meth:`DataFrame.bfill` / :meth:`Series.bfill`

pandas/core/internals/blocks.py

+27-7
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,7 @@ def putmask(self, mask, new, using_cow: bool = False) -> list[Block]:
10141014
----------
10151015
mask : np.ndarray[bool], SparseArray[bool], or BooleanArray
10161016
new : a ndarray/object
1017+
using_cow: bool, default False
10171018
10181019
Returns
10191020
-------
@@ -1188,7 +1189,12 @@ def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
11881189
return [self.make_block(result)]
11891190

11901191
def fillna(
1191-
self, value, limit: int | None = None, inplace: bool = False, downcast=None
1192+
self,
1193+
value,
1194+
limit: int | None = None,
1195+
inplace: bool = False,
1196+
downcast=None,
1197+
using_cow: bool = False,
11921198
) -> list[Block]:
11931199
"""
11941200
fillna on the block with the value. If we fail, then convert to
@@ -1207,20 +1213,22 @@ def fillna(
12071213
if noop:
12081214
# we can't process the value, but nothing to do
12091215
if inplace:
1216+
if using_cow:
1217+
return [self.copy(deep=False)]
12101218
# Arbitrarily imposing the convention that we ignore downcast
12111219
# on no-op when inplace=True
12121220
return [self]
12131221
else:
12141222
# GH#45423 consistent downcasting on no-ops.
1215-
nb = self.copy()
1216-
nbs = nb._maybe_downcast([nb], downcast=downcast)
1223+
nb = self.copy(deep=not using_cow)
1224+
nbs = nb._maybe_downcast([nb], downcast=downcast, using_cow=using_cow)
12171225
return nbs
12181226

12191227
if limit is not None:
12201228
mask[mask.cumsum(self.ndim - 1) > limit] = False
12211229

12221230
if inplace:
1223-
nbs = self.putmask(mask.T, value)
1231+
nbs = self.putmask(mask.T, value, using_cow=using_cow)
12241232
else:
12251233
# without _downcast, we would break
12261234
# test_fillna_dtype_conversion_equiv_replace
@@ -1230,7 +1238,10 @@ def fillna(
12301238
# makes a difference bc blk may have object dtype, which has
12311239
# different behavior in _maybe_downcast.
12321240
return extend_blocks(
1233-
[blk._maybe_downcast([blk], downcast=downcast) for blk in nbs]
1241+
[
1242+
blk._maybe_downcast([blk], downcast=downcast, using_cow=using_cow)
1243+
for blk in nbs
1244+
]
12341245
)
12351246

12361247
def interpolate(
@@ -1725,12 +1736,21 @@ class ExtensionBlock(libinternals.Block, EABackedBlock):
17251736
values: ExtensionArray
17261737

17271738
def fillna(
1728-
self, value, limit: int | None = None, inplace: bool = False, downcast=None
1739+
self,
1740+
value,
1741+
limit: int | None = None,
1742+
inplace: bool = False,
1743+
downcast=None,
1744+
using_cow: bool = False,
17291745
) -> list[Block]:
17301746
if is_interval_dtype(self.dtype):
17311747
# Block.fillna handles coercion (test_fillna_interval)
17321748
return super().fillna(
1733-
value=value, limit=limit, inplace=inplace, downcast=downcast
1749+
value=value,
1750+
limit=limit,
1751+
inplace=inplace,
1752+
downcast=downcast,
1753+
using_cow=using_cow,
17341754
)
17351755
new_values = self.values.fillna(value=value, method=None, limit=limit)
17361756
nb = self.make_block_same_class(new_values)

pandas/core/internals/managers.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -398,15 +398,14 @@ def fillna(self: T, value, limit, inplace: bool, downcast) -> T:
398398
if limit is not None:
399399
# Do this validation even if we go through one of the no-op paths
400400
limit = libalgos.validate_limit(None, limit=limit)
401-
if inplace:
402-
# TODO(CoW) can be optimized to only copy those blocks that have refs
403-
if using_copy_on_write() and any(
404-
not self._has_no_reference_block(i) for i in range(len(self.blocks))
405-
):
406-
self = self.copy()
407401

408402
return self.apply(
409-
"fillna", value=value, limit=limit, inplace=inplace, downcast=downcast
403+
"fillna",
404+
value=value,
405+
limit=limit,
406+
inplace=inplace,
407+
downcast=downcast,
408+
using_cow=using_copy_on_write(),
410409
)
411410

412411
def astype(self: T, dtype, copy: bool | None = False, errors: str = "raise") -> T:

pandas/tests/copy_view/test_interp_fillna.py

+70
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33

44
from pandas import (
55
DataFrame,
6+
Interval,
67
NaT,
78
Series,
89
Timestamp,
10+
interval_range,
911
)
1012
import pandas._testing as tm
1113
from pandas.tests.copy_view.util import get_array
@@ -162,3 +164,71 @@ def test_interpolate_downcast_reference_triggers_copy(using_copy_on_write):
162164
tm.assert_frame_equal(df_orig, view)
163165
else:
164166
tm.assert_frame_equal(df, view)
167+
168+
169+
def test_fillna(using_copy_on_write):
170+
df = DataFrame({"a": [1.5, np.nan], "b": 1})
171+
df_orig = df.copy()
172+
173+
df2 = df.fillna(5.5)
174+
if using_copy_on_write:
175+
assert np.shares_memory(get_array(df, "b"), get_array(df2, "b"))
176+
else:
177+
assert not np.shares_memory(get_array(df, "b"), get_array(df2, "b"))
178+
179+
df2.iloc[0, 1] = 100
180+
tm.assert_frame_equal(df_orig, df)
181+
182+
183+
@pytest.mark.parametrize("downcast", [None, False])
184+
def test_fillna_inplace(using_copy_on_write, downcast):
185+
df = DataFrame({"a": [1.5, np.nan], "b": 1})
186+
arr_a = get_array(df, "a")
187+
arr_b = get_array(df, "b")
188+
189+
df.fillna(5.5, inplace=True, downcast=downcast)
190+
assert np.shares_memory(get_array(df, "a"), arr_a)
191+
assert np.shares_memory(get_array(df, "b"), arr_b)
192+
if using_copy_on_write:
193+
assert df._mgr._has_no_reference(0)
194+
assert df._mgr._has_no_reference(1)
195+
196+
197+
def test_fillna_inplace_reference(using_copy_on_write):
198+
df = DataFrame({"a": [1.5, np.nan], "b": 1})
199+
df_orig = df.copy()
200+
arr_a = get_array(df, "a")
201+
arr_b = get_array(df, "b")
202+
view = df[:]
203+
204+
df.fillna(5.5, inplace=True)
205+
if using_copy_on_write:
206+
assert not np.shares_memory(get_array(df, "a"), arr_a)
207+
assert np.shares_memory(get_array(df, "b"), arr_b)
208+
assert view._mgr._has_no_reference(0)
209+
assert df._mgr._has_no_reference(0)
210+
tm.assert_frame_equal(view, df_orig)
211+
else:
212+
assert np.shares_memory(get_array(df, "a"), arr_a)
213+
assert np.shares_memory(get_array(df, "b"), arr_b)
214+
expected = DataFrame({"a": [1.5, 5.5], "b": 1})
215+
tm.assert_frame_equal(df, expected)
216+
217+
218+
def test_fillna_interval_inplace_reference(using_copy_on_write):
219+
ser = Series(interval_range(start=0, end=5), name="a")
220+
ser.iloc[1] = np.nan
221+
222+
ser_orig = ser.copy()
223+
view = ser[:]
224+
ser.fillna(value=Interval(left=0, right=5), inplace=True)
225+
226+
if using_copy_on_write:
227+
assert not np.shares_memory(
228+
get_array(ser, "a").left.values, get_array(view, "a").left.values
229+
)
230+
tm.assert_series_equal(view, ser_orig)
231+
else:
232+
assert np.shares_memory(
233+
get_array(ser, "a").left.values, get_array(view, "a").left.values
234+
)

pandas/tests/extension/base/methods.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -249,14 +249,17 @@ def test_fillna_copy_frame(self, data_missing):
249249

250250
assert df.A.values is not result.A.values
251251

252-
def test_fillna_copy_series(self, data_missing):
252+
def test_fillna_copy_series(self, data_missing, no_op_with_cow: bool = False):
253253
arr = data_missing.take([1, 1])
254254
ser = pd.Series(arr)
255255

256256
filled_val = ser[0]
257257
result = ser.fillna(filled_val)
258258

259-
assert ser._values is not result._values
259+
if no_op_with_cow:
260+
assert ser._values is result._values
261+
else:
262+
assert ser._values is not result._values
260263
assert ser._values is arr
261264

262265
def test_fillna_length_mismatch(self, data_missing):

pandas/tests/extension/test_datetime.py

+5
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ def test_combine_add(self, data_repeated):
116116
# Timestamp.__add__(Timestamp) not defined
117117
pass
118118

119+
def test_fillna_copy_series(self, data_missing, using_copy_on_write):
120+
super().test_fillna_copy_series(
121+
data_missing, no_op_with_cow=using_copy_on_write
122+
)
123+
119124

120125
class TestInterface(BaseDatetimeTests, base.BaseInterfaceTests):
121126
pass

pandas/tests/extension/test_interval.py

+5
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ def test_combine_add(self, data_repeated):
132132
def test_fillna_length_mismatch(self, data_missing):
133133
super().test_fillna_length_mismatch(data_missing)
134134

135+
def test_fillna_copy_series(self, data_missing, using_copy_on_write):
136+
super().test_fillna_copy_series(
137+
data_missing, no_op_with_cow=using_copy_on_write
138+
)
139+
135140

136141
class TestMissing(BaseInterval, base.BaseMissingTests):
137142
# Index.fillna only accepts scalar `value`, so we have to xfail all

pandas/tests/extension/test_period.py

+5
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ def test_diff(self, data, periods):
105105
else:
106106
super().test_diff(data, periods)
107107

108+
def test_fillna_copy_series(self, data_missing, using_copy_on_write):
109+
super().test_fillna_copy_series(
110+
data_missing, no_op_with_cow=using_copy_on_write
111+
)
112+
108113

109114
class TestInterface(BasePeriodTests, base.BaseInterfaceTests):
110115

0 commit comments

Comments
 (0)