-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
REF: Back IntervalArray by array instead of Index #36310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
8987a0e
153c87a
8164099
d545dac
6050ec8
bd6231c
548efe6
124938e
c479e0a
c4a2229
97a0bed
bfa13bb
b45ed46
e6d4bd9
266512f
f16be73
4efdc08
1ed9623
ed6a932
fee70d8
1a22095
490a8f5
865b3fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,7 +32,6 @@ | |
from pandas.core.dtypes.dtypes import IntervalDtype | ||
from pandas.core.dtypes.generic import ( | ||
ABCDatetimeIndex, | ||
ABCIndexClass, | ||
ABCIntervalIndex, | ||
ABCPeriodIndex, | ||
ABCSeries, | ||
|
@@ -43,7 +42,7 @@ | |
from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs | ||
from pandas.core.arrays.categorical import Categorical | ||
import pandas.core.common as com | ||
from pandas.core.construction import array | ||
from pandas.core.construction import array, extract_array | ||
from pandas.core.indexers import check_array_indexer | ||
from pandas.core.indexes.base import ensure_index | ||
|
||
|
@@ -162,12 +161,14 @@ def __new__( | |
verify_integrity: bool = True, | ||
): | ||
|
||
if isinstance(data, ABCSeries) and is_interval_dtype(data.dtype): | ||
data = data._values | ||
if isinstance(data, (ABCSeries, ABCIntervalIndex)) and is_interval_dtype( | ||
data.dtype | ||
): | ||
data = data._values # TODO: extract_array? | ||
|
||
if isinstance(data, (cls, ABCIntervalIndex)): | ||
left = data.left | ||
right = data.right | ||
if isinstance(data, cls): | ||
left = data._left | ||
right = data._right | ||
closed = closed or data.closed | ||
else: | ||
|
||
|
@@ -244,6 +245,19 @@ def _simple_new( | |
) | ||
raise ValueError(msg) | ||
|
||
from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
left = maybe_upcast_datetimelike_array(left) | ||
right = extract_array(right, extract_numpy=True) | ||
left = extract_array(left, extract_numpy=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Above we are first ensuring that the arrays passed to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can avoid the roundtrip eventually, will be best accomplished by being stricter in what we pass to _simple_new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good, would be a nice follow-up |
||
right = maybe_upcast_datetimelike_array(right) | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
lbase = getattr(left, "_ndarray", left).base | ||
rbase = getattr(right, "_ndarray", right).base | ||
if lbase is not None and lbase is rbase: | ||
# If these share data, then setitem could corrupt our IA | ||
right = right.copy() | ||
|
||
result._left = left | ||
result._right = right | ||
result._closed = closed | ||
|
@@ -477,18 +491,18 @@ def _validate(self): | |
if self.closed not in VALID_CLOSED: | ||
msg = f"invalid option for 'closed': {self.closed}" | ||
raise ValueError(msg) | ||
if len(self.left) != len(self.right): | ||
if len(self._left) != len(self._right): | ||
msg = "left and right must have the same length" | ||
raise ValueError(msg) | ||
left_mask = notna(self.left) | ||
right_mask = notna(self.right) | ||
left_mask = notna(self._left) | ||
right_mask = notna(self._right) | ||
if not (left_mask == right_mask).all(): | ||
msg = ( | ||
"missing values must be missing in the same " | ||
"location both left and right sides" | ||
) | ||
raise ValueError(msg) | ||
if not (self.left[left_mask] <= self.right[left_mask]).all(): | ||
if not (self._left[left_mask] <= self._right[left_mask]).all(): | ||
msg = "left side of interval must be <= right side" | ||
raise ValueError(msg) | ||
|
||
|
@@ -528,22 +542,21 @@ def __iter__(self): | |
return iter(np.asarray(self)) | ||
|
||
def __len__(self) -> int: | ||
return len(self.left) | ||
return len(self._left) | ||
|
||
def __getitem__(self, value): | ||
value = check_array_indexer(self, value) | ||
left = self.left[value] | ||
right = self.right[value] | ||
left = self._left[value] | ||
right = self._right[value] | ||
|
||
# scalar | ||
if not isinstance(left, ABCIndexClass): | ||
if not isinstance(left, (np.ndarray, ExtensionArray)): | ||
# scalar | ||
if is_scalar(left) and isna(left): | ||
return self._fill_value | ||
if np.ndim(left) > 1: | ||
# GH#30588 multi-dimensional indexer disallowed | ||
raise ValueError("multi-dimensional indexing not allowed") | ||
return Interval(left, right, self.closed) | ||
|
||
if np.ndim(left) > 1: | ||
# GH#30588 multi-dimensional indexer disallowed | ||
raise ValueError("multi-dimensional indexing not allowed") | ||
return self._shallow_copy(left, right) | ||
|
||
def __setitem__(self, key, value): | ||
|
@@ -570,7 +583,7 @@ def __setitem__(self, key, value): | |
# list-like of intervals | ||
try: | ||
array = IntervalArray(value) | ||
value_left, value_right = array.left, array.right | ||
value_left, value_right = array._left, array.right | ||
except TypeError as err: | ||
# wrong type: not interval or NA | ||
msg = f"'value' should be an interval type, got {type(value)} instead." | ||
|
@@ -581,15 +594,8 @@ def __setitem__(self, key, value): | |
|
||
key = check_array_indexer(self, key) | ||
|
||
# Need to ensure that left and right are updated atomically, so we're | ||
# forced to copy, update the copy, and swap in the new values. | ||
left = self.left.copy(deep=True) | ||
left._values[key] = value_left | ||
self._left = left | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
right = self.right.copy(deep=True) | ||
right._values[key] = value_right | ||
self._right = right | ||
self._left[key] = value_left | ||
self._right[key] = value_right # TODO: needs tests for not breaking views | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't the un-xfail-ed test doing that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch, will remove comment |
||
|
||
def __eq__(self, other): | ||
# ensure pandas array for list-like and eliminate non-interval scalars | ||
|
@@ -620,7 +626,7 @@ def __eq__(self, other): | |
if is_interval_dtype(other_dtype): | ||
if self.closed != other.closed: | ||
return np.zeros(len(self), dtype=bool) | ||
return (self.left == other.left) & (self.right == other.right) | ||
return (self._left == other.left) & (self._right == other.right) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. at this point |
||
|
||
# non-interval/non-object dtype -> no matches | ||
if not is_object_dtype(other_dtype): | ||
|
@@ -633,8 +639,8 @@ def __eq__(self, other): | |
if ( | ||
isinstance(obj, Interval) | ||
and self.closed == obj.closed | ||
and self.left[i] == obj.left | ||
and self.right[i] == obj.right | ||
and self._left[i] == obj.left | ||
and self._right[i] == obj.right | ||
): | ||
result[i] = True | ||
|
||
|
@@ -697,6 +703,7 @@ def astype(self, dtype, copy=True): | |
array : ExtensionArray or ndarray | ||
ExtensionArray or NumPy ndarray with 'dtype' for its dtype. | ||
""" | ||
from pandas import Index | ||
from pandas.core.arrays.string_ import StringDtype | ||
|
||
if dtype is not None: | ||
|
@@ -708,8 +715,10 @@ def astype(self, dtype, copy=True): | |
|
||
# need to cast to different subtype | ||
try: | ||
new_left = self.left.astype(dtype.subtype) | ||
new_right = self.right.astype(dtype.subtype) | ||
# We need to use Index rules for astype to prevent casting | ||
# np.nan entries to int subtypes | ||
new_left = Index(self._left).astype(dtype.subtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. copy=False ? |
||
new_right = Index(self._right).astype(dtype.subtype) | ||
except TypeError as err: | ||
msg = ( | ||
f"Cannot convert {self.dtype} to {dtype}; subtypes are incompatible" | ||
|
@@ -758,14 +767,14 @@ def copy(self): | |
------- | ||
IntervalArray | ||
""" | ||
left = self.left.copy(deep=True) | ||
right = self.right.copy(deep=True) | ||
left = self._left.copy() | ||
right = self._right.copy() | ||
closed = self.closed | ||
# TODO: Could skip verify_integrity here. | ||
return type(self).from_arrays(left, right, closed=closed) | ||
|
||
def isna(self): | ||
return isna(self.left) | ||
def isna(self) -> np.ndarray: | ||
return isna(self._left) | ||
|
||
def shift(self, periods: int = 1, fill_value: object = None) -> "IntervalArray": | ||
if not len(self) or periods == 0: | ||
|
@@ -781,7 +790,9 @@ def shift(self, periods: int = 1, fill_value: object = None) -> "IntervalArray": | |
|
||
empty_len = min(abs(periods), len(self)) | ||
if isna(fill_value): | ||
fill_value = self.left._na_value | ||
from pandas import Index | ||
|
||
fill_value = Index(self._left)._na_value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why can't we get the fill value rather than doing this? if we need to do this add copy=False |
||
empty = IntervalArray.from_breaks([fill_value] * (empty_len + 1)) | ||
else: | ||
empty = self._from_sequence([fill_value] * empty_len) | ||
|
@@ -847,10 +858,10 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None, **kwargs): | |
fill_left, fill_right = self._validate_fill_value(fill_value) | ||
|
||
left_take = take( | ||
self.left, indices, allow_fill=allow_fill, fill_value=fill_left | ||
self._left, indices, allow_fill=allow_fill, fill_value=fill_left | ||
) | ||
right_take = take( | ||
self.right, indices, allow_fill=allow_fill, fill_value=fill_right | ||
self._right, indices, allow_fill=allow_fill, fill_value=fill_right | ||
) | ||
|
||
return self._shallow_copy(left_take, right_take) | ||
|
@@ -982,15 +993,19 @@ def left(self): | |
Return the left endpoints of each Interval in the IntervalArray as | ||
an Index. | ||
""" | ||
return self._left | ||
from pandas import Index | ||
|
||
return Index(self._left) | ||
|
||
@property | ||
def right(self): | ||
""" | ||
Return the right endpoints of each Interval in the IntervalArray as | ||
an Index. | ||
""" | ||
return self._right | ||
from pandas import Index | ||
|
||
return Index(self._right) | ||
|
||
@property | ||
def length(self): | ||
|
@@ -1151,7 +1166,7 @@ def set_closed(self, closed): | |
raise ValueError(msg) | ||
|
||
return type(self)._simple_new( | ||
left=self.left, right=self.right, closed=closed, verify_integrity=False | ||
left=self._left, right=self._right, closed=closed, verify_integrity=False | ||
) | ||
|
||
_interval_shared_docs[ | ||
|
@@ -1177,15 +1192,15 @@ def is_non_overlapping_monotonic(self): | |
# at a point when both sides of intervals are included | ||
if self.closed == "both": | ||
return bool( | ||
(self.right[:-1] < self.left[1:]).all() | ||
or (self.left[:-1] > self.right[1:]).all() | ||
(self._right[:-1] < self._left[1:]).all() | ||
or (self._left[:-1] > self._right[1:]).all() | ||
) | ||
|
||
# non-strict inequality when closed != 'both'; at least one side is | ||
# not included in the intervals, so equality does not imply overlapping | ||
return bool( | ||
(self.right[:-1] <= self.left[1:]).all() | ||
or (self.left[:-1] >= self.right[1:]).all() | ||
(self._right[:-1] <= self._left[1:]).all() | ||
or (self._left[:-1] >= self._right[1:]).all() | ||
) | ||
|
||
# --------------------------------------------------------------------- | ||
|
@@ -1196,8 +1211,8 @@ def __array__(self, dtype=None) -> np.ndarray: | |
Return the IntervalArray's data as a numpy array of Interval | ||
objects (with dtype='object') | ||
""" | ||
left = self.left | ||
right = self.right | ||
left = self._left | ||
right = self._right | ||
mask = self.isna() | ||
closed = self._closed | ||
|
||
|
@@ -1227,8 +1242,8 @@ def __arrow_array__(self, type=None): | |
interval_type = ArrowIntervalType(subtype, self.closed) | ||
storage_array = pyarrow.StructArray.from_arrays( | ||
[ | ||
pyarrow.array(self.left, type=subtype, from_pandas=True), | ||
pyarrow.array(self.right, type=subtype, from_pandas=True), | ||
pyarrow.array(self._left, type=subtype, from_pandas=True), | ||
pyarrow.array(self._right, type=subtype, from_pandas=True), | ||
], | ||
names=["left", "right"], | ||
) | ||
|
@@ -1282,7 +1297,7 @@ def __arrow_array__(self, type=None): | |
_interval_shared_docs["to_tuples"] % dict(return_type="ndarray", examples="") | ||
) | ||
def to_tuples(self, na_tuple=True): | ||
tuples = com.asarray_tuplesafe(zip(self.left, self.right)) | ||
tuples = com.asarray_tuplesafe(zip(self._left, self._right)) | ||
if not na_tuple: | ||
# GH 18756 | ||
tuples = np.where(~self.isna(), tuples, np.nan) | ||
|
@@ -1348,8 +1363,8 @@ def contains(self, other): | |
if isinstance(other, Interval): | ||
raise NotImplementedError("contains not implemented for two intervals") | ||
|
||
return (self.left < other if self.open_left else self.left <= other) & ( | ||
other < self.right if self.open_right else other <= self.right | ||
return (self._left < other if self.open_left else self._left <= other) & ( | ||
other < self._right if self.open_right else other <= self._right | ||
) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe better to