Skip to content

Commit decc8ce

Browse files
jbrockmendeljreback
authored andcommitted
Ensure TDA.__init__ validates freq (#24666)
1 parent 46a31c9 commit decc8ce

File tree

3 files changed

+44
-64
lines changed

3 files changed

+44
-64
lines changed

pandas/core/arrays/timedeltas.py

+30-61
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from pandas.util._decorators import Appender
1616

1717
from pandas.core.dtypes.common import (
18-
_NS_DTYPE, _TD_DTYPE, ensure_int64, is_datetime64_dtype, is_float_dtype,
19-
is_integer_dtype, is_list_like, is_object_dtype, is_scalar,
18+
_NS_DTYPE, _TD_DTYPE, ensure_int64, is_datetime64_dtype, is_dtype_equal,
19+
is_float_dtype, is_integer_dtype, is_list_like, is_object_dtype, is_scalar,
2020
is_string_dtype, is_timedelta64_dtype, is_timedelta64_ns_dtype,
2121
pandas_dtype)
2222
from pandas.core.dtypes.dtypes import DatetimeTZDtype
@@ -134,55 +134,39 @@ def dtype(self):
134134
_attributes = ["freq"]
135135

136136
def __init__(self, values, dtype=_TD_DTYPE, freq=None, copy=False):
137-
if isinstance(values, (ABCSeries, ABCIndexClass)):
138-
values = values._values
139-
140-
if isinstance(values, type(self)):
141-
values, freq, freq_infer = extract_values_freq(values, freq)
142-
143-
if not isinstance(values, np.ndarray):
144-
msg = (
137+
if not hasattr(values, "dtype"):
138+
raise ValueError(
145139
"Unexpected type '{}'. 'values' must be a TimedeltaArray "
146140
"ndarray, or Series or Index containing one of those."
147-
)
148-
raise ValueError(msg.format(type(values).__name__))
149-
150-
if values.dtype == 'i8':
151-
# for compat with datetime/timedelta/period shared methods,
152-
# we can sometimes get here with int64 values. These represent
153-
# nanosecond UTC (or tz-naive) unix timestamps
154-
values = values.view(_TD_DTYPE)
155-
156-
if values.dtype != _TD_DTYPE:
157-
raise TypeError(_BAD_DTYPE.format(dtype=values.dtype))
158-
159-
try:
160-
dtype_mismatch = dtype != _TD_DTYPE
161-
except TypeError:
162-
raise TypeError(_BAD_DTYPE.format(dtype=dtype))
163-
else:
164-
if dtype_mismatch:
165-
raise TypeError(_BAD_DTYPE.format(dtype=dtype))
166-
141+
.format(type(values).__name__))
167142
if freq == "infer":
168-
msg = (
143+
raise ValueError(
169144
"Frequency inference not allowed in TimedeltaArray.__init__. "
170-
"Use 'pd.array()' instead."
171-
)
172-
raise ValueError(msg)
145+
"Use 'pd.array()' instead.")
173146

174-
if copy:
175-
values = values.copy()
176-
if freq:
177-
freq = to_offset(freq)
147+
if dtype is not None and not is_dtype_equal(dtype, _TD_DTYPE):
148+
raise TypeError("dtype {dtype} cannot be converted to "
149+
"timedelta64[ns]".format(dtype=dtype))
150+
151+
if values.dtype == 'i8':
152+
values = values.view('timedelta64[ns]')
178153

179-
self._data = values
180-
self._dtype = dtype
181-
self._freq = freq
154+
result = type(self)._from_sequence(values, dtype=dtype,
155+
copy=copy, freq=freq)
156+
self._data = result._data
157+
self._freq = result._freq
158+
self._dtype = result._dtype
182159

183160
@classmethod
184161
def _simple_new(cls, values, freq=None, dtype=_TD_DTYPE):
185-
return cls(values, dtype=dtype, freq=freq)
162+
assert dtype == _TD_DTYPE, dtype
163+
assert isinstance(values, np.ndarray), type(values)
164+
165+
result = object.__new__(cls)
166+
result._data = values.view(_TD_DTYPE)
167+
result._freq = to_offset(freq)
168+
result._dtype = _TD_DTYPE
169+
return result
186170

187171
@classmethod
188172
def _from_sequence(cls, data, dtype=_TD_DTYPE, copy=False,
@@ -860,17 +844,17 @@ def sequence_to_td64ns(data, copy=False, unit="ns", errors="raise"):
860844
data = data._data
861845

862846
# Convert whatever we have into timedelta64[ns] dtype
863-
if is_object_dtype(data) or is_string_dtype(data):
847+
if is_object_dtype(data.dtype) or is_string_dtype(data.dtype):
864848
# no need to make a copy, need to convert if string-dtyped
865849
data = objects_to_td64ns(data, unit=unit, errors=errors)
866850
copy = False
867851

868-
elif is_integer_dtype(data):
852+
elif is_integer_dtype(data.dtype):
869853
# treat as multiples of the given unit
870854
data, copy_made = ints_to_td64ns(data, unit=unit)
871855
copy = copy and not copy_made
872856

873-
elif is_float_dtype(data):
857+
elif is_float_dtype(data.dtype):
874858
# treat as multiples of the given unit. If after converting to nanos,
875859
# there are fractional components left, these are truncated
876860
# (i.e. NOT rounded)
@@ -880,7 +864,7 @@ def sequence_to_td64ns(data, copy=False, unit="ns", errors="raise"):
880864
data[mask] = iNaT
881865
copy = False
882866

883-
elif is_timedelta64_dtype(data):
867+
elif is_timedelta64_dtype(data.dtype):
884868
if data.dtype != _TD_DTYPE:
885869
# non-nano unit
886870
# TODO: watch out for overflows
@@ -998,18 +982,3 @@ def _generate_regular_range(start, end, periods, offset):
998982

999983
data = np.arange(b, e, stride, dtype=np.int64)
1000984
return data
1001-
1002-
1003-
def extract_values_freq(arr, freq):
1004-
# type: (TimedeltaArray, Offset) -> Tuple[ndarray, Offset, bool]
1005-
freq_infer = False
1006-
if freq is None:
1007-
freq = arr.freq
1008-
elif freq and arr.freq:
1009-
freq = to_offset(freq)
1010-
freq, freq_infer = dtl.validate_inferred_freq(
1011-
freq, arr.freq,
1012-
freq_infer=False
1013-
)
1014-
values = arr._data
1015-
return values, freq, freq_infer

pandas/core/indexes/timedeltas.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,14 @@ def _simple_new(cls, values, name=None, freq=None, dtype=_TD_DTYPE):
233233
if not isinstance(values, TimedeltaArray):
234234
values = TimedeltaArray._simple_new(values, dtype=dtype,
235235
freq=freq)
236+
else:
237+
if freq is None:
238+
freq = values.freq
236239
assert isinstance(values, TimedeltaArray), type(values)
237240
assert dtype == _TD_DTYPE, dtype
238241
assert values.dtype == 'm8[ns]', values.dtype
239242

240-
freq = to_offset(freq)
241-
tdarr = TimedeltaArray._simple_new(values, freq=freq)
243+
tdarr = TimedeltaArray._simple_new(values._data, freq=freq)
242244
result = object.__new__(cls)
243245
result._data = tdarr
244246
result.name = name

pandas/tests/arrays/test_timedeltas.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,15 @@
99

1010

1111
class TestTimedeltaArrayConstructor(object):
12+
def test_freq_validation(self):
13+
# ensure that the public constructor cannot create an invalid instance
14+
arr = np.array([0, 0, 1], dtype=np.int64) * 3600 * 10**9
15+
16+
msg = ("Inferred frequency None from passed values does not "
17+
"conform to passed frequency D")
18+
with pytest.raises(ValueError, match=msg):
19+
TimedeltaArray(arr.view('timedelta64[ns]'), freq="D")
20+
1221
def test_non_array_raises(self):
1322
with pytest.raises(ValueError, match='list'):
1423
TimedeltaArray([1, 2, 3])
@@ -34,7 +43,7 @@ def test_incorrect_dtype_raises(self):
3443
def test_copy(self):
3544
data = np.array([1, 2, 3], dtype='m8[ns]')
3645
arr = TimedeltaArray(data, copy=False)
37-
assert arr._data is data
46+
assert arr._data.base is data
3847

3948
arr = TimedeltaArray(data, copy=True)
4049
assert arr._data is not data

0 commit comments

Comments
 (0)