Skip to content

Commit 608c2d8

Browse files
Backport PR #51411 on branch 2.0.x (ENH: Optimize CoW for fillna with ea dtypes) (#51639)
Backport PR #51411: ENH: Optimize CoW for fillna with ea dtypes Co-authored-by: Patrick Hoefler <[email protected]>
1 parent 7cc1791 commit 608c2d8

File tree

10 files changed

+98
-45
lines changed

10 files changed

+98
-45
lines changed

pandas/core/internals/blocks.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -1758,9 +1758,14 @@ def fillna(
17581758
downcast=downcast,
17591759
using_cow=using_cow,
17601760
)
1761-
new_values = self.values.fillna(value=value, method=None, limit=limit)
1762-
nb = self.make_block_same_class(new_values)
1763-
return nb._maybe_downcast([nb], downcast)
1761+
if using_cow and self._can_hold_na and not self.values._hasna:
1762+
refs = self.refs
1763+
new_values = self.values
1764+
else:
1765+
refs = None
1766+
new_values = self.values.fillna(value=value, method=None, limit=limit)
1767+
nb = self.make_block_same_class(new_values, refs=refs)
1768+
return nb._maybe_downcast([nb], downcast, using_cow=using_cow)
17641769

17651770
@cache_readonly
17661771
def shape(self) -> Shape:

pandas/tests/copy_view/test_interp_fillna.py

+46
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
from pandas import (
5+
NA,
56
DataFrame,
67
Interval,
78
NaT,
@@ -271,3 +272,48 @@ def test_fillna_series_empty_arg_inplace(using_copy_on_write):
271272
assert np.shares_memory(get_array(ser), arr)
272273
if using_copy_on_write:
273274
assert ser._mgr._has_no_reference(0)
275+
276+
277+
def test_fillna_ea_noop_shares_memory(
278+
using_copy_on_write, any_numeric_ea_and_arrow_dtype
279+
):
280+
df = DataFrame({"a": [1, NA, 3], "b": 1}, dtype=any_numeric_ea_and_arrow_dtype)
281+
df_orig = df.copy()
282+
df2 = df.fillna(100)
283+
284+
assert not np.shares_memory(get_array(df, "a"), get_array(df2, "a"))
285+
286+
if using_copy_on_write:
287+
assert np.shares_memory(get_array(df, "b"), get_array(df2, "b"))
288+
assert not df2._mgr._has_no_reference(1)
289+
else:
290+
assert not np.shares_memory(get_array(df, "b"), get_array(df2, "b"))
291+
292+
tm.assert_frame_equal(df_orig, df)
293+
294+
df2.iloc[0, 1] = 100
295+
if using_copy_on_write:
296+
assert not np.shares_memory(get_array(df, "b"), get_array(df2, "b"))
297+
assert df2._mgr._has_no_reference(1)
298+
assert df._mgr._has_no_reference(1)
299+
tm.assert_frame_equal(df_orig, df)
300+
301+
302+
def test_fillna_inplace_ea_noop_shares_memory(
303+
using_copy_on_write, any_numeric_ea_and_arrow_dtype
304+
):
305+
df = DataFrame({"a": [1, NA, 3], "b": 1}, dtype=any_numeric_ea_and_arrow_dtype)
306+
df_orig = df.copy()
307+
view = df[:]
308+
df.fillna(100, inplace=True)
309+
310+
assert not np.shares_memory(get_array(df, "a"), get_array(view, "a"))
311+
312+
if using_copy_on_write:
313+
assert np.shares_memory(get_array(df, "b"), get_array(view, "b"))
314+
assert not df._mgr._has_no_reference(1)
315+
assert not view._mgr._has_no_reference(1)
316+
else:
317+
assert not np.shares_memory(get_array(df, "b"), get_array(view, "b"))
318+
df.iloc[0, 1] = 100
319+
tm.assert_frame_equal(df_orig, view)

pandas/tests/extension/base/methods.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -243,24 +243,25 @@ def test_factorize_empty(self, data):
243243
def test_fillna_copy_frame(self, data_missing):
244244
arr = data_missing.take([1, 1])
245245
df = pd.DataFrame({"A": arr})
246+
df_orig = df.copy()
246247

247248
filled_val = df.iloc[0, 0]
248249
result = df.fillna(filled_val)
249250

250-
assert df.A.values is not result.A.values
251+
result.iloc[0, 0] = filled_val
251252

252-
def test_fillna_copy_series(self, data_missing, no_op_with_cow: bool = False):
253+
self.assert_frame_equal(df, df_orig)
254+
255+
def test_fillna_copy_series(self, data_missing):
253256
arr = data_missing.take([1, 1])
254257
ser = pd.Series(arr)
258+
ser_orig = ser.copy()
255259

256260
filled_val = ser[0]
257261
result = ser.fillna(filled_val)
262+
result.iloc[0] = filled_val
258263

259-
if no_op_with_cow:
260-
assert ser._values is result._values
261-
else:
262-
assert ser._values is not result._values
263-
assert ser._values is arr
264+
self.assert_series_equal(ser, ser_orig)
264265

265266
def test_fillna_length_mismatch(self, data_missing):
266267
msg = "Length of 'value' does not match."

pandas/tests/extension/conftest.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
import pytest
44

5-
from pandas import Series
5+
from pandas import (
6+
Series,
7+
options,
8+
)
69

710

811
@pytest.fixture
@@ -193,3 +196,11 @@ def invalid_scalar(data):
193196
If the array can hold any item (i.e. object dtype), then use pytest.skip.
194197
"""
195198
return object.__new__(object)
199+
200+
201+
@pytest.fixture
202+
def using_copy_on_write() -> bool:
203+
"""
204+
Fixture to check if Copy-on-Write is enabled.
205+
"""
206+
return options.mode.copy_on_write and options.mode.data_manager == "block"

pandas/tests/extension/json/test_json.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -301,17 +301,9 @@ def test_searchsorted(self, data_for_sorting):
301301
def test_equals(self, data, na_value, as_series):
302302
super().test_equals(data, na_value, as_series)
303303

304-
def test_fillna_copy_frame(self, data_missing, using_copy_on_write):
305-
arr = data_missing.take([1, 1])
306-
df = pd.DataFrame({"A": arr})
307-
308-
filled_val = df.iloc[0, 0]
309-
result = df.fillna(filled_val)
310-
311-
if using_copy_on_write:
312-
assert df.A.values is result.A.values
313-
else:
314-
assert df.A.values is not result.A.values
304+
@pytest.mark.skip("fill-value is interpreted as a dict of values")
305+
def test_fillna_copy_frame(self, data_missing):
306+
super().test_fillna_copy_frame(data_missing)
315307

316308

317309
class TestCasting(BaseJSON, base.BaseCastingTests):

pandas/tests/extension/test_datetime.py

-5
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,6 @@ def test_combine_add(self, data_repeated):
116116
# Timestamp.__add__(Timestamp) not defined
117117
pass
118118

119-
def test_fillna_copy_series(self, data_missing, using_copy_on_write):
120-
super().test_fillna_copy_series(
121-
data_missing, no_op_with_cow=using_copy_on_write
122-
)
123-
124119

125120
class TestInterface(BaseDatetimeTests, base.BaseInterfaceTests):
126121
pass

pandas/tests/extension/test_interval.py

-5
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,6 @@ def test_combine_add(self, data_repeated):
132132
def test_fillna_length_mismatch(self, data_missing):
133133
super().test_fillna_length_mismatch(data_missing)
134134

135-
def test_fillna_copy_series(self, data_missing, using_copy_on_write):
136-
super().test_fillna_copy_series(
137-
data_missing, no_op_with_cow=using_copy_on_write
138-
)
139-
140135

141136
class TestMissing(BaseInterval, base.BaseMissingTests):
142137
# Index.fillna only accepts scalar `value`, so we have to xfail all

pandas/tests/extension/test_period.py

-5
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,6 @@ def test_diff(self, data, periods):
105105
else:
106106
super().test_diff(data, periods)
107107

108-
def test_fillna_copy_series(self, data_missing, using_copy_on_write):
109-
super().test_fillna_copy_series(
110-
data_missing, no_op_with_cow=using_copy_on_write
111-
)
112-
113108

114109
class TestInterface(BasePeriodTests, base.BaseInterfaceTests):
115110
pass

pandas/tests/extension/test_sparse.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -272,25 +272,32 @@ def test_fillna_frame(self, data_missing):
272272
class TestMethods(BaseSparseTests, base.BaseMethodsTests):
273273
_combine_le_expected_dtype = "Sparse[bool]"
274274

275-
def test_fillna_copy_frame(self, data_missing):
275+
def test_fillna_copy_frame(self, data_missing, using_copy_on_write):
276276
arr = data_missing.take([1, 1])
277277
df = pd.DataFrame({"A": arr}, copy=False)
278278

279279
filled_val = df.iloc[0, 0]
280280
result = df.fillna(filled_val)
281281

282282
if hasattr(df._mgr, "blocks"):
283-
assert df.values.base is not result.values.base
283+
if using_copy_on_write:
284+
assert df.values.base is result.values.base
285+
else:
286+
assert df.values.base is not result.values.base
284287
assert df.A._values.to_dense() is arr.to_dense()
285288

286-
def test_fillna_copy_series(self, data_missing):
289+
def test_fillna_copy_series(self, data_missing, using_copy_on_write):
287290
arr = data_missing.take([1, 1])
288291
ser = pd.Series(arr)
289292

290293
filled_val = ser[0]
291294
result = ser.fillna(filled_val)
292295

293-
assert ser._values is not result._values
296+
if using_copy_on_write:
297+
assert ser._values is result._values
298+
299+
else:
300+
assert ser._values is not result._values
294301
assert ser._values.to_dense() is arr.to_dense()
295302

296303
@pytest.mark.xfail(reason="Not Applicable")

pandas/tests/groupby/test_raises.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,9 @@ def test_groupby_raises_datetime_np(how, by, groupby_series, groupby_func_np):
303303

304304

305305
@pytest.mark.parametrize("how", ["method", "agg", "transform"])
306-
def test_groupby_raises_category(how, by, groupby_series, groupby_func):
306+
def test_groupby_raises_category(
307+
how, by, groupby_series, groupby_func, using_copy_on_write
308+
):
307309
# GH#50749
308310
df = DataFrame(
309311
{
@@ -370,7 +372,9 @@ def test_groupby_raises_category(how, by, groupby_series, groupby_func):
370372
TypeError,
371373
r"Cannot setitem on a Categorical with a new category \(0\), "
372374
+ "set the categories first",
373-
),
375+
)
376+
if not using_copy_on_write
377+
else (None, ""), # no-op with CoW
374378
"first": (None, ""),
375379
"idxmax": (None, ""),
376380
"idxmin": (None, ""),
@@ -491,7 +495,7 @@ def test_groupby_raises_category_np(how, by, groupby_series, groupby_func_np):
491495

492496
@pytest.mark.parametrize("how", ["method", "agg", "transform"])
493497
def test_groupby_raises_category_on_category(
494-
how, by, groupby_series, groupby_func, observed
498+
how, by, groupby_series, groupby_func, observed, using_copy_on_write
495499
):
496500
# GH#50749
497501
df = DataFrame(
@@ -562,7 +566,9 @@ def test_groupby_raises_category_on_category(
562566
TypeError,
563567
r"Cannot setitem on a Categorical with a new category \(0\), "
564568
+ "set the categories first",
565-
),
569+
)
570+
if not using_copy_on_write
571+
else (None, ""), # no-op with CoW
566572
"first": (None, ""),
567573
"idxmax": (ValueError, "attempt to get argmax of an empty sequence")
568574
if empty_groups

0 commit comments

Comments
 (0)