Skip to content

Commit 07f6c4d

Browse files
phoflmroeschke
andauthored
CoW: Track references in unstack if there is no copy (#57487)
* CoW: Track references in unstack if there is no copy * Update * Update * Update --------- Co-authored-by: Matthew Roeschke <[email protected]>
1 parent 95f911d commit 07f6c4d

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

pandas/core/reshape/reshape.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
factorize,
3636
unique,
3737
)
38+
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
3839
from pandas.core.arrays.categorical import factorize_from_iterable
3940
from pandas.core.construction import ensure_wrapped_if_datetimelike
4041
from pandas.core.frame import DataFrame
@@ -231,20 +232,31 @@ def arange_result(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.bool_]]:
231232
return new_values, mask.any(0)
232233
# TODO: in all tests we have mask.any(0).all(); can we rely on that?
233234

234-
def get_result(self, values, value_columns, fill_value) -> DataFrame:
235+
def get_result(self, obj, value_columns, fill_value) -> DataFrame:
236+
values = obj._values
235237
if values.ndim == 1:
236238
values = values[:, np.newaxis]
237239

238240
if value_columns is None and values.shape[1] != 1: # pragma: no cover
239241
raise ValueError("must pass column labels for multi-column data")
240242

241-
values, _ = self.get_new_values(values, fill_value)
243+
new_values, _ = self.get_new_values(values, fill_value)
242244
columns = self.get_new_columns(value_columns)
243245
index = self.new_index
244246

245-
return self.constructor(
246-
values, index=index, columns=columns, dtype=values.dtype
247+
result = self.constructor(
248+
new_values, index=index, columns=columns, dtype=new_values.dtype, copy=False
247249
)
250+
if isinstance(values, np.ndarray):
251+
base, new_base = values.base, new_values.base
252+
elif isinstance(values, NDArrayBackedExtensionArray):
253+
base, new_base = values._ndarray.base, new_values._ndarray.base
254+
else:
255+
base, new_base = 1, 2 # type: ignore[assignment]
256+
if base is new_base:
257+
# We can only get here if one of the dimensions is size 1
258+
result._mgr.add_references(obj._mgr)
259+
return result
248260

249261
def get_new_values(self, values, fill_value=None):
250262
if values.ndim == 1:
@@ -532,9 +544,7 @@ def unstack(
532544
unstacker = _Unstacker(
533545
obj.index, level=level, constructor=obj._constructor_expanddim, sort=sort
534546
)
535-
return unstacker.get_result(
536-
obj._values, value_columns=None, fill_value=fill_value
537-
)
547+
return unstacker.get_result(obj, value_columns=None, fill_value=fill_value)
538548

539549

540550
def _unstack_frame(
@@ -550,7 +560,7 @@ def _unstack_frame(
550560
return obj._constructor_from_mgr(mgr, axes=mgr.axes)
551561
else:
552562
return unstacker.get_result(
553-
obj._values, value_columns=obj.columns, fill_value=fill_value
563+
obj, value_columns=obj.columns, fill_value=fill_value
554564
)
555565

556566

pandas/tests/reshape/test_pivot.py

+13
Original file line numberDiff line numberDiff line change
@@ -2703,3 +2703,16 @@ def test_pivot_table_with_margins_and_numeric_column_names(self):
27032703
index=Index(["a", "b", "All"], name=0),
27042704
)
27052705
tm.assert_frame_equal(result, expected)
2706+
2707+
@pytest.mark.parametrize("m", [1, 10])
2708+
def test_unstack_shares_memory(self, m):
2709+
# GH#56633
2710+
levels = np.arange(m)
2711+
index = MultiIndex.from_product([levels] * 2)
2712+
values = np.arange(m * m * 100).reshape(m * m, 100)
2713+
df = DataFrame(values, index, np.arange(100))
2714+
df_orig = df.copy()
2715+
result = df.unstack(sort=False)
2716+
assert np.shares_memory(df._values, result._values) is (m == 1)
2717+
result.iloc[0, 0] = -1
2718+
tm.assert_frame_equal(df, df_orig)

0 commit comments

Comments
 (0)