diff --git a/pandas/core/arrays/timedeltas.py b/pandas/core/arrays/timedeltas.py index 1ec37c9f228a6..47b3f93f88b78 100644 --- a/pandas/core/arrays/timedeltas.py +++ b/pandas/core/arrays/timedeltas.py @@ -15,8 +15,8 @@ from pandas.util._decorators import Appender from pandas.core.dtypes.common import ( - _NS_DTYPE, _TD_DTYPE, ensure_int64, is_datetime64_dtype, is_float_dtype, - is_integer_dtype, is_list_like, is_object_dtype, is_scalar, + _NS_DTYPE, _TD_DTYPE, ensure_int64, is_datetime64_dtype, is_dtype_equal, + is_float_dtype, is_integer_dtype, is_list_like, is_object_dtype, is_scalar, is_string_dtype, is_timedelta64_dtype, is_timedelta64_ns_dtype, pandas_dtype) from pandas.core.dtypes.dtypes import DatetimeTZDtype @@ -134,55 +134,39 @@ def dtype(self): _attributes = ["freq"] def __init__(self, values, dtype=_TD_DTYPE, freq=None, copy=False): - if isinstance(values, (ABCSeries, ABCIndexClass)): - values = values._values - - if isinstance(values, type(self)): - values, freq, freq_infer = extract_values_freq(values, freq) - - if not isinstance(values, np.ndarray): - msg = ( + if not hasattr(values, "dtype"): + raise ValueError( "Unexpected type '{}'. 'values' must be a TimedeltaArray " "ndarray, or Series or Index containing one of those." - ) - raise ValueError(msg.format(type(values).__name__)) - - if values.dtype == 'i8': - # for compat with datetime/timedelta/period shared methods, - # we can sometimes get here with int64 values. These represent - # nanosecond UTC (or tz-naive) unix timestamps - values = values.view(_TD_DTYPE) - - if values.dtype != _TD_DTYPE: - raise TypeError(_BAD_DTYPE.format(dtype=values.dtype)) - - try: - dtype_mismatch = dtype != _TD_DTYPE - except TypeError: - raise TypeError(_BAD_DTYPE.format(dtype=dtype)) - else: - if dtype_mismatch: - raise TypeError(_BAD_DTYPE.format(dtype=dtype)) - + .format(type(values).__name__)) if freq == "infer": - msg = ( + raise ValueError( "Frequency inference not allowed in TimedeltaArray.__init__. " - "Use 'pd.array()' instead." - ) - raise ValueError(msg) + "Use 'pd.array()' instead.") - if copy: - values = values.copy() - if freq: - freq = to_offset(freq) + if dtype is not None and not is_dtype_equal(dtype, _TD_DTYPE): + raise TypeError("dtype {dtype} cannot be converted to " + "timedelta64[ns]".format(dtype=dtype)) + + if values.dtype == 'i8': + values = values.view('timedelta64[ns]') - self._data = values - self._dtype = dtype - self._freq = freq + result = type(self)._from_sequence(values, dtype=dtype, + copy=copy, freq=freq) + self._data = result._data + self._freq = result._freq + self._dtype = result._dtype @classmethod def _simple_new(cls, values, freq=None, dtype=_TD_DTYPE): - return cls(values, dtype=dtype, freq=freq) + assert dtype == _TD_DTYPE, dtype + assert isinstance(values, np.ndarray), type(values) + + result = object.__new__(cls) + result._data = values.view(_TD_DTYPE) + result._freq = to_offset(freq) + result._dtype = _TD_DTYPE + return result @classmethod def _from_sequence(cls, data, dtype=_TD_DTYPE, copy=False, @@ -860,17 +844,17 @@ def sequence_to_td64ns(data, copy=False, unit="ns", errors="raise"): data = data._data # Convert whatever we have into timedelta64[ns] dtype - if is_object_dtype(data) or is_string_dtype(data): + if is_object_dtype(data.dtype) or is_string_dtype(data.dtype): # no need to make a copy, need to convert if string-dtyped data = objects_to_td64ns(data, unit=unit, errors=errors) copy = False - elif is_integer_dtype(data): + elif is_integer_dtype(data.dtype): # treat as multiples of the given unit data, copy_made = ints_to_td64ns(data, unit=unit) copy = copy and not copy_made - elif is_float_dtype(data): + elif is_float_dtype(data.dtype): # treat as multiples of the given unit. If after converting to nanos, # there are fractional components left, these are truncated # (i.e. NOT rounded) @@ -880,7 +864,7 @@ def sequence_to_td64ns(data, copy=False, unit="ns", errors="raise"): data[mask] = iNaT copy = False - elif is_timedelta64_dtype(data): + elif is_timedelta64_dtype(data.dtype): if data.dtype != _TD_DTYPE: # non-nano unit # TODO: watch out for overflows @@ -998,18 +982,3 @@ def _generate_regular_range(start, end, periods, offset): data = np.arange(b, e, stride, dtype=np.int64) return data - - -def extract_values_freq(arr, freq): - # type: (TimedeltaArray, Offset) -> Tuple[ndarray, Offset, bool] - freq_infer = False - if freq is None: - freq = arr.freq - elif freq and arr.freq: - freq = to_offset(freq) - freq, freq_infer = dtl.validate_inferred_freq( - freq, arr.freq, - freq_infer=False - ) - values = arr._data - return values, freq, freq_infer diff --git a/pandas/core/indexes/timedeltas.py b/pandas/core/indexes/timedeltas.py index b9d6b8da2cada..893926cc076ab 100644 --- a/pandas/core/indexes/timedeltas.py +++ b/pandas/core/indexes/timedeltas.py @@ -233,12 +233,14 @@ def _simple_new(cls, values, name=None, freq=None, dtype=_TD_DTYPE): if not isinstance(values, TimedeltaArray): values = TimedeltaArray._simple_new(values, dtype=dtype, freq=freq) + else: + if freq is None: + freq = values.freq assert isinstance(values, TimedeltaArray), type(values) assert dtype == _TD_DTYPE, dtype assert values.dtype == 'm8[ns]', values.dtype - freq = to_offset(freq) - tdarr = TimedeltaArray._simple_new(values, freq=freq) + tdarr = TimedeltaArray._simple_new(values._data, freq=freq) result = object.__new__(cls) result._data = tdarr result.name = name diff --git a/pandas/tests/arrays/test_timedeltas.py b/pandas/tests/arrays/test_timedeltas.py index 481350640e1a6..af23b2467fcdf 100644 --- a/pandas/tests/arrays/test_timedeltas.py +++ b/pandas/tests/arrays/test_timedeltas.py @@ -9,6 +9,15 @@ class TestTimedeltaArrayConstructor(object): + def test_freq_validation(self): + # ensure that the public constructor cannot create an invalid instance + arr = np.array([0, 0, 1], dtype=np.int64) * 3600 * 10**9 + + msg = ("Inferred frequency None from passed values does not " + "conform to passed frequency D") + with pytest.raises(ValueError, match=msg): + TimedeltaArray(arr.view('timedelta64[ns]'), freq="D") + def test_non_array_raises(self): with pytest.raises(ValueError, match='list'): TimedeltaArray([1, 2, 3]) @@ -34,7 +43,7 @@ def test_incorrect_dtype_raises(self): def test_copy(self): data = np.array([1, 2, 3], dtype='m8[ns]') arr = TimedeltaArray(data, copy=False) - assert arr._data is data + assert arr._data.base is data arr = TimedeltaArray(data, copy=True) assert arr._data is not data