Skip to content

REF: back IntervalArray by a single ndarray #37047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 12, 2020
242 changes: 131 additions & 111 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
is_categorical_dtype,
is_datetime64_any_dtype,
is_float_dtype,
is_integer,
is_integer_dtype,
is_interval_dtype,
is_list_like,
Expand Down Expand Up @@ -169,6 +170,17 @@ def __new__(
left = data._left
right = data._right
closed = closed or data.closed

if dtype is None or data.dtype == dtype:
# This path will preserve id(result._combined)
# TODO: could also validate dtype before going to simple_new
combined = data._combined
if copy:
combined = combined.copy()
result = cls._simple_new(combined, closed=closed)
if verify_integrity:
result._validate()
return result
else:

# don't allow scalars
Expand All @@ -186,83 +198,22 @@ def __new__(
)
closed = closed or infer_closed

return cls._simple_new(
left,
right,
closed,
copy=copy,
dtype=dtype,
verify_integrity=verify_integrity,
)
closed = closed or "right"
left, right = _maybe_cast_inputs(left, right, copy, dtype)
combined = _get_combined_data(left, right)
result = cls._simple_new(combined, closed=closed)
if verify_integrity:
result._validate()
return result

@classmethod
def _simple_new(
cls, left, right, closed=None, copy=False, dtype=None, verify_integrity=True
):
def _simple_new(cls, data, closed="right"):
result = IntervalMixin.__new__(cls)

closed = closed or "right"
left = ensure_index(left, copy=copy)
right = ensure_index(right, copy=copy)

if dtype is not None:
# GH 19262: dtype must be an IntervalDtype to override inferred
dtype = pandas_dtype(dtype)
if not is_interval_dtype(dtype):
msg = f"dtype must be an IntervalDtype, got {dtype}"
raise TypeError(msg)
elif dtype.subtype is not None:
left = left.astype(dtype.subtype)
right = right.astype(dtype.subtype)

# coerce dtypes to match if needed
if is_float_dtype(left) and is_integer_dtype(right):
right = right.astype(left.dtype)
elif is_float_dtype(right) and is_integer_dtype(left):
left = left.astype(right.dtype)

if type(left) != type(right):
msg = (
f"must not have differing left [{type(left).__name__}] and "
f"right [{type(right).__name__}] types"
)
raise ValueError(msg)
elif is_categorical_dtype(left.dtype) or is_string_dtype(left.dtype):
# GH 19016
msg = (
"category, object, and string subtypes are not supported "
"for IntervalArray"
)
raise TypeError(msg)
elif isinstance(left, ABCPeriodIndex):
msg = "Period dtypes are not supported, use a PeriodIndex instead"
raise ValueError(msg)
elif isinstance(left, ABCDatetimeIndex) and str(left.tz) != str(right.tz):
msg = (
"left and right must have the same time zone, got "
f"'{left.tz}' and '{right.tz}'"
)
raise ValueError(msg)

# For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray
from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array

left = maybe_upcast_datetimelike_array(left)
left = extract_array(left, extract_numpy=True)
right = maybe_upcast_datetimelike_array(right)
right = extract_array(right, extract_numpy=True)

lbase = getattr(left, "_ndarray", left).base
rbase = getattr(right, "_ndarray", right).base
if lbase is not None and lbase is rbase:
# If these share data, then setitem could corrupt our IA
right = right.copy()

result._left = left
result._right = right
result._combined = data
result._left = data[:, 0]
result._right = data[:, 1]
result._closed = closed
if verify_integrity:
result._validate()
return result

@classmethod
Expand Down Expand Up @@ -397,10 +348,16 @@ def from_breaks(cls, breaks, closed="right", copy=False, dtype=None):
def from_arrays(cls, left, right, closed="right", copy=False, dtype=None):
left = maybe_convert_platform_interval(left)
right = maybe_convert_platform_interval(right)
if len(left) != len(right):
raise ValueError("left and right must have the same length")

return cls._simple_new(
left, right, closed, copy=copy, dtype=dtype, verify_integrity=True
)
closed = closed or "right"
left, right = _maybe_cast_inputs(left, right, copy, dtype)
combined = _get_combined_data(left, right)

result = cls._simple_new(combined, closed)
result._validate()
return result

_interval_shared_docs["from_tuples"] = textwrap.dedent(
"""
Expand Down Expand Up @@ -506,19 +463,6 @@ def _validate(self):
msg = "left side of interval must be <= right side"
raise ValueError(msg)

def _shallow_copy(self, left, right):
"""
Return a new IntervalArray with the replacement attributes

Parameters
----------
left : Index
Values to be used for the left-side of the intervals.
right : Index
Values to be used for the right-side of the intervals.
"""
return self._simple_new(left, right, closed=self.closed, verify_integrity=False)

# ---------------------------------------------------------------------
# Descriptive

Expand Down Expand Up @@ -546,18 +490,20 @@ def __len__(self) -> int:

def __getitem__(self, key):
key = check_array_indexer(self, key)
left = self._left[key]
right = self._right[key]

if not isinstance(left, (np.ndarray, ExtensionArray)):
# scalar
if is_scalar(left) and isna(left):
result = self._combined[key]

if is_integer(key):
left, right = result[0], result[1]
if isna(left):
return self._fill_value
return Interval(left, right, self.closed)
if np.ndim(left) > 1:

# TODO: need to watch out for incorrectly-reducing getitem
if np.ndim(result) > 2:
# GH#30588 multi-dimensional indexer disallowed
raise ValueError("multi-dimensional indexing not allowed")
return self._shallow_copy(left, right)
return type(self)._simple_new(result, closed=self.closed)

def __setitem__(self, key, value):
value_left, value_right = self._validate_setitem_value(value)
Expand Down Expand Up @@ -651,7 +597,8 @@ def fillna(self, value=None, method=None, limit=None):

left = self.left.fillna(value=value_left)
right = self.right.fillna(value=value_right)
return self._shallow_copy(left, right)
combined = _get_combined_data(left, right)
return type(self)._simple_new(combined, closed=self.closed)

def astype(self, dtype, copy=True):
"""
Expand Down Expand Up @@ -693,7 +640,9 @@ def astype(self, dtype, copy=True):
f"Cannot convert {self.dtype} to {dtype}; subtypes are incompatible"
)
raise TypeError(msg) from err
return self._shallow_copy(new_left, new_right)
# TODO: do astype directly on self._combined
combined = _get_combined_data(new_left, new_right)
return type(self)._simple_new(combined, closed=self.closed)
elif is_categorical_dtype(dtype):
return Categorical(np.asarray(self))
elif isinstance(dtype, StringDtype):
Expand Down Expand Up @@ -734,9 +683,11 @@ def _concat_same_type(cls, to_concat):
raise ValueError("Intervals must all be closed on the same side.")
closed = closed.pop()

# TODO: will this mess up on dt64tz?
left = np.concatenate([interval.left for interval in to_concat])
right = np.concatenate([interval.right for interval in to_concat])
return cls._simple_new(left, right, closed=closed, copy=False)
combined = _get_combined_data(left, right) # TODO: 1-stage concat
return cls._simple_new(combined, closed=closed)

def copy(self):
"""
Expand All @@ -746,11 +697,8 @@ def copy(self):
-------
IntervalArray
"""
left = self._left.copy()
right = self._right.copy()
closed = self.closed
# TODO: Could skip verify_integrity here.
return type(self).from_arrays(left, right, closed=closed)
combined = self._combined.copy()
return type(self)._simple_new(combined, closed=self.closed)

def isna(self) -> np.ndarray:
return isna(self._left)
Expand Down Expand Up @@ -843,7 +791,8 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None, **kwargs):
self._right, indices, allow_fill=allow_fill, fill_value=fill_right
)

return self._shallow_copy(left_take, right_take)
combined = _get_combined_data(left_take, right_take)
return type(self)._simple_new(combined, closed=self.closed)

def _validate_listlike(self, value):
# list-like of intervals
Expand Down Expand Up @@ -1170,10 +1119,7 @@ def set_closed(self, closed):
if closed not in VALID_CLOSED:
msg = f"invalid option for 'closed': {closed}"
raise ValueError(msg)

return type(self)._simple_new(
left=self._left, right=self._right, closed=closed, verify_integrity=False
)
return type(self)._simple_new(self._combined, closed=closed)

_interval_shared_docs[
"is_non_overlapping_monotonic"
Expand Down Expand Up @@ -1314,9 +1260,8 @@ def to_tuples(self, na_tuple=True):
@Appender(_extension_array_shared_docs["repeat"] % _shared_docs_kwargs)
def repeat(self, repeats, axis=None):
nv.validate_repeat(tuple(), dict(axis=axis))
left_repeat = self.left.repeat(repeats)
right_repeat = self.right.repeat(repeats)
return self._shallow_copy(left=left_repeat, right=right_repeat)
combined = self._combined.repeat(repeats, 0)
return type(self)._simple_new(combined, closed=self.closed)

_interval_shared_docs["contains"] = textwrap.dedent(
"""
Expand Down Expand Up @@ -1399,3 +1344,78 @@ def maybe_convert_platform_interval(values):
values = np.asarray(values)

return maybe_convert_platform(values)


def _maybe_cast_inputs(left, right, copy, dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type if you can (esp the return)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do

left = ensure_index(left, copy=copy)
right = ensure_index(right, copy=copy)

if dtype is not None:
# GH#19262: dtype must be an IntervalDtype to override inferred
dtype = pandas_dtype(dtype)
if not is_interval_dtype(dtype):
msg = f"dtype must be an IntervalDtype, got {dtype}"
raise TypeError(msg)
elif dtype.subtype is not None:
left = left.astype(dtype.subtype)
right = right.astype(dtype.subtype)

# coerce dtypes to match if needed
if is_float_dtype(left) and is_integer_dtype(right):
right = right.astype(left.dtype)
elif is_float_dtype(right) and is_integer_dtype(left):
left = left.astype(right.dtype)

if type(left) != type(right):
msg = (
f"must not have differing left [{type(left).__name__}] and "
f"right [{type(right).__name__}] types"
)
raise ValueError(msg)
elif is_categorical_dtype(left.dtype) or is_string_dtype(left.dtype):
# GH#19016
msg = (
"category, object, and string subtypes are not supported "
"for IntervalArray"
)
raise TypeError(msg)
elif isinstance(left, ABCPeriodIndex):
msg = "Period dtypes are not supported, use a PeriodIndex instead"
raise ValueError(msg)
elif isinstance(left, ABCDatetimeIndex) and str(left.tz) != str(right.tz):
# TODO: use tz_compare?
msg = (
"left and right must have the same time zone, got "
f"'{left.tz}' and '{right.tz}'"
)
raise ValueError(msg)

return left, right


def _get_combined_data(left, right):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type if you can esp the return

# For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray
from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array

left = maybe_upcast_datetimelike_array(left)
left = extract_array(left, extract_numpy=True)
right = maybe_upcast_datetimelike_array(right)
right = extract_array(right, extract_numpy=True)

lbase = getattr(left, "_ndarray", left).base
rbase = getattr(right, "_ndarray", right).base
if lbase is not None and lbase is rbase:
# If these share data, then setitem could corrupt our IA
right = right.copy()

if isinstance(left, np.ndarray):
combined = np.concatenate(
[left.reshape(-1, 1), right.reshape(-1, 1)],
axis=1,
)
else:
combined = type(left)._concat_same_type(
[left.reshape(-1, 1), right.reshape(-1, 1)],
axis=1,
)
return combined
4 changes: 2 additions & 2 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ def delete(self, loc):
"""
new_left = self.left.delete(loc)
new_right = self.right.delete(loc)
result = self._data._shallow_copy(new_left, new_right)
result = IntervalArray.from_arrays(new_left, new_right, closed=self.closed)
return self._shallow_copy(result)

def insert(self, loc, item):
Expand All @@ -918,7 +918,7 @@ def insert(self, loc, item):

new_left = self.left.insert(loc, left_insert)
new_right = self.right.insert(loc, right_insert)
result = self._data._shallow_copy(new_left, new_right)
result = IntervalArray.from_arrays(new_left, new_right, closed=self.closed)
return self._shallow_copy(result)

@Appender(_index_shared_docs["take"] % _index_doc_kwargs)
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/base/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def test_numpy_array_all_dtypes(any_numpy_dtype):
(pd.Categorical(["a", "b"]), "_codes"),
(pd.core.arrays.period_array(["2000", "2001"], freq="D"), "_data"),
(pd.core.arrays.integer_array([0, np.nan]), "_data"),
(IntervalArray.from_breaks([0, 1]), "_left"),
(IntervalArray.from_breaks([0, 1]), "_combined"),
(SparseArray([0, 1]), "_sparse_values"),
(DatetimeArray(np.array([1, 2], dtype="datetime64[ns]")), "_data"),
# tz-aware Datetime
Expand Down
6 changes: 5 additions & 1 deletion pandas/tests/indexes/interval/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,11 @@ def test_left_right_dont_share_data(self):
# GH#36310
breaks = np.arange(5)
result = IntervalIndex.from_breaks(breaks)._data
assert result._left.base is None or result._left.base is not result._right.base
left = result._left
right = result._right

left[:] = 10000
assert not (right == 10000).any()


class TestFromTuples(Base):
Expand Down