From 5edf59a2aa6cce0b450b2830eb360b0c49152931 Mon Sep 17 00:00:00 2001 From: Brock Mendel Date: Wed, 31 Oct 2018 12:36:49 -0700 Subject: [PATCH] strictness and checks for Timedelta _simple_new --- pandas/core/arrays/timedeltas.py | 30 ++++++++++++++++-------------- pandas/core/indexes/base.py | 6 ++++-- pandas/core/indexes/timedeltas.py | 29 +++++++++++++++++++++++++++-- 3 files changed, 47 insertions(+), 18 deletions(-) diff --git a/pandas/core/arrays/timedeltas.py b/pandas/core/arrays/timedeltas.py index 397297c1b88d0..9653121879c0d 100644 --- a/pandas/core/arrays/timedeltas.py +++ b/pandas/core/arrays/timedeltas.py @@ -11,7 +11,7 @@ from pandas import compat from pandas.core.dtypes.common import ( - _TD_DTYPE, ensure_int64, is_timedelta64_dtype, is_list_like) + _TD_DTYPE, is_list_like) from pandas.core.dtypes.generic import ABCSeries from pandas.core.dtypes.missing import isna @@ -111,16 +111,16 @@ def dtype(self): _attributes = ["freq"] @classmethod - def _simple_new(cls, values, freq=None, **kwargs): - values = np.array(values, copy=False) - if values.dtype == np.object_: - values = array_to_timedelta64(values) - if values.dtype != _TD_DTYPE: - if is_timedelta64_dtype(values): - # non-nano unit - values = values.astype(_TD_DTYPE) - else: - values = ensure_int64(values).view(_TD_DTYPE) + def _simple_new(cls, values, freq=None, dtype=_TD_DTYPE): + # `dtype` is passed by _shallow_copy in corner cases, should always + # be timedelta64[ns] if present + assert dtype == _TD_DTYPE + assert isinstance(values, np.ndarray), type(values) + + if values.dtype == 'i8': + values = values.view('m8[ns]') + + assert values.dtype == 'm8[ns]' result = object.__new__(cls) result._data = values @@ -131,6 +131,10 @@ def __new__(cls, values, freq=None): freq, freq_infer = dtl.maybe_infer_freq(freq) + values = np.array(values, copy=False) + if values.dtype == np.object_: + values = array_to_timedelta64(values) + result = cls._simple_new(values, freq=freq) if freq_infer: inferred = result.inferred_freq @@ -166,17 +170,15 @@ def _generate_range(cls, start, end, periods, freq, closed=None): if freq is not None: index = _generate_regular_range(start, end, periods, freq) - index = cls._simple_new(index, freq=freq) else: index = np.linspace(start.value, end.value, periods).astype('i8') - index = cls._simple_new(index, freq=freq) if not left_closed: index = index[1:] if not right_closed: index = index[:-1] - return index + return cls._simple_new(index, freq=freq) # ---------------------------------------------------------------- # Arithmetic Methods diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 1ffdac1989129..a8a298a231de4 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -2947,7 +2947,8 @@ def difference(self, other): self._assert_can_do_setop(other) if self.equals(other): - return self._shallow_copy([]) + # pass an empty np.ndarray with the appropriate dtype + return self._shallow_copy(self._data[:0]) other, result_name = self._convert_can_do_setop(other) @@ -3715,7 +3716,8 @@ def reindex(self, target, method=None, level=None, limit=None, if not isinstance(target, Index) and len(target) == 0: attrs = self._get_attributes_dict() attrs.pop('freq', None) # don't preserve freq - target = self._simple_new(None, dtype=self.dtype, **attrs) + values = self._data[:0] # empty array with appropriate dtype + target = self._simple_new(values, dtype=self.dtype, **attrs) else: target = ensure_index(target) diff --git a/pandas/core/indexes/timedeltas.py b/pandas/core/indexes/timedeltas.py index e5da21478d0a4..22ecefae8cbe2 100644 --- a/pandas/core/indexes/timedeltas.py +++ b/pandas/core/indexes/timedeltas.py @@ -35,6 +35,7 @@ to_timedelta, _coerce_scalar_to_timedelta_type) from pandas._libs import (lib, index as libindex, join as libjoin, Timedelta, NaT) +from pandas._libs.tslibs.timedeltas import array_to_timedelta64 class TimedeltaIndex(TimedeltaArrayMixin, DatetimeIndexOpsMixin, @@ -166,6 +167,19 @@ def __new__(cls, data=None, unit=None, freq=None, start=None, end=None, elif copy: data = np.array(data, copy=True) + data = np.array(data, copy=False) + if data.dtype == np.object_: + data = array_to_timedelta64(data) + if data.dtype != _TD_DTYPE: + if is_timedelta64_dtype(data): + # non-nano unit + # TODO: watch out for overflows + data = data.astype(_TD_DTYPE) + else: + data = ensure_int64(data).view(_TD_DTYPE) + + assert data.dtype == 'm8[ns]', data.dtype + subarr = cls._simple_new(data, name=name, freq=freq) # check that we are matching freqs if verify_integrity and len(subarr) > 0: @@ -180,12 +194,23 @@ def __new__(cls, data=None, unit=None, freq=None, start=None, end=None, return subarr @classmethod - def _simple_new(cls, values, name=None, freq=None, **kwargs): - result = super(TimedeltaIndex, cls)._simple_new(values, freq, **kwargs) + def _simple_new(cls, values, name=None, freq=None, dtype=_TD_DTYPE): + # `dtype` is passed by _shallow_copy in corner cases, should always + # be timedelta64[ns] if present + assert dtype == _TD_DTYPE + + assert isinstance(values, np.ndarray), type(values) + if values.dtype == 'i8': + values = values.view('m8[ns]') + assert values.dtype == 'm8[ns]', values.dtype + + result = super(TimedeltaIndex, cls)._simple_new(values, freq) result.name = name result._reset_identity() return result + _shallow_copy = Index._shallow_copy + @property def _formatter_func(self): from pandas.io.formats.format import _get_format_timedelta64