Skip to content

CoW: Add reference tracking to index when created from series #51803

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 15, 2023
10 changes: 10 additions & 0 deletions pandas/_libs/internals.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,16 @@ cdef class BlockValuesRefs:
"""
self.referenced_blocks.append(weakref.ref(blk))

def add_index_reference(self, index: object) -> None:
"""Adds a new reference to our reference collection when creating an index.

Parameters
----------
index: object
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an Index object right? (even if we can't put it in the annotation)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

The index that the new reference should point to.
"""
self.referenced_blocks.append(weakref.ref(index))

def has_reference(self) -> bool:
"""Checks if block has foreign references.

Expand Down
2 changes: 1 addition & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5935,7 +5935,7 @@ def set_index(
names.append(None)
# from here, col can only be a column label
else:
arrays.append(frame[col]._values)
arrays.append(frame[col])
names.append(col)
if drop:
to_remove.append(col)
Expand Down
42 changes: 33 additions & 9 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@
rewrite_exception,
)

from pandas.core.dtypes.astype import astype_array
from pandas.core.dtypes.astype import (
astype_array,
astype_is_view,
)
from pandas.core.dtypes.cast import (
LossySetitemError,
can_hold_element,
Expand Down Expand Up @@ -458,6 +461,8 @@ def _engine_type(

str = CachedAccessor("str", StringMethods)

_references = None

# --------------------------------------------------------------------
# Constructors

Expand All @@ -478,6 +483,10 @@ def __new__(

data_dtype = getattr(data, "dtype", None)

refs = None
if not copy and isinstance(data, (ABCSeries, Index)):
refs = data._references
Comment on lines +487 to +488
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something I was also wondering if my related PR: this tackles it for Series and Index, but in theory we also have the problem with arrays:

arr = np.array([1, 2, 3])
ser = pd.Series(arr, index=arr)

And with then mutating ser, you can also trigger faulty behaviour / crashes.

So while for Index/Series this avoids a copy, we might still want to copy anyway for other array-likes (like we are going to do in the DataFrame/Series constructors)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep exactly. But I wanted to wait for the pr that tackles this for the DataFrame case to get merged before adding this for index


# range
if isinstance(data, (range, RangeIndex)):
result = RangeIndex(start=data, copy=copy, name=name)
Expand Down Expand Up @@ -551,7 +560,7 @@ def __new__(
klass = cls._dtype_to_subclass(arr.dtype)

arr = klass._ensure_array(arr, arr.dtype, copy=False)
return klass._simple_new(arr, name)
return klass._simple_new(arr, name, refs=refs)

@classmethod
def _ensure_array(cls, data, dtype, copy: bool):
Expand Down Expand Up @@ -629,7 +638,7 @@ def _dtype_to_subclass(cls, dtype: DtypeObj):
# See each method's docstring.

@classmethod
def _simple_new(cls, values: ArrayLike, name: Hashable = None) -> Self:
def _simple_new(cls, values: ArrayLike, name: Hashable = None, refs=None) -> Self:
"""
We require that we have a dtype compat for the values. If we are passed
a non-dtype compat, then coerce using the constructor.
Expand All @@ -643,6 +652,9 @@ def _simple_new(cls, values: ArrayLike, name: Hashable = None) -> Self:
result._name = name
result._cache = {}
result._reset_identity()
result._references = refs
if refs is not None:
refs.add_index_reference(result)

return result

Expand Down Expand Up @@ -739,13 +751,13 @@ def _shallow_copy(self, values, name: Hashable = no_default) -> Self:
"""
name = self._name if name is no_default else name

return self._simple_new(values, name=name)
return self._simple_new(values, name=name, refs=self._references)

def _view(self) -> Self:
"""
fastpath to make a shallow copy, i.e. new object with same data.
"""
result = self._simple_new(self._values, name=self._name)
result = self._simple_new(self._values, name=self._name, refs=self._references)

result._cache = self._cache
return result
Expand Down Expand Up @@ -955,7 +967,7 @@ def view(self, cls=None):
# of types.
arr_cls = idx_cls._data_cls
arr = arr_cls(self._data.view("i8"), dtype=dtype)
return idx_cls._simple_new(arr, name=self.name)
return idx_cls._simple_new(arr, name=self.name, refs=self._references)

result = self._data.view(cls)
else:
Expand Down Expand Up @@ -1011,7 +1023,15 @@ def astype(self, dtype, copy: bool = True):
new_values = astype_array(values, dtype=dtype, copy=copy)

# pass copy=False because any copying will be done in the astype above
return Index(new_values, name=self.name, dtype=new_values.dtype, copy=False)
result = Index(new_values, name=self.name, dtype=new_values.dtype, copy=False)
if (
not copy
and self._references is not None
and astype_is_view(self.dtype, dtype)
):
result._references = self._references
result._references.add_index_reference(result)
return result

_index_shared_docs[
"take"
Expand Down Expand Up @@ -5183,7 +5203,7 @@ def _getitem_slice(self, slobj: slice) -> Self:
Fastpath for __getitem__ when we know we have a slice.
"""
res = self._data[slobj]
result = type(self)._simple_new(res, name=self._name)
result = type(self)._simple_new(res, name=self._name, refs=self._references)
if "_engine" in self._cache:
reverse = slobj.step is not None and slobj.step < 0
result._engine._update_from_sliced(self._engine, reverse=reverse) # type: ignore[union-attr] # noqa: E501
Expand Down Expand Up @@ -6707,7 +6727,11 @@ def infer_objects(self, copy: bool = True) -> Index:
)
if copy and res_values is values:
return self.copy()
return Index(res_values, name=self.name)
result = Index(res_values, name=self.name)
if not copy and res_values is values and self._references is not None:
result._references = self._references
result._references.add_index_reference(result)
return result

# --------------------------------------------------------------------
# Generated Arithmetic, Comparison, and Unary Methods
Expand Down
8 changes: 6 additions & 2 deletions pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_datetime64tz_dtype,
is_scalar,
)
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.dtypes.missing import is_valid_na_for_dtype

from pandas.core.arrays.datetimes import (
Expand Down Expand Up @@ -267,7 +268,7 @@ def strftime(self, date_format) -> Index:
@doc(DatetimeArray.tz_convert)
def tz_convert(self, tz) -> DatetimeIndex:
arr = self._data.tz_convert(tz)
return type(self)._simple_new(arr, name=self.name)
return type(self)._simple_new(arr, name=self.name, refs=self._references)

@doc(DatetimeArray.tz_localize)
def tz_localize(
Expand Down Expand Up @@ -346,8 +347,11 @@ def __new__(
yearfirst=yearfirst,
ambiguous=ambiguous,
)
refs = None
if not copy and isinstance(data, (Index, ABCSeries)):
refs = data._references

subarr = cls._simple_new(dtarr, name=name)
subarr = cls._simple_new(dtarr, name=name, refs=refs)
return subarr

# --------------------------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def __new__(
result._codes = new_codes

result._reset_identity()
result._references = None

return result

Expand Down
7 changes: 6 additions & 1 deletion pandas/core/indexes/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from pandas.core.dtypes.common import is_integer
from pandas.core.dtypes.dtypes import PeriodDtype
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.dtypes.missing import is_valid_na_for_dtype

from pandas.core.arrays.period import (
Expand Down Expand Up @@ -221,6 +222,10 @@ def __new__(
"second",
}

refs = None
if not copy and isinstance(data, (Index, ABCSeries)):
refs = data._references

if not set(fields).issubset(valid_field_set):
argument = list(set(fields) - valid_field_set)[0]
raise TypeError(f"__new__() got an unexpected keyword argument {argument}")
Expand Down Expand Up @@ -261,7 +266,7 @@ def __new__(
if copy:
data = data.copy()

return cls._simple_new(data, name=name)
return cls._simple_new(data, name=name, refs=refs)

# ------------------------------------------------------------------------
# Data
Expand Down
1 change: 1 addition & 0 deletions pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def _simple_new( # type: ignore[override]
result._name = name
result._cache = {}
result._reset_identity()
result._references = None
return result

@classmethod
Expand Down
7 changes: 6 additions & 1 deletion pandas/core/indexes/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
is_scalar,
is_timedelta64_dtype,
)
from pandas.core.dtypes.generic import ABCSeries

from pandas.core.arrays import datetimelike as dtl
from pandas.core.arrays.timedeltas import TimedeltaArray
Expand Down Expand Up @@ -172,7 +173,11 @@ def __new__(
tdarr = TimedeltaArray._from_sequence_not_strict(
data, freq=freq, unit=unit, dtype=dtype, copy=copy
)
return cls._simple_new(tdarr, name=name)
refs = None
if not copy and isinstance(data, (ABCSeries, Index)):
refs = data._references

return cls._simple_new(tdarr, name=name, refs=refs)

# -------------------------------------------------------------------

Expand Down
Empty file.
56 changes: 56 additions & 0 deletions pandas/tests/copy_view/index/test_datetimeindex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest

from pandas import (
DatetimeIndex,
Series,
Timestamp,
date_range,
)
import pandas._testing as tm


@pytest.mark.parametrize(
"cons",
[
lambda x: DatetimeIndex(x),
lambda x: DatetimeIndex(DatetimeIndex(x)),
],
)
def test_datetimeindex(using_copy_on_write, cons):
dt = date_range("2019-12-31", periods=3, freq="D")
ser = Series(dt)
idx = cons(ser)
expected = idx.copy(deep=True)
ser.iloc[0] = Timestamp("2020-12-31")
if using_copy_on_write:
tm.assert_index_equal(idx, expected)


def test_datetimeindex_tz_convert(using_copy_on_write):
dt = date_range("2019-12-31", periods=3, freq="D", tz="Europe/Berlin")
ser = Series(dt)
idx = DatetimeIndex(ser).tz_convert("US/Eastern")
expected = idx.copy(deep=True)
ser.iloc[0] = Timestamp("2020-12-31", tz="Europe/Berlin")
if using_copy_on_write:
tm.assert_index_equal(idx, expected)


def test_datetimeindex_tz_localize(using_copy_on_write):
dt = date_range("2019-12-31", periods=3, freq="D")
ser = Series(dt)
idx = DatetimeIndex(ser).tz_localize("Europe/Berlin")
expected = idx.copy(deep=True)
ser.iloc[0] = Timestamp("2020-12-31")
if using_copy_on_write:
tm.assert_index_equal(idx, expected)


def test_datetimeindex_isocalendar(using_copy_on_write):
dt = date_range("2019-12-31", periods=3, freq="D")
ser = Series(dt)
df = DatetimeIndex(ser).isocalendar()
expected = df.index.copy(deep=True)
ser.iloc[0] = Timestamp("2020-12-31")
if using_copy_on_write:
tm.assert_index_equal(df.index, expected)
Loading