Skip to content

Commit 09593b2

Browse files
authored
CoW: __array__ not recognizing ea dtypes (#51966)
1 parent d54bd78 commit 09593b2

File tree

3 files changed

+77
-6
lines changed

3 files changed

+77
-6
lines changed

pandas/core/generic.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
validate_inclusive,
104104
)
105105

106+
from pandas.core.dtypes.astype import astype_is_view
106107
from pandas.core.dtypes.common import (
107108
ensure_object,
108109
ensure_platform_int,
@@ -2012,10 +2013,17 @@ def empty(self) -> bool_t:
20122013
def __array__(self, dtype: npt.DTypeLike | None = None) -> np.ndarray:
20132014
values = self._values
20142015
arr = np.asarray(values, dtype=dtype)
2015-
if arr is values and using_copy_on_write():
2016-
# TODO(CoW) also properly handle extension dtypes
2017-
arr = arr.view()
2018-
arr.flags.writeable = False
2016+
if (
2017+
astype_is_view(values.dtype, arr.dtype)
2018+
and using_copy_on_write()
2019+
and self._mgr.is_single_block
2020+
):
2021+
# Check if both conversions can be done without a copy
2022+
if astype_is_view(self.dtypes.iloc[0], values.dtype) and astype_is_view(
2023+
values.dtype, arr.dtype
2024+
):
2025+
arr = arr.view()
2026+
arr.flags.writeable = False
20192027
return arr
20202028

20212029
@final

pandas/core/series.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -921,8 +921,7 @@ def __array__(self, dtype: npt.DTypeLike | None = None) -> np.ndarray:
921921
"""
922922
values = self._values
923923
arr = np.asarray(values, dtype=dtype)
924-
if arr is values and using_copy_on_write():
925-
# TODO(CoW) also properly handle extension dtypes
924+
if using_copy_on_write() and astype_is_view(values.dtype, arr.dtype):
926925
arr = arr.view()
927926
arr.flags.writeable = False
928927
return arr

pandas/tests/copy_view/test_array.py

+64
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pandas import (
55
DataFrame,
66
Series,
7+
date_range,
78
)
89
import pandas._testing as tm
910
from pandas.tests.copy_view.util import get_array
@@ -119,3 +120,66 @@ def test_ravel_read_only(using_copy_on_write, order):
119120
if using_copy_on_write:
120121
assert arr.flags.writeable is False
121122
assert np.shares_memory(get_array(ser), arr)
123+
124+
125+
def test_series_array_ea_dtypes(using_copy_on_write):
126+
ser = Series([1, 2, 3], dtype="Int64")
127+
arr = np.asarray(ser, dtype="int64")
128+
assert np.shares_memory(arr, get_array(ser))
129+
if using_copy_on_write:
130+
assert arr.flags.writeable is False
131+
else:
132+
assert arr.flags.writeable is True
133+
134+
arr = np.asarray(ser)
135+
assert not np.shares_memory(arr, get_array(ser))
136+
assert arr.flags.writeable is True
137+
138+
139+
def test_dataframe_array_ea_dtypes(using_copy_on_write):
140+
df = DataFrame({"a": [1, 2, 3]}, dtype="Int64")
141+
arr = np.asarray(df, dtype="int64")
142+
# TODO: This should be able to share memory, but we are roundtripping
143+
# through object
144+
assert not np.shares_memory(arr, get_array(df, "a"))
145+
assert arr.flags.writeable is True
146+
147+
arr = np.asarray(df)
148+
if using_copy_on_write:
149+
# TODO(CoW): This should be True
150+
assert arr.flags.writeable is False
151+
else:
152+
assert arr.flags.writeable is True
153+
154+
155+
def test_dataframe_array_string_dtype(using_copy_on_write, using_array_manager):
156+
df = DataFrame({"a": ["a", "b"]}, dtype="string")
157+
arr = np.asarray(df)
158+
if not using_array_manager:
159+
assert np.shares_memory(arr, get_array(df, "a"))
160+
if using_copy_on_write:
161+
assert arr.flags.writeable is False
162+
else:
163+
assert arr.flags.writeable is True
164+
165+
166+
def test_dataframe_multiple_numpy_dtypes():
167+
df = DataFrame({"a": [1, 2, 3], "b": 1.5})
168+
arr = np.asarray(df)
169+
assert not np.shares_memory(arr, get_array(df, "a"))
170+
assert arr.flags.writeable is True
171+
172+
173+
def test_values_is_ea(using_copy_on_write):
174+
df = DataFrame({"a": date_range("2012-01-01", periods=3)})
175+
arr = np.asarray(df)
176+
if using_copy_on_write:
177+
assert arr.flags.writeable is False
178+
else:
179+
assert arr.flags.writeable is True
180+
181+
182+
def test_empty_dataframe():
183+
df = DataFrame()
184+
arr = np.asarray(df)
185+
assert arr.flags.writeable is True

0 commit comments

Comments
 (0)