Skip to content

Commit 74a178a

Browse files
jbrockmendelSeeminSyed
authored andcommitted
REF: make DatetimeIndex._simple_new actually simple (pandas-dev#32282)
1 parent 5d0792c commit 74a178a

File tree

8 files changed

+48
-70
lines changed

8 files changed

+48
-70
lines changed

pandas/_libs/tslibs/offsets.pyx

+12-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,18 @@ def apply_index_wraps(func):
114114
# Note: normally we would use `@functools.wraps(func)`, but this does
115115
# not play nicely with cython class methods
116116
def wrapper(self, other):
117-
result = func(self, other)
117+
118+
is_index = getattr(other, "_typ", "") == "datetimeindex"
119+
120+
# operate on DatetimeArray
121+
arr = other._data if is_index else other
122+
123+
result = func(self, arr)
124+
125+
if is_index:
126+
# Wrap DatetimeArray result back to DatetimeIndex
127+
result = type(other)._simple_new(result, name=other.name)
128+
118129
if self.normalize:
119130
result = result.to_period('D').to_timestamp()
120131
return result

pandas/core/indexes/base.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -3281,13 +3281,11 @@ def reindex(self, target, method=None, level=None, limit=None, tolerance=None):
32813281
target = _ensure_has_len(target) # target may be an iterator
32823282

32833283
if not isinstance(target, Index) and len(target) == 0:
3284-
attrs = self._get_attributes_dict()
3285-
attrs.pop("freq", None) # don't preserve freq
32863284
if isinstance(self, ABCRangeIndex):
32873285
values = range(0)
32883286
else:
32893287
values = self._data[:0] # appropriately-dtyped empty array
3290-
target = self._simple_new(values, **attrs)
3288+
target = self._simple_new(values, name=self.name)
32913289
else:
32923290
target = ensure_index(target)
32933291

pandas/core/indexes/datetimelike.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -622,21 +622,11 @@ def _shallow_copy(self, values=None, name: Label = lib.no_default):
622622
if values is None:
623623
values = self._data
624624

625-
if isinstance(values, type(self)):
626-
values = values._data
627625
if isinstance(values, np.ndarray):
628626
# TODO: We would rather not get here
629627
values = type(self._data)(values, dtype=self.dtype)
630628

631-
attributes = self._get_attributes_dict()
632-
633-
if self.freq is not None:
634-
if isinstance(values, (DatetimeArray, TimedeltaArray)):
635-
if values.freq is None:
636-
del attributes["freq"]
637-
638-
attributes["name"] = name
639-
result = self._simple_new(values, **attributes)
629+
result = type(self)._simple_new(values, name=name)
640630
result._cache = cache
641631
return result
642632

@@ -780,7 +770,10 @@ def _fast_union(self, other, sort=None):
780770
loc = right.searchsorted(left_start, side="left")
781771
right_chunk = right.values[:loc]
782772
dates = concat_compat((left.values, right_chunk))
783-
return self._shallow_copy(dates)
773+
result = self._shallow_copy(dates)
774+
result._set_freq("infer")
775+
# TODO: can we infer that it has self.freq?
776+
return result
784777
else:
785778
left, right = other, self
786779

@@ -792,7 +785,10 @@ def _fast_union(self, other, sort=None):
792785
loc = right.searchsorted(left_end, side="right")
793786
right_chunk = right.values[loc:]
794787
dates = concat_compat((left.values, right_chunk))
795-
return self._shallow_copy(dates)
788+
result = self._shallow_copy(dates)
789+
result._set_freq("infer")
790+
# TODO: can we infer that it has self.freq?
791+
return result
796792
else:
797793
return left
798794

pandas/core/indexes/datetimes.py

+20-29
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,13 @@
77

88
from pandas._libs import NaT, Period, Timestamp, index as libindex, lib, tslib as libts
99
from pandas._libs.tslibs import fields, parsing, timezones
10+
from pandas._typing import Label
1011
from pandas.util._decorators import cache_readonly
1112

1213
from pandas.core.dtypes.common import _NS_DTYPE, is_float, is_integer, is_scalar
13-
from pandas.core.dtypes.dtypes import DatetimeTZDtype
1414
from pandas.core.dtypes.missing import is_valid_nat_for_dtype
1515

16-
from pandas.core.arrays.datetimes import (
17-
DatetimeArray,
18-
tz_to_dtype,
19-
validate_tz_from_dtype,
20-
)
16+
from pandas.core.arrays.datetimes import DatetimeArray, tz_to_dtype
2117
import pandas.core.common as com
2218
from pandas.core.indexes.base import Index, InvalidIndexError, maybe_extract_name
2319
from pandas.core.indexes.datetimelike import DatetimeTimedeltaMixin
@@ -36,7 +32,20 @@ def _new_DatetimeIndex(cls, d):
3632
if "data" in d and not isinstance(d["data"], DatetimeIndex):
3733
# Avoid need to verify integrity by calling simple_new directly
3834
data = d.pop("data")
39-
result = cls._simple_new(data, **d)
35+
if not isinstance(data, DatetimeArray):
36+
# For backward compat with older pickles, we may need to construct
37+
# a DatetimeArray to adapt to the newer _simple_new signature
38+
tz = d.pop("tz")
39+
freq = d.pop("freq")
40+
dta = DatetimeArray._simple_new(data, dtype=tz_to_dtype(tz), freq=freq)
41+
else:
42+
dta = data
43+
for key in ["tz", "freq"]:
44+
# These are already stored in our DatetimeArray; if they are
45+
# also in the pickle and don't match, we have a problem.
46+
if key in d:
47+
assert d.pop(key) == getattr(dta, key)
48+
result = cls._simple_new(dta, **d)
4049
else:
4150
with warnings.catch_warnings():
4251
# TODO: If we knew what was going in to **d, we might be able to
@@ -244,34 +253,16 @@ def __new__(
244253
return subarr
245254

246255
@classmethod
247-
def _simple_new(cls, values, name=None, freq=None, tz=None, dtype=None):
248-
"""
249-
We require the we have a dtype compat for the values
250-
if we are passed a non-dtype compat, then coerce using the constructor
251-
"""
252-
if isinstance(values, DatetimeArray):
253-
if tz:
254-
tz = validate_tz_from_dtype(dtype, tz)
255-
dtype = DatetimeTZDtype(tz=tz)
256-
elif dtype is None:
257-
dtype = _NS_DTYPE
258-
259-
values = DatetimeArray(values, freq=freq, dtype=dtype)
260-
tz = values.tz
261-
freq = values.freq
262-
values = values._data
263-
264-
dtype = tz_to_dtype(tz)
265-
dtarr = DatetimeArray._simple_new(values, freq=freq, dtype=dtype)
266-
assert isinstance(dtarr, DatetimeArray)
256+
def _simple_new(cls, values: DatetimeArray, name: Label = None):
257+
assert isinstance(values, DatetimeArray), type(values)
267258

268259
result = object.__new__(cls)
269-
result._data = dtarr
260+
result._data = values
270261
result.name = name
271262
result._cache = {}
272263
result._no_setting_name = False
273264
# For groupby perf. See note in indexes/base about _index_data
274-
result._index_data = dtarr._data
265+
result._index_data = values._data
275266
result._reset_identity()
276267
return result
277268

pandas/core/indexes/timedeltas.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
""" implement the TimedeltaIndex """
22

33
from pandas._libs import NaT, Timedelta, index as libindex
4+
from pandas._typing import Label
45
from pandas.util._decorators import Appender
56

67
from pandas.core.dtypes.common import (
@@ -154,7 +155,7 @@ def __new__(
154155
if isinstance(data, TimedeltaArray) and freq is None:
155156
if copy:
156157
data = data.copy()
157-
return cls._simple_new(data, name=name, freq=freq)
158+
return cls._simple_new(data, name=name)
158159

159160
if isinstance(data, TimedeltaIndex) and freq is None and name is None:
160161
if copy:
@@ -170,12 +171,8 @@ def __new__(
170171
return cls._simple_new(tdarr, name=name)
171172

172173
@classmethod
173-
def _simple_new(cls, values, name=None, freq=None, dtype=_TD_DTYPE):
174-
# `dtype` is passed by _shallow_copy in corner cases, should always
175-
# be timedelta64[ns] if present
176-
assert dtype == _TD_DTYPE, dtype
174+
def _simple_new(cls, values: TimedeltaArray, name: Label = None):
177175
assert isinstance(values, TimedeltaArray)
178-
assert freq is None or values.freq == freq
179176

180177
result = object.__new__(cls)
181178
result._data = values

pandas/tests/arrays/test_datetimelike.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def test_compare_len1_raises(self):
6565
# to the case where one has length-1, which numpy would broadcast
6666
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
6767

68-
idx = self.array_cls._simple_new(data, freq="D")
69-
arr = self.index_cls(idx)
68+
arr = self.array_cls._simple_new(data, freq="D")
69+
idx = self.index_cls(arr)
7070

7171
with pytest.raises(ValueError, match="Lengths must match"):
7272
arr == arr[:1]

pandas/tests/indexes/datetimes/test_ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def test_equals(self):
363363
assert not idx.equals(pd.Series(idx2))
364364

365365
# same internal, different tz
366-
idx3 = pd.DatetimeIndex._simple_new(idx.asi8, tz="US/Pacific")
366+
idx3 = pd.DatetimeIndex(idx.asi8, tz="US/Pacific")
367367
tm.assert_numpy_array_equal(idx.asi8, idx3.asi8)
368368
assert not idx.equals(idx3)
369369
assert not idx.equals(idx3.copy())

pandas/tseries/offsets.py

-15
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,6 @@ def apply_index(self, i):
337337
# integer addition on PeriodIndex is deprecated,
338338
# so we directly use _time_shift instead
339339
asper = i.to_period("W")
340-
if not isinstance(asper._data, np.ndarray):
341-
# unwrap PeriodIndex --> PeriodArray
342-
asper = asper._data
343340
shifted = asper._time_shift(weeks)
344341
i = shifted.to_timestamp() + i.to_perioddelta("W")
345342

@@ -629,9 +626,6 @@ def apply_index(self, i):
629626
# to_period rolls forward to next BDay; track and
630627
# reduce n where it does when rolling forward
631628
asper = i.to_period("B")
632-
if not isinstance(asper._data, np.ndarray):
633-
# unwrap PeriodIndex --> PeriodArray
634-
asper = asper._data
635629

636630
if self.n > 0:
637631
shifted = (i.to_perioddelta("B") - time).asi8 != 0
@@ -1384,9 +1378,6 @@ def apply_index(self, i):
13841378
# integer-array addition on PeriodIndex is deprecated,
13851379
# so we use _addsub_int_array directly
13861380
asper = i.to_period("M")
1387-
if not isinstance(asper._data, np.ndarray):
1388-
# unwrap PeriodIndex --> PeriodArray
1389-
asper = asper._data
13901381

13911382
shifted = asper._addsub_int_array(roll // 2, operator.add)
13921383
i = type(dti)(shifted.to_timestamp())
@@ -1582,9 +1573,6 @@ def apply_index(self, i):
15821573
# integer addition on PeriodIndex is deprecated,
15831574
# so we use _time_shift directly
15841575
asper = i.to_period("W")
1585-
if not isinstance(asper._data, np.ndarray):
1586-
# unwrap PeriodIndex --> PeriodArray
1587-
asper = asper._data
15881576

15891577
shifted = asper._time_shift(self.n)
15901578
return shifted.to_timestamp() + i.to_perioddelta("W")
@@ -1608,9 +1596,6 @@ def _end_apply_index(self, dtindex):
16081596

16091597
base, mult = libfrequencies.get_freq_code(self.freqstr)
16101598
base_period = dtindex.to_period(base)
1611-
if not isinstance(base_period._data, np.ndarray):
1612-
# unwrap PeriodIndex --> PeriodArray
1613-
base_period = base_period._data
16141599

16151600
if self.n > 0:
16161601
# when adding, dates on end roll to next

0 commit comments

Comments
 (0)