-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
REF: Back IntervalArray by array instead of Index #36310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
8987a0e
153c87a
8164099
d545dac
6050ec8
bd6231c
548efe6
124938e
c479e0a
c4a2229
97a0bed
bfa13bb
b45ed46
e6d4bd9
266512f
f16be73
4efdc08
1ed9623
ed6a932
fee70d8
1a22095
490a8f5
865b3fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -976,8 +976,15 @@ def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray") | |
""" | ||
_check_isinstance(left, right, IntervalArray) | ||
|
||
assert_index_equal(left.left, right.left, exact=exact, obj=f"{obj}.left") | ||
assert_index_equal(left.right, right.right, exact=exact, obj=f"{obj}.left") | ||
kwargs = {} | ||
if left._left.dtype.kind in ["m", "M"]: | ||
# We have a DatetimeArray or TimedeltaArray | ||
kwargs["check_freq"] = False | ||
|
||
# TODO: `exact` keyword? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this TODO needs to be solved first? It was there before, but you now removed it? (so is not being ignored?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think we're ok without the keyword, will remove the comment |
||
assert_equal(left._left, right._left, obj=f"{obj}.left", **kwargs) | ||
assert_equal(left._right, right._right, obj=f"{obj}.left", **kwargs) | ||
|
||
assert_attr_equal("closed", left, right, obj=obj) | ||
|
||
|
||
|
@@ -988,20 +995,22 @@ def assert_period_array_equal(left, right, obj="PeriodArray"): | |
assert_attr_equal("freq", left, right, obj=obj) | ||
|
||
|
||
def assert_datetime_array_equal(left, right, obj="DatetimeArray"): | ||
def assert_datetime_array_equal(left, right, obj="DatetimeArray", check_freq=True): | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
__tracebackhide__ = True | ||
_check_isinstance(left, right, DatetimeArray) | ||
|
||
assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data") | ||
assert_attr_equal("freq", left, right, obj=obj) | ||
if check_freq: | ||
assert_attr_equal("freq", left, right, obj=obj) | ||
assert_attr_equal("tz", left, right, obj=obj) | ||
|
||
|
||
def assert_timedelta_array_equal(left, right, obj="TimedeltaArray"): | ||
def assert_timedelta_array_equal(left, right, obj="TimedeltaArray", check_freq=True): | ||
__tracebackhide__ = True | ||
_check_isinstance(left, right, TimedeltaArray) | ||
assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data") | ||
assert_attr_equal("freq", left, right, obj=obj) | ||
if check_freq: | ||
assert_attr_equal("freq", left, right, obj=obj) | ||
|
||
|
||
def raise_assert_detail(obj, message, left, right, diff=None, index_values=None): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,6 @@ | |
from pandas.core.dtypes.dtypes import IntervalDtype | ||
from pandas.core.dtypes.generic import ( | ||
ABCDatetimeIndex, | ||
ABCIndexClass, | ||
ABCIntervalIndex, | ||
ABCPeriodIndex, | ||
ABCSeries, | ||
|
@@ -42,7 +41,7 @@ | |
from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs | ||
from pandas.core.arrays.categorical import Categorical | ||
import pandas.core.common as com | ||
from pandas.core.construction import array | ||
from pandas.core.construction import array, extract_array | ||
from pandas.core.indexers import check_array_indexer | ||
from pandas.core.indexes.base import ensure_index | ||
|
||
|
@@ -161,12 +160,14 @@ def __new__( | |
verify_integrity: bool = True, | ||
): | ||
|
||
if isinstance(data, ABCSeries) and is_interval_dtype(data.dtype): | ||
data = data._values | ||
if isinstance(data, (ABCSeries, ABCIntervalIndex)) and is_interval_dtype( | ||
data.dtype | ||
): | ||
data = data._values # TODO: extract_array? | ||
|
||
if isinstance(data, (cls, ABCIntervalIndex)): | ||
left = data.left | ||
right = data.right | ||
if isinstance(data, cls): | ||
left = data._left | ||
right = data._right | ||
closed = closed or data.closed | ||
else: | ||
|
||
|
@@ -243,6 +244,20 @@ def _simple_new( | |
) | ||
raise ValueError(msg) | ||
|
||
# For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray | ||
from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
left = maybe_upcast_datetimelike_array(left) | ||
left = extract_array(left, extract_numpy=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Above we are first ensuring that the arrays passed to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can avoid the roundtrip eventually, will be best accomplished by being stricter in what we pass to _simple_new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good, would be a nice follow-up |
||
right = maybe_upcast_datetimelike_array(right) | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
right = extract_array(right, extract_numpy=True) | ||
|
||
lbase = getattr(left, "_ndarray", left).base | ||
rbase = getattr(right, "_ndarray", right).base | ||
if lbase is not None and lbase is rbase: | ||
# If these share data, then setitem could corrupt our IA | ||
right = right.copy() | ||
|
||
result._left = left | ||
result._right = right | ||
result._closed = closed | ||
|
@@ -476,18 +491,18 @@ def _validate(self): | |
if self.closed not in VALID_CLOSED: | ||
msg = f"invalid option for 'closed': {self.closed}" | ||
raise ValueError(msg) | ||
if len(self.left) != len(self.right): | ||
if len(self._left) != len(self._right): | ||
msg = "left and right must have the same length" | ||
raise ValueError(msg) | ||
left_mask = notna(self.left) | ||
right_mask = notna(self.right) | ||
left_mask = notna(self._left) | ||
right_mask = notna(self._right) | ||
if not (left_mask == right_mask).all(): | ||
msg = ( | ||
"missing values must be missing in the same " | ||
"location both left and right sides" | ||
) | ||
raise ValueError(msg) | ||
if not (self.left[left_mask] <= self.right[left_mask]).all(): | ||
if not (self._left[left_mask] <= self._right[left_mask]).all(): | ||
msg = "left side of interval must be <= right side" | ||
raise ValueError(msg) | ||
|
||
|
@@ -527,37 +542,29 @@ def __iter__(self): | |
return iter(np.asarray(self)) | ||
|
||
def __len__(self) -> int: | ||
return len(self.left) | ||
return len(self._left) | ||
|
||
def __getitem__(self, value): | ||
value = check_array_indexer(self, value) | ||
left = self.left[value] | ||
right = self.right[value] | ||
left = self._left[value] | ||
right = self._right[value] | ||
|
||
# scalar | ||
if not isinstance(left, ABCIndexClass): | ||
if not isinstance(left, (np.ndarray, ExtensionArray)): | ||
# scalar | ||
if is_scalar(left) and isna(left): | ||
return self._fill_value | ||
if np.ndim(left) > 1: | ||
# GH#30588 multi-dimensional indexer disallowed | ||
raise ValueError("multi-dimensional indexing not allowed") | ||
return Interval(left, right, self.closed) | ||
|
||
if np.ndim(left) > 1: | ||
# GH#30588 multi-dimensional indexer disallowed | ||
raise ValueError("multi-dimensional indexing not allowed") | ||
return self._shallow_copy(left, right) | ||
|
||
def __setitem__(self, key, value): | ||
value_left, value_right = self._validate_setitem_value(value) | ||
key = check_array_indexer(self, key) | ||
|
||
# Need to ensure that left and right are updated atomically, so we're | ||
# forced to copy, update the copy, and swap in the new values. | ||
left = self.left.copy(deep=True) | ||
left._values[key] = value_left | ||
self._left = left | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
right = self.right.copy(deep=True) | ||
right._values[key] = value_right | ||
self._right = right | ||
self._left[key] = value_left | ||
self._right[key] = value_right # TODO: needs tests for not breaking views | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't the un-xfail-ed test doing that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch, will remove comment |
||
|
||
def __eq__(self, other): | ||
# ensure pandas array for list-like and eliminate non-interval scalars | ||
|
@@ -588,7 +595,7 @@ def __eq__(self, other): | |
if is_interval_dtype(other_dtype): | ||
if self.closed != other.closed: | ||
return np.zeros(len(self), dtype=bool) | ||
return (self.left == other.left) & (self.right == other.right) | ||
return (self._left == other.left) & (self._right == other.right) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. at this point |
||
|
||
# non-interval/non-object dtype -> no matches | ||
if not is_object_dtype(other_dtype): | ||
|
@@ -601,8 +608,8 @@ def __eq__(self, other): | |
if ( | ||
isinstance(obj, Interval) | ||
and self.closed == obj.closed | ||
and self.left[i] == obj.left | ||
and self.right[i] == obj.right | ||
and self._left[i] == obj.left | ||
and self._right[i] == obj.right | ||
): | ||
result[i] = True | ||
|
||
|
@@ -665,6 +672,7 @@ def astype(self, dtype, copy=True): | |
array : ExtensionArray or ndarray | ||
ExtensionArray or NumPy ndarray with 'dtype' for its dtype. | ||
""" | ||
from pandas import Index | ||
from pandas.core.arrays.string_ import StringDtype | ||
|
||
if dtype is not None: | ||
|
@@ -676,8 +684,10 @@ def astype(self, dtype, copy=True): | |
|
||
# need to cast to different subtype | ||
try: | ||
new_left = self.left.astype(dtype.subtype) | ||
new_right = self.right.astype(dtype.subtype) | ||
# We need to use Index rules for astype to prevent casting | ||
# np.nan entries to int subtypes | ||
new_left = Index(self._left, copy=False).astype(dtype.subtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could add copy=False to .astype (not sure how much any of this matters though) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. merge is ok now and can see if this matters on followup |
||
new_right = Index(self._right, copy=False).astype(dtype.subtype) | ||
except TypeError as err: | ||
msg = ( | ||
f"Cannot convert {self.dtype} to {dtype}; subtypes are incompatible" | ||
|
@@ -726,14 +736,14 @@ def copy(self): | |
------- | ||
IntervalArray | ||
""" | ||
left = self.left.copy(deep=True) | ||
right = self.right.copy(deep=True) | ||
left = self._left.copy() | ||
right = self._right.copy() | ||
closed = self.closed | ||
# TODO: Could skip verify_integrity here. | ||
return type(self).from_arrays(left, right, closed=closed) | ||
|
||
def isna(self): | ||
return isna(self.left) | ||
def isna(self) -> np.ndarray: | ||
return isna(self._left) | ||
|
||
def shift(self, periods: int = 1, fill_value: object = None) -> "IntervalArray": | ||
if not len(self) or periods == 0: | ||
|
@@ -749,7 +759,9 @@ def shift(self, periods: int = 1, fill_value: object = None) -> "IntervalArray": | |
|
||
empty_len = min(abs(periods), len(self)) | ||
if isna(fill_value): | ||
fill_value = self.left._na_value | ||
from pandas import Index | ||
|
||
fill_value = Index(self._left, copy=False)._na_value | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
empty = IntervalArray.from_breaks([fill_value] * (empty_len + 1)) | ||
else: | ||
empty = self._from_sequence([fill_value] * empty_len) | ||
|
@@ -815,10 +827,10 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None, **kwargs): | |
fill_left, fill_right = self._validate_fill_value(fill_value) | ||
|
||
left_take = take( | ||
self.left, indices, allow_fill=allow_fill, fill_value=fill_left | ||
self._left, indices, allow_fill=allow_fill, fill_value=fill_left | ||
) | ||
right_take = take( | ||
self.right, indices, allow_fill=allow_fill, fill_value=fill_right | ||
self._right, indices, allow_fill=allow_fill, fill_value=fill_right | ||
) | ||
|
||
return self._shallow_copy(left_take, right_take) | ||
|
@@ -984,15 +996,19 @@ def left(self): | |
Return the left endpoints of each Interval in the IntervalArray as | ||
an Index. | ||
""" | ||
return self._left | ||
from pandas import Index | ||
|
||
return Index(self._left, copy=False) | ||
|
||
@property | ||
def right(self): | ||
""" | ||
Return the right endpoints of each Interval in the IntervalArray as | ||
an Index. | ||
""" | ||
return self._right | ||
from pandas import Index | ||
|
||
return Index(self._right, copy=False) | ||
|
||
@property | ||
def length(self): | ||
|
@@ -1153,7 +1169,7 @@ def set_closed(self, closed): | |
raise ValueError(msg) | ||
|
||
return type(self)._simple_new( | ||
left=self.left, right=self.right, closed=closed, verify_integrity=False | ||
left=self._left, right=self._right, closed=closed, verify_integrity=False | ||
) | ||
|
||
_interval_shared_docs[ | ||
|
@@ -1179,15 +1195,15 @@ def is_non_overlapping_monotonic(self): | |
# at a point when both sides of intervals are included | ||
if self.closed == "both": | ||
return bool( | ||
(self.right[:-1] < self.left[1:]).all() | ||
or (self.left[:-1] > self.right[1:]).all() | ||
(self._right[:-1] < self._left[1:]).all() | ||
or (self._left[:-1] > self._right[1:]).all() | ||
) | ||
|
||
# non-strict inequality when closed != 'both'; at least one side is | ||
# not included in the intervals, so equality does not imply overlapping | ||
return bool( | ||
(self.right[:-1] <= self.left[1:]).all() | ||
or (self.left[:-1] >= self.right[1:]).all() | ||
(self._right[:-1] <= self._left[1:]).all() | ||
or (self._left[:-1] >= self._right[1:]).all() | ||
) | ||
|
||
# --------------------------------------------------------------------- | ||
|
@@ -1198,8 +1214,8 @@ def __array__(self, dtype=None) -> np.ndarray: | |
Return the IntervalArray's data as a numpy array of Interval | ||
objects (with dtype='object') | ||
""" | ||
left = self.left | ||
right = self.right | ||
left = self._left | ||
right = self._right | ||
mask = self.isna() | ||
closed = self._closed | ||
|
||
|
@@ -1229,8 +1245,8 @@ def __arrow_array__(self, type=None): | |
interval_type = ArrowIntervalType(subtype, self.closed) | ||
storage_array = pyarrow.StructArray.from_arrays( | ||
[ | ||
pyarrow.array(self.left, type=subtype, from_pandas=True), | ||
pyarrow.array(self.right, type=subtype, from_pandas=True), | ||
pyarrow.array(self._left, type=subtype, from_pandas=True), | ||
pyarrow.array(self._right, type=subtype, from_pandas=True), | ||
], | ||
names=["left", "right"], | ||
) | ||
|
@@ -1284,7 +1300,7 @@ def __arrow_array__(self, type=None): | |
_interval_shared_docs["to_tuples"] % dict(return_type="ndarray", examples="") | ||
) | ||
def to_tuples(self, na_tuple=True): | ||
tuples = com.asarray_tuplesafe(zip(self.left, self.right)) | ||
tuples = com.asarray_tuplesafe(zip(self._left, self._right)) | ||
if not na_tuple: | ||
# GH 18756 | ||
tuples = np.where(~self.isna(), tuples, np.nan) | ||
|
@@ -1350,8 +1366,8 @@ def contains(self, other): | |
if isinstance(other, Interval): | ||
raise NotImplementedError("contains not implemented for two intervals") | ||
|
||
return (self.left < other if self.open_left else self.left <= other) & ( | ||
other < self.right if self.open_right else other <= self.right | ||
return (self._left < other if self.open_left else self._left <= other) & ( | ||
other < self._right if self.open_right else other <= self._right | ||
) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docs should only reference public functions. How would users actually pass this through a public API? I suspect it's impossible, since
check_freq
inassert_frame_equal
applies to the index rather than values of an array.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So
assert_extension_array_equal
could perhaps takekwargs
and pass it through. But that's maybe not worth the effort.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so remove this note?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, unless this is user facing in some way (e.g. is assert_index_equal changed)?