Skip to content

Backport PR #51803 on branch 2.0.x (CoW: Add reference tracking to index when created from series) #52000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pandas/_libs/internals.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,16 @@ cdef class BlockValuesRefs:
"""
self.referenced_blocks.append(weakref.ref(blk))

def add_index_reference(self, index: object) -> None:
"""Adds a new reference to our reference collection when creating an index.

Parameters
----------
index: object
The index that the new reference should point to.
"""
self.referenced_blocks.append(weakref.ref(index))

def has_reference(self) -> bool:
"""Checks if block has foreign references.

Expand Down
2 changes: 1 addition & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5879,7 +5879,7 @@ def set_index(
names.append(None)
# from here, col can only be a column label
else:
arrays.append(frame[col]._values)
arrays.append(frame[col])
names.append(col)
if drop:
to_remove.append(col)
Expand Down
46 changes: 36 additions & 10 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@
rewrite_exception,
)

from pandas.core.dtypes.astype import astype_array
from pandas.core.dtypes.astype import (
astype_array,
astype_is_view,
)
from pandas.core.dtypes.cast import (
LossySetitemError,
can_hold_element,
Expand Down Expand Up @@ -457,6 +460,8 @@ def _engine_type(

str = CachedAccessor("str", StringMethods)

_references = None

# --------------------------------------------------------------------
# Constructors

Expand All @@ -477,6 +482,10 @@ def __new__(

data_dtype = getattr(data, "dtype", None)

refs = None
if not copy and isinstance(data, (ABCSeries, Index)):
refs = data._references

# range
if isinstance(data, (range, RangeIndex)):
result = RangeIndex(start=data, copy=copy, name=name)
Expand Down Expand Up @@ -550,7 +559,7 @@ def __new__(
klass = cls._dtype_to_subclass(arr.dtype)

arr = klass._ensure_array(arr, arr.dtype, copy=False)
return klass._simple_new(arr, name)
return klass._simple_new(arr, name, refs=refs)

@classmethod
def _ensure_array(cls, data, dtype, copy: bool):
Expand Down Expand Up @@ -629,7 +638,7 @@ def _dtype_to_subclass(cls, dtype: DtypeObj):

@classmethod
def _simple_new(
cls: type[_IndexT], values: ArrayLike, name: Hashable = None
cls: type[_IndexT], values: ArrayLike, name: Hashable = None, refs=None
) -> _IndexT:
"""
We require that we have a dtype compat for the values. If we are passed
Expand All @@ -644,6 +653,9 @@ def _simple_new(
result._name = name
result._cache = {}
result._reset_identity()
result._references = refs
if refs is not None:
refs.add_index_reference(result)

return result

Expand Down Expand Up @@ -740,13 +752,13 @@ def _shallow_copy(self: _IndexT, values, name: Hashable = no_default) -> _IndexT
"""
name = self._name if name is no_default else name

return self._simple_new(values, name=name)
return self._simple_new(values, name=name, refs=self._references)

def _view(self: _IndexT) -> _IndexT:
"""
fastpath to make a shallow copy, i.e. new object with same data.
"""
result = self._simple_new(self._values, name=self._name)
result = self._simple_new(self._values, name=self._name, refs=self._references)

result._cache = self._cache
return result
Expand Down Expand Up @@ -956,7 +968,7 @@ def view(self, cls=None):
# of types.
arr_cls = idx_cls._data_cls
arr = arr_cls(self._data.view("i8"), dtype=dtype)
return idx_cls._simple_new(arr, name=self.name)
return idx_cls._simple_new(arr, name=self.name, refs=self._references)

result = self._data.view(cls)
else:
Expand Down Expand Up @@ -1012,7 +1024,15 @@ def astype(self, dtype, copy: bool = True):
new_values = astype_array(values, dtype=dtype, copy=copy)

# pass copy=False because any copying will be done in the astype above
return Index(new_values, name=self.name, dtype=new_values.dtype, copy=False)
result = Index(new_values, name=self.name, dtype=new_values.dtype, copy=False)
if (
not copy
and self._references is not None
and astype_is_view(self.dtype, dtype)
):
result._references = self._references
result._references.add_index_reference(result)
return result

_index_shared_docs[
"take"
Expand Down Expand Up @@ -5155,7 +5175,9 @@ def __getitem__(self, key):
# pessimization com.is_bool_indexer and ndim checks.
result = getitem(key)
# Going through simple_new for performance.
return type(self)._simple_new(result, name=self._name)
return type(self)._simple_new(
result, name=self._name, refs=self._references
)

if com.is_bool_indexer(key):
# if we have list[bools, length=1e5] then doing this check+convert
Expand All @@ -5181,7 +5203,7 @@ def _getitem_slice(self: _IndexT, slobj: slice) -> _IndexT:
Fastpath for __getitem__ when we know we have a slice.
"""
res = self._data[slobj]
return type(self)._simple_new(res, name=self._name)
return type(self)._simple_new(res, name=self._name, refs=self._references)

@final
def _can_hold_identifiers_and_holds_name(self, name) -> bool:
Expand Down Expand Up @@ -6700,7 +6722,11 @@ def infer_objects(self, copy: bool = True) -> Index:
)
if copy and res_values is values:
return self.copy()
return Index(res_values, name=self.name)
result = Index(res_values, name=self.name)
if not copy and res_values is values and self._references is not None:
result._references = self._references
result._references.add_index_reference(result)
return result

# --------------------------------------------------------------------
# Generated Arithmetic, Comparison, and Unary Methods
Expand Down
8 changes: 6 additions & 2 deletions pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
is_datetime64tz_dtype,
is_scalar,
)
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.dtypes.missing import is_valid_na_for_dtype

from pandas.core.arrays.datetimes import (
Expand Down Expand Up @@ -266,7 +267,7 @@ def strftime(self, date_format) -> Index:
@doc(DatetimeArray.tz_convert)
def tz_convert(self, tz) -> DatetimeIndex:
arr = self._data.tz_convert(tz)
return type(self)._simple_new(arr, name=self.name)
return type(self)._simple_new(arr, name=self.name, refs=self._references)

@doc(DatetimeArray.tz_localize)
def tz_localize(
Expand Down Expand Up @@ -345,8 +346,11 @@ def __new__(
yearfirst=yearfirst,
ambiguous=ambiguous,
)
refs = None
if not copy and isinstance(data, (Index, ABCSeries)):
refs = data._references

subarr = cls._simple_new(dtarr, name=name)
subarr = cls._simple_new(dtarr, name=name, refs=refs)
return subarr

# --------------------------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def __new__(
result._codes = new_codes

result._reset_identity()
result._references = None

return result

Expand Down
7 changes: 6 additions & 1 deletion pandas/core/indexes/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from pandas.core.dtypes.common import is_integer
from pandas.core.dtypes.dtypes import PeriodDtype
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.dtypes.missing import is_valid_na_for_dtype

from pandas.core.arrays.period import (
Expand Down Expand Up @@ -217,6 +218,10 @@ def __new__(
"second",
}

refs = None
if not copy and isinstance(data, (Index, ABCSeries)):
refs = data._references

if not set(fields).issubset(valid_field_set):
argument = list(set(fields) - valid_field_set)[0]
raise TypeError(f"__new__() got an unexpected keyword argument {argument}")
Expand Down Expand Up @@ -257,7 +262,7 @@ def __new__(
if copy:
data = data.copy()

return cls._simple_new(data, name=name)
return cls._simple_new(data, name=name, refs=refs)

# ------------------------------------------------------------------------
# Data
Expand Down
1 change: 1 addition & 0 deletions pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def _simple_new( # type: ignore[override]
result._name = name
result._cache = {}
result._reset_identity()
result._references = None
return result

@classmethod
Expand Down
7 changes: 6 additions & 1 deletion pandas/core/indexes/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
is_scalar,
is_timedelta64_dtype,
)
from pandas.core.dtypes.generic import ABCSeries

from pandas.core.arrays import datetimelike as dtl
from pandas.core.arrays.timedeltas import TimedeltaArray
Expand Down Expand Up @@ -168,7 +169,11 @@ def __new__(
tdarr = TimedeltaArray._from_sequence_not_strict(
data, freq=freq, unit=unit, dtype=dtype, copy=copy
)
return cls._simple_new(tdarr, name=name)
refs = None
if not copy and isinstance(data, (ABCSeries, Index)):
refs = data._references

return cls._simple_new(tdarr, name=name, refs=refs)

# -------------------------------------------------------------------

Expand Down
Empty file.
56 changes: 56 additions & 0 deletions pandas/tests/copy_view/index/test_datetimeindex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest

from pandas import (
DatetimeIndex,
Series,
Timestamp,
date_range,
)
import pandas._testing as tm


@pytest.mark.parametrize(
"cons",
[
lambda x: DatetimeIndex(x),
lambda x: DatetimeIndex(DatetimeIndex(x)),
],
)
def test_datetimeindex(using_copy_on_write, cons):
dt = date_range("2019-12-31", periods=3, freq="D")
ser = Series(dt)
idx = cons(ser)
expected = idx.copy(deep=True)
ser.iloc[0] = Timestamp("2020-12-31")
if using_copy_on_write:
tm.assert_index_equal(idx, expected)


def test_datetimeindex_tz_convert(using_copy_on_write):
dt = date_range("2019-12-31", periods=3, freq="D", tz="Europe/Berlin")
ser = Series(dt)
idx = DatetimeIndex(ser).tz_convert("US/Eastern")
expected = idx.copy(deep=True)
ser.iloc[0] = Timestamp("2020-12-31", tz="Europe/Berlin")
if using_copy_on_write:
tm.assert_index_equal(idx, expected)


def test_datetimeindex_tz_localize(using_copy_on_write):
dt = date_range("2019-12-31", periods=3, freq="D")
ser = Series(dt)
idx = DatetimeIndex(ser).tz_localize("Europe/Berlin")
expected = idx.copy(deep=True)
ser.iloc[0] = Timestamp("2020-12-31")
if using_copy_on_write:
tm.assert_index_equal(idx, expected)


def test_datetimeindex_isocalendar(using_copy_on_write):
dt = date_range("2019-12-31", periods=3, freq="D")
ser = Series(dt)
df = DatetimeIndex(ser).isocalendar()
expected = df.index.copy(deep=True)
ser.iloc[0] = Timestamp("2020-12-31")
if using_copy_on_write:
tm.assert_index_equal(df.index, expected)
Loading