Skip to content

REF: stricter checks in _simple_new, avoid shallow_copy in EAs #23426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,9 @@ def _add_nat(self):
# and datetime dtypes
result = np.zeros(len(self), dtype=np.int64)
result.fill(iNaT)
return self._shallow_copy(result, freq=None)
if is_timedelta64_dtype(self):
return type(self)(result, freq=None)
return type(self)(result, tz=self.tz, freq=None)

def _sub_nat(self):
"""Subtract pd.NaT from self"""
Expand Down
83 changes: 47 additions & 36 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,11 @@ def _simple_new(cls, values, freq=None, tz=None, **kwargs):
we require the we have a dtype compat for the values
if we are passed a non-dtype compat, then coerce using the constructor
"""
assert isinstance(values, np.ndarray), type(values)
if values.dtype == 'i8':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment here about what this is doing

values = values.view('M8[ns]')

if getattr(values, 'dtype', None) is None:
# empty, but with dtype compat
if values is None:
values = np.empty(0, dtype=_NS_DTYPE)
return cls(values, freq=freq, tz=tz, **kwargs)
values = np.array(values, copy=False)

if not is_datetime64_dtype(values):
values = ensure_int64(values).view(_NS_DTYPE)
assert values.dtype == 'M8[ns]', values.dtype

result = object.__new__(cls)
result._data = values
Expand All @@ -209,6 +204,15 @@ def __new__(cls, values, freq=None, tz=None, dtype=None):
# if dtype has an embedded tz, capture it
tz = dtl.validate_tz_from_dtype(dtype, tz)

if isinstance(values, DatetimeArrayMixin):
values = values.asi8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why getting the integers here and not M8 directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the way the constructors view i8 values is more consistent. For tz-naive, M8 vs i8 are equivalent. For tz-aware, i8 is interpreted as unix timestamps (i.e. UTC), whereas M8 are interpreted as the wall-time in the given timezone.

if values.dtype == 'i8':
values = values.view('M8[ns]')

assert isinstance(values, np.ndarray), type(values)
assert is_datetime64_dtype(values)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

above you check explicitly for 'M8[ns]', because here it can still be another resolution?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct

values = conversion.ensure_datetime64ns(values, copy=False)

result = cls._simple_new(values, freq=freq, tz=tz)
if freq_infer:
inferred = result.inferred_freq
Expand Down Expand Up @@ -253,28 +257,22 @@ def _generate_range(cls, start, end, periods, freq, tz=None,

if tz is not None:
# Localize the start and end arguments
start = _maybe_localize_point(
start, getattr(start, 'tz', None), start, freq, tz
)
end = _maybe_localize_point(
end, getattr(end, 'tz', None), end, freq, tz
)
start = _maybe_localize_point(start, getattr(start, 'tz', None),
start, freq, tz)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, can you leave such style changes only to lines you actually change anyway?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah and the former is actually more idiomatic, really prefer not to do partial line wrapping

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

end = _maybe_localize_point(end, getattr(end, 'tz', None),
end, freq, tz)
if start and end:
# Make sure start and end have the same tz
start = _maybe_localize_point(
start, start.tz, end.tz, freq, tz
)
end = _maybe_localize_point(
end, end.tz, start.tz, freq, tz
)
start = _maybe_localize_point(start, start.tz, end.tz, freq, tz)
end = _maybe_localize_point(end, end.tz, start.tz, freq, tz)

if freq is not None:
# TODO: consider re-implementing _cached_range; GH#17914
index = _generate_regular_range(cls, start, end, periods, freq)

if tz is not None and getattr(index, 'tz', None) is None:
arr = conversion.tz_localize_to_utc(
ensure_int64(index.values),
tz, ambiguous=ambiguous)
if tz is not None and index.tz is None:
arr = conversion.tz_localize_to_utc(ensure_int64(index.values),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrap on the ensure_int64

tz, ambiguous=ambiguous)

index = cls(arr)

Expand All @@ -287,9 +285,8 @@ def _generate_range(cls, start, end, periods, freq, tz=None,
else:
# Create a linearly spaced date_range in local time
arr = np.linspace(start.value, end.value, periods)
index = cls._simple_new(
arr.astype('M8[ns]', copy=False), freq=None, tz=tz
)
arr = arr.astype('M8[ns]', copy=False)
index = cls._simple_new(arr, freq=None, tz=tz)

if not left_closed and len(index) and index[0] == start:
index = index[1:]
Expand Down Expand Up @@ -586,7 +583,7 @@ def tz_convert(self, tz):
'tz_localize to localize')

# No conversion since timestamps are all UTC to begin with
return self._shallow_copy(tz=tz)
return self._simple_new(self.asi8, tz=tz, freq=self.freq)

def tz_localize(self, tz, ambiguous='raise', nonexistent='raise',
errors=None):
Expand Down Expand Up @@ -708,7 +705,7 @@ def tz_localize(self, tz, ambiguous='raise', nonexistent='raise',
self.asi8, tz, ambiguous=ambiguous, nonexistent=nonexistent,
)
new_dates = new_dates.view(_NS_DTYPE)
return self._shallow_copy(new_dates, tz=tz)
return self._simple_new(new_dates, tz=tz, freq=self.freq)

# ----------------------------------------------------------------
# Conversion Methods - Vectorized analogues of Timestamp methods
Expand Down Expand Up @@ -843,7 +840,8 @@ def to_perioddelta(self, freq):
# TODO: consider privatizing (discussion in GH#23113)
from pandas.core.arrays.timedeltas import TimedeltaArrayMixin
i8delta = self.asi8 - self.to_period(freq).to_timestamp().asi8
return TimedeltaArrayMixin(i8delta)
m8delta = i8delta.view('m8[ns]')
return TimedeltaArrayMixin(m8delta)

# -----------------------------------------------------------------
# Properties - Vectorized Timestamp Properties/Methods
Expand Down Expand Up @@ -1320,6 +1318,20 @@ def to_julian_date(self):


def _generate_regular_range(cls, start, end, periods, freq):
"""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function summary / description?

Parameters
----------
cls : class
start : Timestamp or None
end : Timestamp or None
periods : int
freq : DateOffset
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameter descriptions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do


Returns
-------
ndarray[np.int64] representing nanosecond unix timestamps
"""
if isinstance(freq, Tick):
stride = freq.nanos
if periods is None:
Expand All @@ -1343,21 +1355,20 @@ def _generate_regular_range(cls, start, end, periods, freq):
"if a 'period' is given.")

data = np.arange(b, e, stride, dtype=np.int64)
data = cls._simple_new(data.view(_NS_DTYPE), None, tz=tz)
else:
tz = None
# start and end should have the same timezone by this point
if isinstance(start, Timestamp):
if start is not None:
tz = start.tz
elif isinstance(end, Timestamp):
elif end is not None:
tz = end.tz

xdr = generate_range(start=start, end=end,
periods=periods, offset=freq)

values = np.array([x.value for x in xdr])
data = cls._simple_new(values, freq=freq, tz=tz)
data = np.array([x.value for x in xdr], dtype=np.int64)

data = cls._simple_new(data.view(_NS_DTYPE), freq=freq, tz=tz)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its pretty arbitrary that you are viewing as M8[ns] rather than i8 here. let's be consistent (prob just i8 is fine),
though I think this IS i8 already?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yah, there is a lot of casting back-and-forth. I'll try to cut down on it.

General thought is that the values passed to _simple_new should already be in their correct forms (and master currently has some weird behavior, like passing a list in one case). Sharing code between Datetime/Timedelta/Period pretty much requires that an exception be made for i8.

return data


Expand Down
24 changes: 10 additions & 14 deletions pandas/core/arrays/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pandas import compat

from pandas.core.dtypes.common import (
_TD_DTYPE, ensure_int64, is_timedelta64_dtype, is_list_like)
_TD_DTYPE, is_list_like)
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.dtypes.missing import isna

Expand Down Expand Up @@ -111,16 +111,9 @@ def dtype(self):
_attributes = ["freq"]

@classmethod
def _simple_new(cls, values, freq=None, **kwargs):
values = np.array(values, copy=False)
if values.dtype == np.object_:
values = array_to_timedelta64(values)
if values.dtype != _TD_DTYPE:
if is_timedelta64_dtype(values):
# non-nano unit
values = values.astype(_TD_DTYPE)
else:
values = ensure_int64(values).view(_TD_DTYPE)
def _simple_new(cls, values, freq=None):
assert isinstance(values, np.ndarray), type(values)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add doc-string to indicate possible things values can be (e.g. i8, M8[ns])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is object even possible?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring: sure

object: yes. The point of this PR is to be more strict about what goes to _simple_new, since too much work is going on in some of them, causing some ambiguity.

As commented elsewhere, this PR does more things that it needs to, causing some confusion. I'll separate them out.

assert values.dtype == 'm8[ns]', values.dtype

result = object.__new__(cls)
result._data = values
Expand All @@ -131,6 +124,10 @@ def __new__(cls, values, freq=None):

freq, freq_infer = dtl.maybe_infer_freq(freq)

values = np.array(values, copy=False)
if values.dtype == np.object_:
values = array_to_timedelta64(values)

result = cls._simple_new(values, freq=freq)
if freq_infer:
inferred = result.inferred_freq
Expand Down Expand Up @@ -166,17 +163,16 @@ def _generate_range(cls, start, end, periods, freq, closed=None):

if freq is not None:
index = _generate_regular_range(start, end, periods, freq)
index = cls._simple_new(index, freq=freq)
else:
index = np.linspace(start.value, end.value, periods).astype('i8')
index = cls._simple_new(index, freq=freq)

if not left_closed:
index = index[1:]
if not right_closed:
index = index[:-1]

return index
index = index.view('m8[ns]')
return cls._simple_new(index, freq=freq)

# ----------------------------------------------------------------
# Arithmetic Methods
Expand Down
24 changes: 17 additions & 7 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,12 @@ def _round(self, freq, mode, ambiguous):
result = self._maybe_mask_results(result, fill_value=NaT)

attribs = self._get_attributes_dict()
if 'freq' in attribs:
attribs['freq'] = None
attribs['freq'] = None
if 'tz' in attribs:
attribs['tz'] = None
return self._ensure_localized(
self._shallow_copy(result, **attribs), ambiguous
)

result = self._shallow_copy(result, **attribs)
return self._ensure_localized(result, ambiguous)

@Appender((_round_doc + _round_example).format(op="round"))
def round(self, freq, ambiguous='raise'):
Expand All @@ -222,6 +221,18 @@ class DatetimeIndexOpsMixin(DatetimeLikeArrayMixin):
_resolution = cache_readonly(DatetimeLikeArrayMixin._resolution.fget)
resolution = cache_readonly(DatetimeLikeArrayMixin.resolution.fget)

def _shallow_copy(self, values=None, **kwargs):
if isinstance(values, list):
# reached via Index.insert
assert len(values) == 0
values = np.array([], dtype='i8')

# unwrap for case where e.g. _get_unique_index passes an instance
# of own class instead of ndarray
values = getattr(values, '_data', values)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do this with an actual check that it is an index?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. (Ideally I'd like to make _get_unique_index pass the "correct" thing, but that can wait for another day)


return DatetimeLikeArrayMixin._shallow_copy(self, values, **kwargs)

def equals(self, other):
"""
Determines if two Index objects contain the same elements.
Expand Down Expand Up @@ -640,8 +651,7 @@ def where(self, cond, other=None):
result = np.where(cond, values, other).astype('i8')

result = self._ensure_localized(result, from_utc=True)
return self._shallow_copy(result,
**self._get_attributes_dict())
return self._shallow_copy(result)

def _summary(self, name=None):
"""
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ def __new__(cls, data=None,
data = data.astype(np.int64, copy=False)
subarr = data.view(_NS_DTYPE)

assert isinstance(subarr, np.ndarray), type(subarr)
assert subarr.dtype == 'M8[ns]', subarr.dtype
subarr = cls._simple_new(subarr, name=name, freq=freq, tz=tz)
if dtype is not None:
if not is_dtype_equal(subarr.dtype, dtype):
Expand Down Expand Up @@ -1134,6 +1136,8 @@ def slice_indexer(self, start=None, end=None, step=None, kind=None):
is_year_end = wrap_field_accessor(DatetimeArrayMixin.is_year_end)
is_leap_year = wrap_field_accessor(DatetimeArrayMixin.is_leap_year)

tz_localize = wrap_array_method(DatetimeArrayMixin.tz_localize, True)
tz_convert = wrap_array_method(DatetimeArrayMixin.tz_convert, True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two methods previously used shallow_copy in the DatetimeArray class, so name was inherited automatically. This PR avoids the use of shallow_copy in the DatetimeArray class, so we need the extra step to pin name.

to_perioddelta = wrap_array_method(DatetimeArrayMixin.to_perioddelta,
False)
to_period = wrap_array_method(DatetimeArrayMixin.to_period, True)
Expand Down
28 changes: 26 additions & 2 deletions pandas/core/indexes/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
to_timedelta, _coerce_scalar_to_timedelta_type)
from pandas._libs import (lib, index as libindex,
join as libjoin, Timedelta, NaT)
from pandas._libs.tslibs.timedeltas import array_to_timedelta64


class TimedeltaIndex(TimedeltaArrayMixin, DatetimeIndexOpsMixin,
Expand Down Expand Up @@ -166,6 +167,19 @@ def __new__(cls, data=None, unit=None, freq=None, start=None, end=None,
elif copy:
data = np.array(data, copy=True)

data = np.array(data, copy=False)
if data.dtype == np.object_:
data = array_to_timedelta64(data)

if data.dtype != _TD_DTYPE:
if is_timedelta64_dtype(data):
# non-nano unit
# TODO: watch out for overflows
data = data.astype(_TD_DTYPE)
else:
data = ensure_int64(data).view(_TD_DTYPE)

assert data.dtype == 'm8[ns]', data.dtype
subarr = cls._simple_new(data, name=name, freq=freq)
# check that we are matching freqs
if verify_integrity and len(subarr) > 0:
Expand All @@ -180,8 +194,18 @@ def __new__(cls, data=None, unit=None, freq=None, start=None, end=None,
return subarr

@classmethod
def _simple_new(cls, values, name=None, freq=None, **kwargs):
result = super(TimedeltaIndex, cls)._simple_new(values, freq, **kwargs)
def _simple_new(cls, values, name=None, freq=None, dtype=_TD_DTYPE):

# `dtype` is not always passed, but if it is, it should always
# be m8[ns]
assert dtype == _TD_DTYPE
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what cases is it passed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _shallow_copy:

        if not len(values) and 'dtype' not in kwargs:
            attributes['dtype'] = self.dtype


assert isinstance(values, np.ndarray), type(values)
if values.dtype == 'i8':
values = values.view('m8[ns]')
assert values.dtype == 'm8[ns]', values.dtype
Copy link
Member

@gfyoung gfyoung Oct 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two points:

  • Generally not a big fan of bare assert like these, unless they're internal (in which case that might be fine). Are these user-facing in any way?
  • Even if they're internal are these assert statements tested?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Largely these got added for debugging and left them in since they behave like especially-emphatic comments. On the next pass I'll make sure that they only go in private methods.


result = super(TimedeltaIndex, cls)._simple_new(values, freq)
result.name = name
result._reset_identity()
return result
Expand Down