Skip to content

Commit f07e98b

Browse files
authored
ENH: Implement CoW for convert_dtypes (#51265)
1 parent ead9ced commit f07e98b

File tree

5 files changed

+46
-5
lines changed

5 files changed

+46
-5
lines changed

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ Copy-on-Write improvements
225225
- :meth:`DataFrame.tz_convert` / :meth:`Series.tz_localize`
226226
- :meth:`DataFrame.infer_objects` / :meth:`Series.infer_objects`
227227
- :meth:`DataFrame.astype` / :meth:`Series.astype`
228+
- :meth:`DataFrame.convert_dtypes` / :meth:`Series.convert_dtypes`
228229
- :func:`concat`
229230

230231
These methods return views when Copy-on-Write is enabled, which provides a significant

pandas/core/generic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6673,7 +6673,7 @@ def convert_dtypes(
66736673
# https://github.com/python/mypy/issues/8354
66746674
return cast(NDFrameT, result)
66756675
else:
6676-
return self.copy()
6676+
return self.copy(deep=None)
66776677

66786678
# ----------------------------------------------------------------------
66796679
# Filling NA's

pandas/core/series.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5468,7 +5468,7 @@ def _convert_dtypes(
54685468
if infer_objects:
54695469
input_series = input_series.infer_objects()
54705470
if is_object_dtype(input_series):
5471-
input_series = input_series.copy()
5471+
input_series = input_series.copy(deep=None)
54725472

54735473
if convert_string or convert_integer or convert_boolean or convert_floating:
54745474
dtype_backend = get_option("mode.dtype_backend")
@@ -5483,7 +5483,7 @@ def _convert_dtypes(
54835483
)
54845484
result = input_series.astype(inferred_dtype)
54855485
else:
5486-
result = input_series.copy()
5486+
result = input_series.copy(deep=None)
54875487
return result
54885488

54895489
# error: Cannot determine type of 'isna'

pandas/tests/copy_view/test_astype.py

+39
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,42 @@ def test_astype_arrow_timestamp(using_copy_on_write):
193193
if using_copy_on_write:
194194
assert not result._mgr._has_no_reference(0)
195195
assert np.shares_memory(get_array(df, "a").asi8, get_array(result, "a")._data)
196+
197+
198+
def test_convert_dtypes_infer_objects(using_copy_on_write):
199+
ser = Series(["a", "b", "c"])
200+
ser_orig = ser.copy()
201+
result = ser.convert_dtypes(
202+
convert_integer=False,
203+
convert_boolean=False,
204+
convert_floating=False,
205+
convert_string=False,
206+
)
207+
208+
if using_copy_on_write:
209+
assert np.shares_memory(get_array(ser), get_array(result))
210+
else:
211+
assert not np.shares_memory(get_array(ser), get_array(result))
212+
213+
result.iloc[0] = "x"
214+
tm.assert_series_equal(ser, ser_orig)
215+
216+
217+
def test_convert_dtypes(using_copy_on_write):
218+
df = DataFrame({"a": ["a", "b"], "b": [1, 2], "c": [1.5, 2.5], "d": [True, False]})
219+
df_orig = df.copy()
220+
df2 = df.convert_dtypes()
221+
222+
if using_copy_on_write:
223+
assert np.shares_memory(get_array(df2, "a"), get_array(df, "a"))
224+
assert np.shares_memory(get_array(df2, "d"), get_array(df, "d"))
225+
assert np.shares_memory(get_array(df2, "b"), get_array(df, "b"))
226+
assert np.shares_memory(get_array(df2, "c"), get_array(df, "c"))
227+
else:
228+
assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a"))
229+
assert not np.shares_memory(get_array(df2, "b"), get_array(df, "b"))
230+
assert not np.shares_memory(get_array(df2, "c"), get_array(df, "c"))
231+
assert not np.shares_memory(get_array(df2, "d"), get_array(df, "d"))
232+
233+
df2.iloc[0, 0] = "x"
234+
tm.assert_frame_equal(df, df_orig)

pandas/tests/copy_view/util.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22
from pandas.core.arrays import BaseMaskedArray
33

44

5-
def get_array(obj, col):
5+
def get_array(obj, col=None):
66
"""
77
Helper method to get array for a DataFrame column or a Series.
88
99
Equivalent of df[col].values, but without going through normal getitem,
1010
which triggers tracking references / CoW (and we might be testing that
1111
this is done by some other operation).
1212
"""
13-
if isinstance(obj, Series) and obj.name == col:
13+
if isinstance(obj, Series) and (obj is None or obj.name == col):
1414
return obj._values
15+
assert col is not None
1516
icol = obj.columns.get_loc(col)
1617
assert isinstance(icol, int)
1718
arr = obj._get_column_array(icol)

0 commit comments

Comments
 (0)