Skip to content

Commit fbc660b

Browse files
authored
Backport PR #51803 on branch 2.0.x (CoW: Add reference tracking to index when created from series) (#52000)
1 parent 6b6c336 commit fbc660b

13 files changed

+330
-15
lines changed

pandas/_libs/internals.pyx

+10
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,16 @@ cdef class BlockValuesRefs:
890890
"""
891891
self.referenced_blocks.append(weakref.ref(blk))
892892

893+
def add_index_reference(self, index: object) -> None:
894+
"""Adds a new reference to our reference collection when creating an index.
895+
896+
Parameters
897+
----------
898+
index: object
899+
The index that the new reference should point to.
900+
"""
901+
self.referenced_blocks.append(weakref.ref(index))
902+
893903
def has_reference(self) -> bool:
894904
"""Checks if block has foreign references.
895905

pandas/core/frame.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5879,7 +5879,7 @@ def set_index(
58795879
names.append(None)
58805880
# from here, col can only be a column label
58815881
else:
5882-
arrays.append(frame[col]._values)
5882+
arrays.append(frame[col])
58835883
names.append(col)
58845884
if drop:
58855885
to_remove.append(col)

pandas/core/indexes/base.py

+36-10
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@
7373
rewrite_exception,
7474
)
7575

76-
from pandas.core.dtypes.astype import astype_array
76+
from pandas.core.dtypes.astype import (
77+
astype_array,
78+
astype_is_view,
79+
)
7780
from pandas.core.dtypes.cast import (
7881
LossySetitemError,
7982
can_hold_element,
@@ -457,6 +460,8 @@ def _engine_type(
457460

458461
str = CachedAccessor("str", StringMethods)
459462

463+
_references = None
464+
460465
# --------------------------------------------------------------------
461466
# Constructors
462467

@@ -477,6 +482,10 @@ def __new__(
477482

478483
data_dtype = getattr(data, "dtype", None)
479484

485+
refs = None
486+
if not copy and isinstance(data, (ABCSeries, Index)):
487+
refs = data._references
488+
480489
# range
481490
if isinstance(data, (range, RangeIndex)):
482491
result = RangeIndex(start=data, copy=copy, name=name)
@@ -550,7 +559,7 @@ def __new__(
550559
klass = cls._dtype_to_subclass(arr.dtype)
551560

552561
arr = klass._ensure_array(arr, arr.dtype, copy=False)
553-
return klass._simple_new(arr, name)
562+
return klass._simple_new(arr, name, refs=refs)
554563

555564
@classmethod
556565
def _ensure_array(cls, data, dtype, copy: bool):
@@ -629,7 +638,7 @@ def _dtype_to_subclass(cls, dtype: DtypeObj):
629638

630639
@classmethod
631640
def _simple_new(
632-
cls: type[_IndexT], values: ArrayLike, name: Hashable = None
641+
cls: type[_IndexT], values: ArrayLike, name: Hashable = None, refs=None
633642
) -> _IndexT:
634643
"""
635644
We require that we have a dtype compat for the values. If we are passed
@@ -644,6 +653,9 @@ def _simple_new(
644653
result._name = name
645654
result._cache = {}
646655
result._reset_identity()
656+
result._references = refs
657+
if refs is not None:
658+
refs.add_index_reference(result)
647659

648660
return result
649661

@@ -740,13 +752,13 @@ def _shallow_copy(self: _IndexT, values, name: Hashable = no_default) -> _IndexT
740752
"""
741753
name = self._name if name is no_default else name
742754

743-
return self._simple_new(values, name=name)
755+
return self._simple_new(values, name=name, refs=self._references)
744756

745757
def _view(self: _IndexT) -> _IndexT:
746758
"""
747759
fastpath to make a shallow copy, i.e. new object with same data.
748760
"""
749-
result = self._simple_new(self._values, name=self._name)
761+
result = self._simple_new(self._values, name=self._name, refs=self._references)
750762

751763
result._cache = self._cache
752764
return result
@@ -956,7 +968,7 @@ def view(self, cls=None):
956968
# of types.
957969
arr_cls = idx_cls._data_cls
958970
arr = arr_cls(self._data.view("i8"), dtype=dtype)
959-
return idx_cls._simple_new(arr, name=self.name)
971+
return idx_cls._simple_new(arr, name=self.name, refs=self._references)
960972

961973
result = self._data.view(cls)
962974
else:
@@ -1012,7 +1024,15 @@ def astype(self, dtype, copy: bool = True):
10121024
new_values = astype_array(values, dtype=dtype, copy=copy)
10131025

10141026
# pass copy=False because any copying will be done in the astype above
1015-
return Index(new_values, name=self.name, dtype=new_values.dtype, copy=False)
1027+
result = Index(new_values, name=self.name, dtype=new_values.dtype, copy=False)
1028+
if (
1029+
not copy
1030+
and self._references is not None
1031+
and astype_is_view(self.dtype, dtype)
1032+
):
1033+
result._references = self._references
1034+
result._references.add_index_reference(result)
1035+
return result
10161036

10171037
_index_shared_docs[
10181038
"take"
@@ -5155,7 +5175,9 @@ def __getitem__(self, key):
51555175
# pessimization com.is_bool_indexer and ndim checks.
51565176
result = getitem(key)
51575177
# Going through simple_new for performance.
5158-
return type(self)._simple_new(result, name=self._name)
5178+
return type(self)._simple_new(
5179+
result, name=self._name, refs=self._references
5180+
)
51595181

51605182
if com.is_bool_indexer(key):
51615183
# if we have list[bools, length=1e5] then doing this check+convert
@@ -5181,7 +5203,7 @@ def _getitem_slice(self: _IndexT, slobj: slice) -> _IndexT:
51815203
Fastpath for __getitem__ when we know we have a slice.
51825204
"""
51835205
res = self._data[slobj]
5184-
return type(self)._simple_new(res, name=self._name)
5206+
return type(self)._simple_new(res, name=self._name, refs=self._references)
51855207

51865208
@final
51875209
def _can_hold_identifiers_and_holds_name(self, name) -> bool:
@@ -6700,7 +6722,11 @@ def infer_objects(self, copy: bool = True) -> Index:
67006722
)
67016723
if copy and res_values is values:
67026724
return self.copy()
6703-
return Index(res_values, name=self.name)
6725+
result = Index(res_values, name=self.name)
6726+
if not copy and res_values is values and self._references is not None:
6727+
result._references = self._references
6728+
result._references.add_index_reference(result)
6729+
return result
67046730

67056731
# --------------------------------------------------------------------
67066732
# Generated Arithmetic, Comparison, and Unary Methods

pandas/core/indexes/datetimes.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
is_datetime64tz_dtype,
4545
is_scalar,
4646
)
47+
from pandas.core.dtypes.generic import ABCSeries
4748
from pandas.core.dtypes.missing import is_valid_na_for_dtype
4849

4950
from pandas.core.arrays.datetimes import (
@@ -266,7 +267,7 @@ def strftime(self, date_format) -> Index:
266267
@doc(DatetimeArray.tz_convert)
267268
def tz_convert(self, tz) -> DatetimeIndex:
268269
arr = self._data.tz_convert(tz)
269-
return type(self)._simple_new(arr, name=self.name)
270+
return type(self)._simple_new(arr, name=self.name, refs=self._references)
270271

271272
@doc(DatetimeArray.tz_localize)
272273
def tz_localize(
@@ -345,8 +346,11 @@ def __new__(
345346
yearfirst=yearfirst,
346347
ambiguous=ambiguous,
347348
)
349+
refs = None
350+
if not copy and isinstance(data, (Index, ABCSeries)):
351+
refs = data._references
348352

349-
subarr = cls._simple_new(dtarr, name=name)
353+
subarr = cls._simple_new(dtarr, name=name, refs=refs)
350354
return subarr
351355

352356
# --------------------------------------------------------------------

pandas/core/indexes/multi.py

+1
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def __new__(
353353
result._codes = new_codes
354354

355355
result._reset_identity()
356+
result._references = None
356357

357358
return result
358359

pandas/core/indexes/period.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from pandas.core.dtypes.common import is_integer
3030
from pandas.core.dtypes.dtypes import PeriodDtype
31+
from pandas.core.dtypes.generic import ABCSeries
3132
from pandas.core.dtypes.missing import is_valid_na_for_dtype
3233

3334
from pandas.core.arrays.period import (
@@ -217,6 +218,10 @@ def __new__(
217218
"second",
218219
}
219220

221+
refs = None
222+
if not copy and isinstance(data, (Index, ABCSeries)):
223+
refs = data._references
224+
220225
if not set(fields).issubset(valid_field_set):
221226
argument = list(set(fields) - valid_field_set)[0]
222227
raise TypeError(f"__new__() got an unexpected keyword argument {argument}")
@@ -257,7 +262,7 @@ def __new__(
257262
if copy:
258263
data = data.copy()
259264

260-
return cls._simple_new(data, name=name)
265+
return cls._simple_new(data, name=name, refs=refs)
261266

262267
# ------------------------------------------------------------------------
263268
# Data

pandas/core/indexes/range.py

+1
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def _simple_new( # type: ignore[override]
175175
result._name = name
176176
result._cache = {}
177177
result._reset_identity()
178+
result._references = None
178179
return result
179180

180181
@classmethod

pandas/core/indexes/timedeltas.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
is_scalar,
1818
is_timedelta64_dtype,
1919
)
20+
from pandas.core.dtypes.generic import ABCSeries
2021

2122
from pandas.core.arrays import datetimelike as dtl
2223
from pandas.core.arrays.timedeltas import TimedeltaArray
@@ -168,7 +169,11 @@ def __new__(
168169
tdarr = TimedeltaArray._from_sequence_not_strict(
169170
data, freq=freq, unit=unit, dtype=dtype, copy=copy
170171
)
171-
return cls._simple_new(tdarr, name=name)
172+
refs = None
173+
if not copy and isinstance(data, (ABCSeries, Index)):
174+
refs = data._references
175+
176+
return cls._simple_new(tdarr, name=name, refs=refs)
172177

173178
# -------------------------------------------------------------------
174179

pandas/tests/copy_view/index/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pytest
2+
3+
from pandas import (
4+
DatetimeIndex,
5+
Series,
6+
Timestamp,
7+
date_range,
8+
)
9+
import pandas._testing as tm
10+
11+
12+
@pytest.mark.parametrize(
13+
"cons",
14+
[
15+
lambda x: DatetimeIndex(x),
16+
lambda x: DatetimeIndex(DatetimeIndex(x)),
17+
],
18+
)
19+
def test_datetimeindex(using_copy_on_write, cons):
20+
dt = date_range("2019-12-31", periods=3, freq="D")
21+
ser = Series(dt)
22+
idx = cons(ser)
23+
expected = idx.copy(deep=True)
24+
ser.iloc[0] = Timestamp("2020-12-31")
25+
if using_copy_on_write:
26+
tm.assert_index_equal(idx, expected)
27+
28+
29+
def test_datetimeindex_tz_convert(using_copy_on_write):
30+
dt = date_range("2019-12-31", periods=3, freq="D", tz="Europe/Berlin")
31+
ser = Series(dt)
32+
idx = DatetimeIndex(ser).tz_convert("US/Eastern")
33+
expected = idx.copy(deep=True)
34+
ser.iloc[0] = Timestamp("2020-12-31", tz="Europe/Berlin")
35+
if using_copy_on_write:
36+
tm.assert_index_equal(idx, expected)
37+
38+
39+
def test_datetimeindex_tz_localize(using_copy_on_write):
40+
dt = date_range("2019-12-31", periods=3, freq="D")
41+
ser = Series(dt)
42+
idx = DatetimeIndex(ser).tz_localize("Europe/Berlin")
43+
expected = idx.copy(deep=True)
44+
ser.iloc[0] = Timestamp("2020-12-31")
45+
if using_copy_on_write:
46+
tm.assert_index_equal(idx, expected)
47+
48+
49+
def test_datetimeindex_isocalendar(using_copy_on_write):
50+
dt = date_range("2019-12-31", periods=3, freq="D")
51+
ser = Series(dt)
52+
df = DatetimeIndex(ser).isocalendar()
53+
expected = df.index.copy(deep=True)
54+
ser.iloc[0] = Timestamp("2020-12-31")
55+
if using_copy_on_write:
56+
tm.assert_index_equal(df.index, expected)

0 commit comments

Comments
 (0)