-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
REF: stricter checks in _simple_new, avoid shallow_copy in EAs #23426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
10a923f
9060f1a
d5b0bfd
233367a
f37ace3
1860ea0
777ddff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -177,16 +177,11 @@ def _simple_new(cls, values, freq=None, tz=None, **kwargs): | |
we require the we have a dtype compat for the values | ||
if we are passed a non-dtype compat, then coerce using the constructor | ||
""" | ||
assert isinstance(values, np.ndarray), type(values) | ||
if values.dtype == 'i8': | ||
values = values.view('M8[ns]') | ||
|
||
if getattr(values, 'dtype', None) is None: | ||
# empty, but with dtype compat | ||
if values is None: | ||
values = np.empty(0, dtype=_NS_DTYPE) | ||
return cls(values, freq=freq, tz=tz, **kwargs) | ||
values = np.array(values, copy=False) | ||
|
||
if not is_datetime64_dtype(values): | ||
values = ensure_int64(values).view(_NS_DTYPE) | ||
assert values.dtype == 'M8[ns]', values.dtype | ||
|
||
result = object.__new__(cls) | ||
result._data = values | ||
|
@@ -209,6 +204,15 @@ def __new__(cls, values, freq=None, tz=None, dtype=None): | |
# if dtype has an embedded tz, capture it | ||
tz = dtl.validate_tz_from_dtype(dtype, tz) | ||
|
||
if isinstance(values, DatetimeArrayMixin): | ||
values = values.asi8 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why getting the integers here and not M8 directly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because the way the constructors view i8 values is more consistent. For tz-naive, M8 vs i8 are equivalent. For tz-aware, i8 is interpreted as unix timestamps (i.e. UTC), whereas M8 are interpreted as the wall-time in the given timezone. |
||
if values.dtype == 'i8': | ||
values = values.view('M8[ns]') | ||
|
||
assert isinstance(values, np.ndarray), type(values) | ||
assert is_datetime64_dtype(values) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. above you check explicitly for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. correct |
||
values = conversion.ensure_datetime64ns(values, copy=False) | ||
|
||
result = cls._simple_new(values, freq=freq, tz=tz) | ||
if freq_infer: | ||
inferred = result.inferred_freq | ||
|
@@ -253,28 +257,22 @@ def _generate_range(cls, start, end, periods, freq, tz=None, | |
|
||
if tz is not None: | ||
# Localize the start and end arguments | ||
start = _maybe_localize_point( | ||
start, getattr(start, 'tz', None), start, freq, tz | ||
) | ||
end = _maybe_localize_point( | ||
end, getattr(end, 'tz', None), end, freq, tz | ||
) | ||
start = _maybe_localize_point(start, getattr(start, 'tz', None), | ||
start, freq, tz) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, can you leave such style changes only to lines you actually change anyway? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah and the former is actually more idiomatic, really prefer not to do partial line wrapping There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure |
||
end = _maybe_localize_point(end, getattr(end, 'tz', None), | ||
end, freq, tz) | ||
if start and end: | ||
# Make sure start and end have the same tz | ||
start = _maybe_localize_point( | ||
start, start.tz, end.tz, freq, tz | ||
) | ||
end = _maybe_localize_point( | ||
end, end.tz, start.tz, freq, tz | ||
) | ||
start = _maybe_localize_point(start, start.tz, end.tz, freq, tz) | ||
end = _maybe_localize_point(end, end.tz, start.tz, freq, tz) | ||
|
||
if freq is not None: | ||
# TODO: consider re-implementing _cached_range; GH#17914 | ||
index = _generate_regular_range(cls, start, end, periods, freq) | ||
|
||
if tz is not None and getattr(index, 'tz', None) is None: | ||
arr = conversion.tz_localize_to_utc( | ||
ensure_int64(index.values), | ||
tz, ambiguous=ambiguous) | ||
if tz is not None and index.tz is None: | ||
arr = conversion.tz_localize_to_utc(ensure_int64(index.values), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wrap on the ensure_int64 |
||
tz, ambiguous=ambiguous) | ||
|
||
index = cls(arr) | ||
|
||
|
@@ -287,9 +285,8 @@ def _generate_range(cls, start, end, periods, freq, tz=None, | |
else: | ||
# Create a linearly spaced date_range in local time | ||
arr = np.linspace(start.value, end.value, periods) | ||
index = cls._simple_new( | ||
arr.astype('M8[ns]', copy=False), freq=None, tz=tz | ||
) | ||
arr = arr.astype('M8[ns]', copy=False) | ||
index = cls._simple_new(arr, freq=None, tz=tz) | ||
|
||
if not left_closed and len(index) and index[0] == start: | ||
index = index[1:] | ||
|
@@ -586,7 +583,7 @@ def tz_convert(self, tz): | |
'tz_localize to localize') | ||
|
||
# No conversion since timestamps are all UTC to begin with | ||
return self._shallow_copy(tz=tz) | ||
return self._simple_new(self.asi8, tz=tz, freq=self.freq) | ||
|
||
def tz_localize(self, tz, ambiguous='raise', nonexistent='raise', | ||
errors=None): | ||
|
@@ -708,7 +705,7 @@ def tz_localize(self, tz, ambiguous='raise', nonexistent='raise', | |
self.asi8, tz, ambiguous=ambiguous, nonexistent=nonexistent, | ||
) | ||
new_dates = new_dates.view(_NS_DTYPE) | ||
return self._shallow_copy(new_dates, tz=tz) | ||
return self._simple_new(new_dates, tz=tz, freq=self.freq) | ||
|
||
# ---------------------------------------------------------------- | ||
# Conversion Methods - Vectorized analogues of Timestamp methods | ||
|
@@ -843,7 +840,8 @@ def to_perioddelta(self, freq): | |
# TODO: consider privatizing (discussion in GH#23113) | ||
from pandas.core.arrays.timedeltas import TimedeltaArrayMixin | ||
i8delta = self.asi8 - self.to_period(freq).to_timestamp().asi8 | ||
return TimedeltaArrayMixin(i8delta) | ||
m8delta = i8delta.view('m8[ns]') | ||
return TimedeltaArrayMixin(m8delta) | ||
|
||
# ----------------------------------------------------------------- | ||
# Properties - Vectorized Timestamp Properties/Methods | ||
|
@@ -1320,6 +1318,20 @@ def to_julian_date(self): | |
|
||
|
||
def _generate_regular_range(cls, start, end, periods, freq): | ||
""" | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function summary / description? |
||
Parameters | ||
---------- | ||
cls : class | ||
start : Timestamp or None | ||
end : Timestamp or None | ||
periods : int | ||
freq : DateOffset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Parameter descriptions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will do |
||
|
||
Returns | ||
------- | ||
ndarray[np.int64] representing nanosecond unix timestamps | ||
""" | ||
if isinstance(freq, Tick): | ||
stride = freq.nanos | ||
if periods is None: | ||
|
@@ -1343,21 +1355,20 @@ def _generate_regular_range(cls, start, end, periods, freq): | |
"if a 'period' is given.") | ||
|
||
data = np.arange(b, e, stride, dtype=np.int64) | ||
data = cls._simple_new(data.view(_NS_DTYPE), None, tz=tz) | ||
else: | ||
tz = None | ||
# start and end should have the same timezone by this point | ||
if isinstance(start, Timestamp): | ||
if start is not None: | ||
tz = start.tz | ||
elif isinstance(end, Timestamp): | ||
elif end is not None: | ||
tz = end.tz | ||
|
||
xdr = generate_range(start=start, end=end, | ||
periods=periods, offset=freq) | ||
|
||
values = np.array([x.value for x in xdr]) | ||
data = cls._simple_new(values, freq=freq, tz=tz) | ||
data = np.array([x.value for x in xdr], dtype=np.int64) | ||
|
||
data = cls._simple_new(data.view(_NS_DTYPE), freq=freq, tz=tz) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. its pretty arbitrary that you are viewing as M8[ns] rather than i8 here. let's be consistent (prob just i8 is fine), There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yah, there is a lot of casting back-and-forth. I'll try to cut down on it. General thought is that the values passed to |
||
return data | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ | |
from pandas import compat | ||
|
||
from pandas.core.dtypes.common import ( | ||
_TD_DTYPE, ensure_int64, is_timedelta64_dtype, is_list_like) | ||
_TD_DTYPE, is_list_like) | ||
from pandas.core.dtypes.generic import ABCSeries | ||
from pandas.core.dtypes.missing import isna | ||
|
||
|
@@ -111,16 +111,9 @@ def dtype(self): | |
_attributes = ["freq"] | ||
|
||
@classmethod | ||
def _simple_new(cls, values, freq=None, **kwargs): | ||
values = np.array(values, copy=False) | ||
if values.dtype == np.object_: | ||
values = array_to_timedelta64(values) | ||
if values.dtype != _TD_DTYPE: | ||
if is_timedelta64_dtype(values): | ||
# non-nano unit | ||
values = values.astype(_TD_DTYPE) | ||
else: | ||
values = ensure_int64(values).view(_TD_DTYPE) | ||
def _simple_new(cls, values, freq=None): | ||
assert isinstance(values, np.ndarray), type(values) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add doc-string to indicate possible things values can be (e.g. i8, M8[ns]) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is object even possible? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docstring: sure object: yes. The point of this PR is to be more strict about what goes to _simple_new, since too much work is going on in some of them, causing some ambiguity. As commented elsewhere, this PR does more things that it needs to, causing some confusion. I'll separate them out. |
||
assert values.dtype == 'm8[ns]', values.dtype | ||
|
||
result = object.__new__(cls) | ||
result._data = values | ||
|
@@ -131,6 +124,10 @@ def __new__(cls, values, freq=None): | |
|
||
freq, freq_infer = dtl.maybe_infer_freq(freq) | ||
|
||
values = np.array(values, copy=False) | ||
if values.dtype == np.object_: | ||
values = array_to_timedelta64(values) | ||
|
||
result = cls._simple_new(values, freq=freq) | ||
if freq_infer: | ||
inferred = result.inferred_freq | ||
|
@@ -166,17 +163,16 @@ def _generate_range(cls, start, end, periods, freq, closed=None): | |
|
||
if freq is not None: | ||
index = _generate_regular_range(start, end, periods, freq) | ||
index = cls._simple_new(index, freq=freq) | ||
else: | ||
index = np.linspace(start.value, end.value, periods).astype('i8') | ||
index = cls._simple_new(index, freq=freq) | ||
|
||
if not left_closed: | ||
index = index[1:] | ||
if not right_closed: | ||
index = index[:-1] | ||
|
||
return index | ||
index = index.view('m8[ns]') | ||
return cls._simple_new(index, freq=freq) | ||
|
||
# ---------------------------------------------------------------- | ||
# Arithmetic Methods | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -189,13 +189,12 @@ def _round(self, freq, mode, ambiguous): | |
result = self._maybe_mask_results(result, fill_value=NaT) | ||
|
||
attribs = self._get_attributes_dict() | ||
if 'freq' in attribs: | ||
attribs['freq'] = None | ||
attribs['freq'] = None | ||
if 'tz' in attribs: | ||
attribs['tz'] = None | ||
return self._ensure_localized( | ||
self._shallow_copy(result, **attribs), ambiguous | ||
) | ||
|
||
result = self._shallow_copy(result, **attribs) | ||
return self._ensure_localized(result, ambiguous) | ||
|
||
@Appender((_round_doc + _round_example).format(op="round")) | ||
def round(self, freq, ambiguous='raise'): | ||
|
@@ -222,6 +221,18 @@ class DatetimeIndexOpsMixin(DatetimeLikeArrayMixin): | |
_resolution = cache_readonly(DatetimeLikeArrayMixin._resolution.fget) | ||
resolution = cache_readonly(DatetimeLikeArrayMixin.resolution.fget) | ||
|
||
def _shallow_copy(self, values=None, **kwargs): | ||
if isinstance(values, list): | ||
# reached via Index.insert | ||
assert len(values) == 0 | ||
values = np.array([], dtype='i8') | ||
|
||
# unwrap for case where e.g. _get_unique_index passes an instance | ||
# of own class instead of ndarray | ||
values = getattr(values, '_data', values) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you do this with an actual check that it is an index? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. (Ideally I'd like to make _get_unique_index pass the "correct" thing, but that can wait for another day) |
||
|
||
return DatetimeLikeArrayMixin._shallow_copy(self, values, **kwargs) | ||
|
||
def equals(self, other): | ||
""" | ||
Determines if two Index objects contain the same elements. | ||
|
@@ -640,8 +651,7 @@ def where(self, cond, other=None): | |
result = np.where(cond, values, other).astype('i8') | ||
|
||
result = self._ensure_localized(result, from_utc=True) | ||
return self._shallow_copy(result, | ||
**self._get_attributes_dict()) | ||
return self._shallow_copy(result) | ||
|
||
def _summary(self, name=None): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -298,6 +298,8 @@ def __new__(cls, data=None, | |
data = data.astype(np.int64, copy=False) | ||
subarr = data.view(_NS_DTYPE) | ||
|
||
assert isinstance(subarr, np.ndarray), type(subarr) | ||
assert subarr.dtype == 'M8[ns]', subarr.dtype | ||
subarr = cls._simple_new(subarr, name=name, freq=freq, tz=tz) | ||
if dtype is not None: | ||
if not is_dtype_equal(subarr.dtype, dtype): | ||
|
@@ -1134,6 +1136,8 @@ def slice_indexer(self, start=None, end=None, step=None, kind=None): | |
is_year_end = wrap_field_accessor(DatetimeArrayMixin.is_year_end) | ||
is_leap_year = wrap_field_accessor(DatetimeArrayMixin.is_leap_year) | ||
|
||
tz_localize = wrap_array_method(DatetimeArrayMixin.tz_localize, True) | ||
tz_convert = wrap_array_method(DatetimeArrayMixin.tz_convert, True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two methods previously used shallow_copy in the DatetimeArray class, so |
||
to_perioddelta = wrap_array_method(DatetimeArrayMixin.to_perioddelta, | ||
False) | ||
to_period = wrap_array_method(DatetimeArrayMixin.to_period, True) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
to_timedelta, _coerce_scalar_to_timedelta_type) | ||
from pandas._libs import (lib, index as libindex, | ||
join as libjoin, Timedelta, NaT) | ||
from pandas._libs.tslibs.timedeltas import array_to_timedelta64 | ||
|
||
|
||
class TimedeltaIndex(TimedeltaArrayMixin, DatetimeIndexOpsMixin, | ||
|
@@ -166,6 +167,19 @@ def __new__(cls, data=None, unit=None, freq=None, start=None, end=None, | |
elif copy: | ||
data = np.array(data, copy=True) | ||
|
||
data = np.array(data, copy=False) | ||
if data.dtype == np.object_: | ||
data = array_to_timedelta64(data) | ||
|
||
if data.dtype != _TD_DTYPE: | ||
if is_timedelta64_dtype(data): | ||
# non-nano unit | ||
# TODO: watch out for overflows | ||
data = data.astype(_TD_DTYPE) | ||
else: | ||
data = ensure_int64(data).view(_TD_DTYPE) | ||
|
||
assert data.dtype == 'm8[ns]', data.dtype | ||
subarr = cls._simple_new(data, name=name, freq=freq) | ||
# check that we are matching freqs | ||
if verify_integrity and len(subarr) > 0: | ||
|
@@ -180,8 +194,18 @@ def __new__(cls, data=None, unit=None, freq=None, start=None, end=None, | |
return subarr | ||
|
||
@classmethod | ||
def _simple_new(cls, values, name=None, freq=None, **kwargs): | ||
result = super(TimedeltaIndex, cls)._simple_new(values, freq, **kwargs) | ||
def _simple_new(cls, values, name=None, freq=None, dtype=_TD_DTYPE): | ||
|
||
# `dtype` is not always passed, but if it is, it should always | ||
# be m8[ns] | ||
assert dtype == _TD_DTYPE | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In what cases is it passed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In _shallow_copy:
|
||
|
||
assert isinstance(values, np.ndarray), type(values) | ||
if values.dtype == 'i8': | ||
values = values.view('m8[ns]') | ||
assert values.dtype == 'm8[ns]', values.dtype | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two points:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Largely these got added for debugging and left them in since they behave like especially-emphatic comments. On the next pass I'll make sure that they only go in private methods. |
||
|
||
result = super(TimedeltaIndex, cls)._simple_new(values, freq) | ||
result.name = name | ||
result._reset_identity() | ||
return result | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a comment here about what this is doing