Skip to content

Ensure TDA.__init__ validates freq #24666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 9, 2019
91 changes: 30 additions & 61 deletions pandas/core/arrays/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from pandas.util._decorators import Appender

from pandas.core.dtypes.common import (
_NS_DTYPE, _TD_DTYPE, ensure_int64, is_datetime64_dtype, is_float_dtype,
is_integer_dtype, is_list_like, is_object_dtype, is_scalar,
_NS_DTYPE, _TD_DTYPE, ensure_int64, is_datetime64_dtype, is_dtype_equal,
is_float_dtype, is_integer_dtype, is_list_like, is_object_dtype, is_scalar,
is_string_dtype, is_timedelta64_dtype, is_timedelta64_ns_dtype,
pandas_dtype)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
Expand Down Expand Up @@ -134,55 +134,39 @@ def dtype(self):
_attributes = ["freq"]

def __init__(self, values, dtype=_TD_DTYPE, freq=None, copy=False):
if isinstance(values, (ABCSeries, ABCIndexClass)):
values = values._values

if isinstance(values, type(self)):
values, freq, freq_infer = extract_values_freq(values, freq)

if not isinstance(values, np.ndarray):
msg = (
if not hasattr(values, "dtype"):
raise ValueError(
"Unexpected type '{}'. 'values' must be a TimedeltaArray "
"ndarray, or Series or Index containing one of those."
)
raise ValueError(msg.format(type(values).__name__))

if values.dtype == 'i8':
# for compat with datetime/timedelta/period shared methods,
# we can sometimes get here with int64 values. These represent
# nanosecond UTC (or tz-naive) unix timestamps
values = values.view(_TD_DTYPE)

if values.dtype != _TD_DTYPE:
raise TypeError(_BAD_DTYPE.format(dtype=values.dtype))

try:
dtype_mismatch = dtype != _TD_DTYPE
except TypeError:
raise TypeError(_BAD_DTYPE.format(dtype=dtype))
else:
if dtype_mismatch:
raise TypeError(_BAD_DTYPE.format(dtype=dtype))

.format(type(values).__name__))
if freq == "infer":
msg = (
raise ValueError(
"Frequency inference not allowed in TimedeltaArray.__init__. "
"Use 'pd.array()' instead."
)
raise ValueError(msg)
"Use 'pd.array()' instead.")

if copy:
values = values.copy()
if freq:
freq = to_offset(freq)
if dtype is not None and not is_dtype_equal(dtype, _TD_DTYPE):
raise TypeError("dtype {dtype} cannot be converted to "
"timedelta64[ns]".format(dtype=dtype))

if values.dtype == 'i8':
values = values.view('timedelta64[ns]')

self._data = values
self._dtype = dtype
self._freq = freq
result = type(self)._from_sequence(values, dtype=dtype,
copy=copy, freq=freq)
self._data = result._data
self._freq = result._freq
self._dtype = result._dtype

@classmethod
def _simple_new(cls, values, freq=None, dtype=_TD_DTYPE):
return cls(values, dtype=dtype, freq=freq)
assert dtype == _TD_DTYPE, dtype
assert isinstance(values, np.ndarray), type(values)

result = object.__new__(cls)
result._data = values.view(_TD_DTYPE)
result._freq = to_offset(freq)
result._dtype = _TD_DTYPE
return result

@classmethod
def _from_sequence(cls, data, dtype=_TD_DTYPE, copy=False,
Expand Down Expand Up @@ -860,17 +844,17 @@ def sequence_to_td64ns(data, copy=False, unit="ns", errors="raise"):
data = data._data

# Convert whatever we have into timedelta64[ns] dtype
if is_object_dtype(data) or is_string_dtype(data):
if is_object_dtype(data.dtype) or is_string_dtype(data.dtype):
# no need to make a copy, need to convert if string-dtyped
data = objects_to_td64ns(data, unit=unit, errors=errors)
copy = False

elif is_integer_dtype(data):
elif is_integer_dtype(data.dtype):
# treat as multiples of the given unit
data, copy_made = ints_to_td64ns(data, unit=unit)
copy = copy and not copy_made

elif is_float_dtype(data):
elif is_float_dtype(data.dtype):
# treat as multiples of the given unit. If after converting to nanos,
# there are fractional components left, these are truncated
# (i.e. NOT rounded)
Expand All @@ -880,7 +864,7 @@ def sequence_to_td64ns(data, copy=False, unit="ns", errors="raise"):
data[mask] = iNaT
copy = False

elif is_timedelta64_dtype(data):
elif is_timedelta64_dtype(data.dtype):
if data.dtype != _TD_DTYPE:
# non-nano unit
# TODO: watch out for overflows
Expand Down Expand Up @@ -998,18 +982,3 @@ def _generate_regular_range(start, end, periods, offset):

data = np.arange(b, e, stride, dtype=np.int64)
return data


def extract_values_freq(arr, freq):
# type: (TimedeltaArray, Offset) -> Tuple[ndarray, Offset, bool]
freq_infer = False
if freq is None:
freq = arr.freq
elif freq and arr.freq:
freq = to_offset(freq)
freq, freq_infer = dtl.validate_inferred_freq(
freq, arr.freq,
freq_infer=False
)
values = arr._data
return values, freq, freq_infer
6 changes: 4 additions & 2 deletions pandas/core/indexes/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,14 @@ def _simple_new(cls, values, name=None, freq=None, dtype=_TD_DTYPE):
if not isinstance(values, TimedeltaArray):
values = TimedeltaArray._simple_new(values, dtype=dtype,
freq=freq)
else:
if freq is None:
freq = values.freq
assert isinstance(values, TimedeltaArray), type(values)
assert dtype == _TD_DTYPE, dtype
assert values.dtype == 'm8[ns]', values.dtype

freq = to_offset(freq)
tdarr = TimedeltaArray._simple_new(values, freq=freq)
tdarr = TimedeltaArray._simple_new(values._data, freq=freq)
result = object.__new__(cls)
result._data = tdarr
result.name = name
Expand Down
11 changes: 10 additions & 1 deletion pandas/tests/arrays/test_timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@


class TestTimedeltaArrayConstructor(object):
def test_freq_validation(self):
# ensure that the public constructor cannot create an invalid instance
arr = np.array([0, 0, 1], dtype=np.int64) * 3600 * 10**9

msg = ("Inferred frequency None from passed values does not "
"conform to passed frequency D")
with pytest.raises(ValueError, match=msg):
TimedeltaArray(arr.view('timedelta64[ns]'), freq="D")

def test_non_array_raises(self):
with pytest.raises(ValueError, match='list'):
TimedeltaArray([1, 2, 3])
Expand All @@ -34,7 +43,7 @@ def test_incorrect_dtype_raises(self):
def test_copy(self):
data = np.array([1, 2, 3], dtype='m8[ns]')
arr = TimedeltaArray(data, copy=False)
assert arr._data is data
assert arr._data.base is data

arr = TimedeltaArray(data, copy=True)
assert arr._data is not data
Expand Down