Skip to content

Commit d5bf954

Browse files
jbrockmendelPingviinituutti
authored andcommitted
strictness and checks for Timedelta _simple_new (pandas-dev#23433)
1 parent ebf92b8 commit d5bf954

File tree

2 files changed

+43
-16
lines changed

2 files changed

+43
-16
lines changed

pandas/core/arrays/timedeltas.py

+16-14
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pandas import compat
1212

1313
from pandas.core.dtypes.common import (
14-
_TD_DTYPE, ensure_int64, is_timedelta64_dtype, is_list_like)
14+
_TD_DTYPE, is_list_like)
1515
from pandas.core.dtypes.generic import ABCSeries
1616
from pandas.core.dtypes.missing import isna
1717

@@ -111,16 +111,16 @@ def dtype(self):
111111
_attributes = ["freq"]
112112

113113
@classmethod
114-
def _simple_new(cls, values, freq=None, **kwargs):
115-
values = np.array(values, copy=False)
116-
if values.dtype == np.object_:
117-
values = array_to_timedelta64(values)
118-
if values.dtype != _TD_DTYPE:
119-
if is_timedelta64_dtype(values):
120-
# non-nano unit
121-
values = values.astype(_TD_DTYPE)
122-
else:
123-
values = ensure_int64(values).view(_TD_DTYPE)
114+
def _simple_new(cls, values, freq=None, dtype=_TD_DTYPE):
115+
# `dtype` is passed by _shallow_copy in corner cases, should always
116+
# be timedelta64[ns] if present
117+
assert dtype == _TD_DTYPE
118+
assert isinstance(values, np.ndarray), type(values)
119+
120+
if values.dtype == 'i8':
121+
values = values.view('m8[ns]')
122+
123+
assert values.dtype == 'm8[ns]'
124124

125125
result = object.__new__(cls)
126126
result._data = values
@@ -131,6 +131,10 @@ def __new__(cls, values, freq=None):
131131

132132
freq, freq_infer = dtl.maybe_infer_freq(freq)
133133

134+
values = np.array(values, copy=False)
135+
if values.dtype == np.object_:
136+
values = array_to_timedelta64(values)
137+
134138
result = cls._simple_new(values, freq=freq)
135139
if freq_infer:
136140
inferred = result.inferred_freq
@@ -166,17 +170,15 @@ def _generate_range(cls, start, end, periods, freq, closed=None):
166170

167171
if freq is not None:
168172
index = _generate_regular_range(start, end, periods, freq)
169-
index = cls._simple_new(index, freq=freq)
170173
else:
171174
index = np.linspace(start.value, end.value, periods).astype('i8')
172-
index = cls._simple_new(index, freq=freq)
173175

174176
if not left_closed:
175177
index = index[1:]
176178
if not right_closed:
177179
index = index[:-1]
178180

179-
return index
181+
return cls._simple_new(index, freq=freq)
180182

181183
# ----------------------------------------------------------------
182184
# Arithmetic Methods

pandas/core/indexes/timedeltas.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
to_timedelta, _coerce_scalar_to_timedelta_type)
3636
from pandas._libs import (lib, index as libindex,
3737
join as libjoin, Timedelta, NaT)
38+
from pandas._libs.tslibs.timedeltas import array_to_timedelta64
3839

3940

4041
class TimedeltaIndex(TimedeltaArrayMixin, DatetimeIndexOpsMixin,
@@ -166,6 +167,19 @@ def __new__(cls, data=None, unit=None, freq=None, start=None, end=None,
166167
elif copy:
167168
data = np.array(data, copy=True)
168169

170+
data = np.array(data, copy=False)
171+
if data.dtype == np.object_:
172+
data = array_to_timedelta64(data)
173+
if data.dtype != _TD_DTYPE:
174+
if is_timedelta64_dtype(data):
175+
# non-nano unit
176+
# TODO: watch out for overflows
177+
data = data.astype(_TD_DTYPE)
178+
else:
179+
data = ensure_int64(data).view(_TD_DTYPE)
180+
181+
assert data.dtype == 'm8[ns]', data.dtype
182+
169183
subarr = cls._simple_new(data, name=name, freq=freq)
170184
# check that we are matching freqs
171185
if verify_integrity and len(subarr) > 0:
@@ -180,12 +194,23 @@ def __new__(cls, data=None, unit=None, freq=None, start=None, end=None,
180194
return subarr
181195

182196
@classmethod
183-
def _simple_new(cls, values, name=None, freq=None, **kwargs):
184-
result = super(TimedeltaIndex, cls)._simple_new(values, freq, **kwargs)
197+
def _simple_new(cls, values, name=None, freq=None, dtype=_TD_DTYPE):
198+
# `dtype` is passed by _shallow_copy in corner cases, should always
199+
# be timedelta64[ns] if present
200+
assert dtype == _TD_DTYPE
201+
202+
assert isinstance(values, np.ndarray), type(values)
203+
if values.dtype == 'i8':
204+
values = values.view('m8[ns]')
205+
assert values.dtype == 'm8[ns]', values.dtype
206+
207+
result = super(TimedeltaIndex, cls)._simple_new(values, freq)
185208
result.name = name
186209
result._reset_identity()
187210
return result
188211

212+
_shallow_copy = Index._shallow_copy
213+
189214
@property
190215
def _formatter_func(self):
191216
from pandas.io.formats.format import _get_format_timedelta64

0 commit comments

Comments
 (0)