diff --git a/pandas/_libs/internals.pyx b/pandas/_libs/internals.pyx index 277243d72c536..7323bdfc4c6d7 100644 --- a/pandas/_libs/internals.pyx +++ b/pandas/_libs/internals.pyx @@ -890,6 +890,16 @@ cdef class BlockValuesRefs: """ self.referenced_blocks.append(weakref.ref(blk)) + def add_index_reference(self, index: object) -> None: + """Adds a new reference to our reference collection when creating an index. + + Parameters + ---------- + index: object + The index that the new reference should point to. + """ + self.referenced_blocks.append(weakref.ref(index)) + def has_reference(self) -> bool: """Checks if block has foreign references. diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 786f7c7a11ed0..ac229c5f50d58 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5879,7 +5879,7 @@ def set_index( names.append(None) # from here, col can only be a column label else: - arrays.append(frame[col]._values) + arrays.append(frame[col]) names.append(col) if drop: to_remove.append(col) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index acebe8a498f03..ec97736090a01 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -73,7 +73,10 @@ rewrite_exception, ) -from pandas.core.dtypes.astype import astype_array +from pandas.core.dtypes.astype import ( + astype_array, + astype_is_view, +) from pandas.core.dtypes.cast import ( LossySetitemError, can_hold_element, @@ -457,6 +460,8 @@ def _engine_type( str = CachedAccessor("str", StringMethods) + _references = None + # -------------------------------------------------------------------- # Constructors @@ -477,6 +482,10 @@ def __new__( data_dtype = getattr(data, "dtype", None) + refs = None + if not copy and isinstance(data, (ABCSeries, Index)): + refs = data._references + # range if isinstance(data, (range, RangeIndex)): result = RangeIndex(start=data, copy=copy, name=name) @@ -550,7 +559,7 @@ def __new__( klass = cls._dtype_to_subclass(arr.dtype) arr = klass._ensure_array(arr, arr.dtype, copy=False) - return klass._simple_new(arr, name) + return klass._simple_new(arr, name, refs=refs) @classmethod def _ensure_array(cls, data, dtype, copy: bool): @@ -629,7 +638,7 @@ def _dtype_to_subclass(cls, dtype: DtypeObj): @classmethod def _simple_new( - cls: type[_IndexT], values: ArrayLike, name: Hashable = None + cls: type[_IndexT], values: ArrayLike, name: Hashable = None, refs=None ) -> _IndexT: """ We require that we have a dtype compat for the values. If we are passed @@ -644,6 +653,9 @@ def _simple_new( result._name = name result._cache = {} result._reset_identity() + result._references = refs + if refs is not None: + refs.add_index_reference(result) return result @@ -740,13 +752,13 @@ def _shallow_copy(self: _IndexT, values, name: Hashable = no_default) -> _IndexT """ name = self._name if name is no_default else name - return self._simple_new(values, name=name) + return self._simple_new(values, name=name, refs=self._references) def _view(self: _IndexT) -> _IndexT: """ fastpath to make a shallow copy, i.e. new object with same data. """ - result = self._simple_new(self._values, name=self._name) + result = self._simple_new(self._values, name=self._name, refs=self._references) result._cache = self._cache return result @@ -956,7 +968,7 @@ def view(self, cls=None): # of types. arr_cls = idx_cls._data_cls arr = arr_cls(self._data.view("i8"), dtype=dtype) - return idx_cls._simple_new(arr, name=self.name) + return idx_cls._simple_new(arr, name=self.name, refs=self._references) result = self._data.view(cls) else: @@ -1012,7 +1024,15 @@ def astype(self, dtype, copy: bool = True): new_values = astype_array(values, dtype=dtype, copy=copy) # pass copy=False because any copying will be done in the astype above - return Index(new_values, name=self.name, dtype=new_values.dtype, copy=False) + result = Index(new_values, name=self.name, dtype=new_values.dtype, copy=False) + if ( + not copy + and self._references is not None + and astype_is_view(self.dtype, dtype) + ): + result._references = self._references + result._references.add_index_reference(result) + return result _index_shared_docs[ "take" @@ -5155,7 +5175,9 @@ def __getitem__(self, key): # pessimization com.is_bool_indexer and ndim checks. result = getitem(key) # Going through simple_new for performance. - return type(self)._simple_new(result, name=self._name) + return type(self)._simple_new( + result, name=self._name, refs=self._references + ) if com.is_bool_indexer(key): # if we have list[bools, length=1e5] then doing this check+convert @@ -5181,7 +5203,7 @@ def _getitem_slice(self: _IndexT, slobj: slice) -> _IndexT: Fastpath for __getitem__ when we know we have a slice. """ res = self._data[slobj] - return type(self)._simple_new(res, name=self._name) + return type(self)._simple_new(res, name=self._name, refs=self._references) @final def _can_hold_identifiers_and_holds_name(self, name) -> bool: @@ -6700,7 +6722,11 @@ def infer_objects(self, copy: bool = True) -> Index: ) if copy and res_values is values: return self.copy() - return Index(res_values, name=self.name) + result = Index(res_values, name=self.name) + if not copy and res_values is values and self._references is not None: + result._references = self._references + result._references.add_index_reference(result) + return result # -------------------------------------------------------------------- # Generated Arithmetic, Comparison, and Unary Methods diff --git a/pandas/core/indexes/datetimes.py b/pandas/core/indexes/datetimes.py index 096e501c7bd6e..1d24af5293a9e 100644 --- a/pandas/core/indexes/datetimes.py +++ b/pandas/core/indexes/datetimes.py @@ -44,6 +44,7 @@ is_datetime64tz_dtype, is_scalar, ) +from pandas.core.dtypes.generic import ABCSeries from pandas.core.dtypes.missing import is_valid_na_for_dtype from pandas.core.arrays.datetimes import ( @@ -266,7 +267,7 @@ def strftime(self, date_format) -> Index: @doc(DatetimeArray.tz_convert) def tz_convert(self, tz) -> DatetimeIndex: arr = self._data.tz_convert(tz) - return type(self)._simple_new(arr, name=self.name) + return type(self)._simple_new(arr, name=self.name, refs=self._references) @doc(DatetimeArray.tz_localize) def tz_localize( @@ -345,8 +346,11 @@ def __new__( yearfirst=yearfirst, ambiguous=ambiguous, ) + refs = None + if not copy and isinstance(data, (Index, ABCSeries)): + refs = data._references - subarr = cls._simple_new(dtarr, name=name) + subarr = cls._simple_new(dtarr, name=name, refs=refs) return subarr # -------------------------------------------------------------------- diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index b381752818ba0..efb232eeeb22f 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -353,6 +353,7 @@ def __new__( result._codes = new_codes result._reset_identity() + result._references = None return result diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index 2d4f0736e30fa..eb898786e24c9 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -28,6 +28,7 @@ from pandas.core.dtypes.common import is_integer from pandas.core.dtypes.dtypes import PeriodDtype +from pandas.core.dtypes.generic import ABCSeries from pandas.core.dtypes.missing import is_valid_na_for_dtype from pandas.core.arrays.period import ( @@ -217,6 +218,10 @@ def __new__( "second", } + refs = None + if not copy and isinstance(data, (Index, ABCSeries)): + refs = data._references + if not set(fields).issubset(valid_field_set): argument = list(set(fields) - valid_field_set)[0] raise TypeError(f"__new__() got an unexpected keyword argument {argument}") @@ -257,7 +262,7 @@ def __new__( if copy: data = data.copy() - return cls._simple_new(data, name=name) + return cls._simple_new(data, name=name, refs=refs) # ------------------------------------------------------------------------ # Data diff --git a/pandas/core/indexes/range.py b/pandas/core/indexes/range.py index 0be539a9c3216..b6975a5848874 100644 --- a/pandas/core/indexes/range.py +++ b/pandas/core/indexes/range.py @@ -175,6 +175,7 @@ def _simple_new( # type: ignore[override] result._name = name result._cache = {} result._reset_identity() + result._references = None return result @classmethod diff --git a/pandas/core/indexes/timedeltas.py b/pandas/core/indexes/timedeltas.py index e7ea54df62411..482c0da36f610 100644 --- a/pandas/core/indexes/timedeltas.py +++ b/pandas/core/indexes/timedeltas.py @@ -17,6 +17,7 @@ is_scalar, is_timedelta64_dtype, ) +from pandas.core.dtypes.generic import ABCSeries from pandas.core.arrays import datetimelike as dtl from pandas.core.arrays.timedeltas import TimedeltaArray @@ -168,7 +169,11 @@ def __new__( tdarr = TimedeltaArray._from_sequence_not_strict( data, freq=freq, unit=unit, dtype=dtype, copy=copy ) - return cls._simple_new(tdarr, name=name) + refs = None + if not copy and isinstance(data, (ABCSeries, Index)): + refs = data._references + + return cls._simple_new(tdarr, name=name, refs=refs) # ------------------------------------------------------------------- diff --git a/pandas/tests/copy_view/index/__init__.py b/pandas/tests/copy_view/index/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pandas/tests/copy_view/index/test_datetimeindex.py b/pandas/tests/copy_view/index/test_datetimeindex.py new file mode 100644 index 0000000000000..f691d5589f48c --- /dev/null +++ b/pandas/tests/copy_view/index/test_datetimeindex.py @@ -0,0 +1,56 @@ +import pytest + +from pandas import ( + DatetimeIndex, + Series, + Timestamp, + date_range, +) +import pandas._testing as tm + + +@pytest.mark.parametrize( + "cons", + [ + lambda x: DatetimeIndex(x), + lambda x: DatetimeIndex(DatetimeIndex(x)), + ], +) +def test_datetimeindex(using_copy_on_write, cons): + dt = date_range("2019-12-31", periods=3, freq="D") + ser = Series(dt) + idx = cons(ser) + expected = idx.copy(deep=True) + ser.iloc[0] = Timestamp("2020-12-31") + if using_copy_on_write: + tm.assert_index_equal(idx, expected) + + +def test_datetimeindex_tz_convert(using_copy_on_write): + dt = date_range("2019-12-31", periods=3, freq="D", tz="Europe/Berlin") + ser = Series(dt) + idx = DatetimeIndex(ser).tz_convert("US/Eastern") + expected = idx.copy(deep=True) + ser.iloc[0] = Timestamp("2020-12-31", tz="Europe/Berlin") + if using_copy_on_write: + tm.assert_index_equal(idx, expected) + + +def test_datetimeindex_tz_localize(using_copy_on_write): + dt = date_range("2019-12-31", periods=3, freq="D") + ser = Series(dt) + idx = DatetimeIndex(ser).tz_localize("Europe/Berlin") + expected = idx.copy(deep=True) + ser.iloc[0] = Timestamp("2020-12-31") + if using_copy_on_write: + tm.assert_index_equal(idx, expected) + + +def test_datetimeindex_isocalendar(using_copy_on_write): + dt = date_range("2019-12-31", periods=3, freq="D") + ser = Series(dt) + df = DatetimeIndex(ser).isocalendar() + expected = df.index.copy(deep=True) + ser.iloc[0] = Timestamp("2020-12-31") + if using_copy_on_write: + tm.assert_index_equal(df.index, expected) diff --git a/pandas/tests/copy_view/index/test_index.py b/pandas/tests/copy_view/index/test_index.py new file mode 100644 index 0000000000000..817be43475d0b --- /dev/null +++ b/pandas/tests/copy_view/index/test_index.py @@ -0,0 +1,155 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Index, + Series, +) +import pandas._testing as tm +from pandas.tests.copy_view.util import get_array + + +def index_view(index_data=[1, 2]): + df = DataFrame({"a": index_data, "b": 1.5}) + view = df[:] + df = df.set_index("a", drop=True) + idx = df.index + # df = None + return idx, view + + +def test_set_index_update_column(using_copy_on_write): + df = DataFrame({"a": [1, 2], "b": 1}) + df = df.set_index("a", drop=False) + expected = df.index.copy(deep=True) + df.iloc[0, 0] = 100 + if using_copy_on_write: + tm.assert_index_equal(df.index, expected) + else: + tm.assert_index_equal(df.index, Index([100, 2], name="a")) + + +def test_set_index_drop_update_column(using_copy_on_write): + df = DataFrame({"a": [1, 2], "b": 1.5}) + view = df[:] + df = df.set_index("a", drop=True) + expected = df.index.copy(deep=True) + view.iloc[0, 0] = 100 + tm.assert_index_equal(df.index, expected) + + +def test_set_index_series(using_copy_on_write): + df = DataFrame({"a": [1, 2], "b": 1.5}) + ser = Series([10, 11]) + df = df.set_index(ser) + expected = df.index.copy(deep=True) + ser.iloc[0] = 100 + if using_copy_on_write: + tm.assert_index_equal(df.index, expected) + else: + tm.assert_index_equal(df.index, Index([100, 11])) + + +def test_assign_index_as_series(using_copy_on_write): + df = DataFrame({"a": [1, 2], "b": 1.5}) + ser = Series([10, 11]) + df.index = ser + expected = df.index.copy(deep=True) + ser.iloc[0] = 100 + if using_copy_on_write: + tm.assert_index_equal(df.index, expected) + else: + tm.assert_index_equal(df.index, Index([100, 11])) + + +def test_assign_index_as_index(using_copy_on_write): + df = DataFrame({"a": [1, 2], "b": 1.5}) + ser = Series([10, 11]) + rhs_index = Index(ser) + df.index = rhs_index + rhs_index = None # overwrite to clear reference + expected = df.index.copy(deep=True) + ser.iloc[0] = 100 + if using_copy_on_write: + tm.assert_index_equal(df.index, expected) + else: + tm.assert_index_equal(df.index, Index([100, 11])) + + +def test_index_from_series(using_copy_on_write): + ser = Series([1, 2]) + idx = Index(ser) + expected = idx.copy(deep=True) + ser.iloc[0] = 100 + if using_copy_on_write: + tm.assert_index_equal(idx, expected) + else: + tm.assert_index_equal(idx, Index([100, 2])) + + +def test_index_from_series_copy(using_copy_on_write): + ser = Series([1, 2]) + idx = Index(ser, copy=True) # noqa + arr = get_array(ser) + ser.iloc[0] = 100 + assert np.shares_memory(get_array(ser), arr) + + +def test_index_from_index(using_copy_on_write): + ser = Series([1, 2]) + idx = Index(ser) + idx = Index(idx) + expected = idx.copy(deep=True) + ser.iloc[0] = 100 + if using_copy_on_write: + tm.assert_index_equal(idx, expected) + else: + tm.assert_index_equal(idx, Index([100, 2])) + + +@pytest.mark.parametrize( + "func", + [ + lambda x: x._shallow_copy(x._values), + lambda x: x.view(), + lambda x: x.take([0, 1]), + lambda x: x.repeat([1, 1]), + lambda x: x[slice(0, 2)], + lambda x: x[[0, 1]], + lambda x: x._getitem_slice(slice(0, 2)), + lambda x: x.delete([]), + lambda x: x.rename("b"), + lambda x: x.astype("Int64", copy=False), + ], + ids=[ + "_shallow_copy", + "view", + "take", + "repeat", + "getitem_slice", + "getitem_list", + "_getitem_slice", + "delete", + "rename", + "astype", + ], +) +def test_index_ops(using_copy_on_write, func, request): + idx, view_ = index_view() + expected = idx.copy(deep=True) + if "astype" in request.node.callspec.id: + expected = expected.astype("Int64") + idx = func(idx) + view_.iloc[0, 0] = 100 + if using_copy_on_write: + tm.assert_index_equal(idx, expected, check_names=False) + + +def test_infer_objects(using_copy_on_write): + idx, view_ = index_view(["a", "b"]) + expected = idx.copy(deep=True) + idx = idx.infer_objects(copy=False) + view_.iloc[0, 0] = "aaaa" + if using_copy_on_write: + tm.assert_index_equal(idx, expected, check_names=False) diff --git a/pandas/tests/copy_view/index/test_periodindex.py b/pandas/tests/copy_view/index/test_periodindex.py new file mode 100644 index 0000000000000..94bc3a66f0e2b --- /dev/null +++ b/pandas/tests/copy_view/index/test_periodindex.py @@ -0,0 +1,26 @@ +import pytest + +from pandas import ( + Period, + PeriodIndex, + Series, + period_range, +) +import pandas._testing as tm + + +@pytest.mark.parametrize( + "cons", + [ + lambda x: PeriodIndex(x), + lambda x: PeriodIndex(PeriodIndex(x)), + ], +) +def test_periodindex(using_copy_on_write, cons): + dt = period_range("2019-12-31", periods=3, freq="D") + ser = Series(dt) + idx = cons(ser) + expected = idx.copy(deep=True) + ser.iloc[0] = Period("2020-12-31") + if using_copy_on_write: + tm.assert_index_equal(idx, expected) diff --git a/pandas/tests/copy_view/index/test_timedeltaindex.py b/pandas/tests/copy_view/index/test_timedeltaindex.py new file mode 100644 index 0000000000000..a543e06cea328 --- /dev/null +++ b/pandas/tests/copy_view/index/test_timedeltaindex.py @@ -0,0 +1,26 @@ +import pytest + +from pandas import ( + Series, + Timedelta, + TimedeltaIndex, + timedelta_range, +) +import pandas._testing as tm + + +@pytest.mark.parametrize( + "cons", + [ + lambda x: TimedeltaIndex(x), + lambda x: TimedeltaIndex(TimedeltaIndex(x)), + ], +) +def test_timedeltaindex(using_copy_on_write, cons): + dt = timedelta_range("1 day", periods=3) + ser = Series(dt) + idx = cons(ser) + expected = idx.copy(deep=True) + ser.iloc[0] = Timedelta("5 days") + if using_copy_on_write: + tm.assert_index_equal(idx, expected)