diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 46a7351443883..efb66c9a47a97 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -1,7 +1,7 @@ import operator from operator import le, lt import textwrap -from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Type, TypeVar, Union, cast +from typing import Sequence, Type, TypeVar import numpy as np @@ -14,7 +14,6 @@ intervals_to_interval_bounds, ) from pandas._libs.missing import NA -from pandas._typing import ArrayLike, Dtype from pandas.compat.numpy import function as nv from pandas.util._decorators import Appender @@ -22,9 +21,7 @@ from pandas.core.dtypes.common import ( is_categorical_dtype, is_datetime64_any_dtype, - is_dtype_equal, is_float_dtype, - is_integer, is_integer_dtype, is_interval_dtype, is_list_like, @@ -52,10 +49,6 @@ from pandas.core.indexes.base import ensure_index from pandas.core.ops import invalid_comparison, unpack_zerodim_and_defer -if TYPE_CHECKING: - from pandas import Index - from pandas.core.arrays import DatetimeArray, TimedeltaArray - IntervalArrayT = TypeVar("IntervalArrayT", bound="IntervalArray") _interval_shared_docs = {} @@ -182,17 +175,6 @@ def __new__( left = data._left right = data._right closed = closed or data.closed - - if dtype is None or data.dtype == dtype: - # This path will preserve id(result._combined) - # TODO: could also validate dtype before going to simple_new - combined = data._combined - if copy: - combined = combined.copy() - result = cls._simple_new(combined, closed=closed) - if verify_integrity: - result._validate() - return result else: # don't allow scalars @@ -210,22 +192,83 @@ def __new__( ) closed = closed or infer_closed - closed = closed or "right" - left, right = _maybe_cast_inputs(left, right, copy, dtype) - combined = _get_combined_data(left, right) - result = cls._simple_new(combined, closed=closed) - if verify_integrity: - result._validate() - return result + return cls._simple_new( + left, + right, + closed, + copy=copy, + dtype=dtype, + verify_integrity=verify_integrity, + ) @classmethod - def _simple_new(cls, data, closed="right"): + def _simple_new( + cls, left, right, closed=None, copy=False, dtype=None, verify_integrity=True + ): result = IntervalMixin.__new__(cls) - result._combined = data - result._left = data[:, 0] - result._right = data[:, 1] + closed = closed or "right" + left = ensure_index(left, copy=copy) + right = ensure_index(right, copy=copy) + + if dtype is not None: + # GH 19262: dtype must be an IntervalDtype to override inferred + dtype = pandas_dtype(dtype) + if not is_interval_dtype(dtype): + msg = f"dtype must be an IntervalDtype, got {dtype}" + raise TypeError(msg) + elif dtype.subtype is not None: + left = left.astype(dtype.subtype) + right = right.astype(dtype.subtype) + + # coerce dtypes to match if needed + if is_float_dtype(left) and is_integer_dtype(right): + right = right.astype(left.dtype) + elif is_float_dtype(right) and is_integer_dtype(left): + left = left.astype(right.dtype) + + if type(left) != type(right): + msg = ( + f"must not have differing left [{type(left).__name__}] and " + f"right [{type(right).__name__}] types" + ) + raise ValueError(msg) + elif is_categorical_dtype(left.dtype) or is_string_dtype(left.dtype): + # GH 19016 + msg = ( + "category, object, and string subtypes are not supported " + "for IntervalArray" + ) + raise TypeError(msg) + elif isinstance(left, ABCPeriodIndex): + msg = "Period dtypes are not supported, use a PeriodIndex instead" + raise ValueError(msg) + elif isinstance(left, ABCDatetimeIndex) and str(left.tz) != str(right.tz): + msg = ( + "left and right must have the same time zone, got " + f"'{left.tz}' and '{right.tz}'" + ) + raise ValueError(msg) + + # For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray + from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array + + left = maybe_upcast_datetimelike_array(left) + left = extract_array(left, extract_numpy=True) + right = maybe_upcast_datetimelike_array(right) + right = extract_array(right, extract_numpy=True) + + lbase = getattr(left, "_ndarray", left).base + rbase = getattr(right, "_ndarray", right).base + if lbase is not None and lbase is rbase: + # If these share data, then setitem could corrupt our IA + right = right.copy() + + result._left = left + result._right = right result._closed = closed + if verify_integrity: + result._validate() return result @classmethod @@ -360,16 +403,10 @@ def from_breaks(cls, breaks, closed="right", copy=False, dtype=None): def from_arrays(cls, left, right, closed="right", copy=False, dtype=None): left = maybe_convert_platform_interval(left) right = maybe_convert_platform_interval(right) - if len(left) != len(right): - raise ValueError("left and right must have the same length") - closed = closed or "right" - left, right = _maybe_cast_inputs(left, right, copy, dtype) - combined = _get_combined_data(left, right) - - result = cls._simple_new(combined, closed) - result._validate() - return result + return cls._simple_new( + left, right, closed, copy=copy, dtype=dtype, verify_integrity=True + ) _interval_shared_docs["from_tuples"] = textwrap.dedent( """ @@ -475,6 +512,19 @@ def _validate(self): msg = "left side of interval must be <= right side" raise ValueError(msg) + def _shallow_copy(self, left, right): + """ + Return a new IntervalArray with the replacement attributes + + Parameters + ---------- + left : Index + Values to be used for the left-side of the intervals. + right : Index + Values to be used for the right-side of the intervals. + """ + return self._simple_new(left, right, closed=self.closed, verify_integrity=False) + # --------------------------------------------------------------------- # Descriptive @@ -502,20 +552,18 @@ def __len__(self) -> int: def __getitem__(self, key): key = check_array_indexer(self, key) + left = self._left[key] + right = self._right[key] - result = self._combined[key] - - if is_integer(key): - left, right = result[0], result[1] - if isna(left): + if not isinstance(left, (np.ndarray, ExtensionArray)): + # scalar + if is_scalar(left) and isna(left): return self._fill_value return Interval(left, right, self.closed) - - # TODO: need to watch out for incorrectly-reducing getitem - if np.ndim(result) > 2: + if np.ndim(left) > 1: # GH#30588 multi-dimensional indexer disallowed raise ValueError("multi-dimensional indexing not allowed") - return type(self)._simple_new(result, closed=self.closed) + return self._shallow_copy(left, right) def __setitem__(self, key, value): value_left, value_right = self._validate_setitem_value(value) @@ -673,8 +721,7 @@ def fillna(self, value=None, method=None, limit=None): left = self.left.fillna(value=value_left) right = self.right.fillna(value=value_right) - combined = _get_combined_data(left, right) - return type(self)._simple_new(combined, closed=self.closed) + return self._shallow_copy(left, right) def astype(self, dtype, copy=True): """ @@ -716,9 +763,7 @@ def astype(self, dtype, copy=True): f"Cannot convert {self.dtype} to {dtype}; subtypes are incompatible" ) raise TypeError(msg) from err - # TODO: do astype directly on self._combined - combined = _get_combined_data(new_left, new_right) - return type(self)._simple_new(combined, closed=self.closed) + return self._shallow_copy(new_left, new_right) elif is_categorical_dtype(dtype): return Categorical(np.asarray(self), dtype=dtype) elif isinstance(dtype, StringDtype): @@ -761,11 +806,9 @@ def _concat_same_type( raise ValueError("Intervals must all be closed on the same side.") closed = closed.pop() - # TODO: will this mess up on dt64tz? left = np.concatenate([interval.left for interval in to_concat]) right = np.concatenate([interval.right for interval in to_concat]) - combined = _get_combined_data(left, right) # TODO: 1-stage concat - return cls._simple_new(combined, closed=closed) + return cls._simple_new(left, right, closed=closed, copy=False) def copy(self: IntervalArrayT) -> IntervalArrayT: """ @@ -775,8 +818,11 @@ def copy(self: IntervalArrayT) -> IntervalArrayT: ------- IntervalArray """ - combined = self._combined.copy() - return type(self)._simple_new(combined, closed=self.closed) + left = self._left.copy() + right = self._right.copy() + closed = self.closed + # TODO: Could skip verify_integrity here. + return type(self).from_arrays(left, right, closed=closed) def isna(self) -> np.ndarray: return isna(self._left) @@ -869,8 +915,7 @@ def take(self, indices, *, allow_fill=False, fill_value=None, axis=None, **kwarg self._right, indices, allow_fill=allow_fill, fill_value=fill_right ) - combined = _get_combined_data(left_take, right_take) - return type(self)._simple_new(combined, closed=self.closed) + return self._shallow_copy(left_take, right_take) def _validate_listlike(self, value): # list-like of intervals @@ -1183,7 +1228,10 @@ def set_closed(self, closed): if closed not in VALID_CLOSED: msg = f"invalid option for 'closed': {closed}" raise ValueError(msg) - return type(self)._simple_new(self._combined, closed=closed) + + return type(self)._simple_new( + left=self._left, right=self._right, closed=closed, verify_integrity=False + ) _interval_shared_docs[ "is_non_overlapping_monotonic" @@ -1324,8 +1372,9 @@ def to_tuples(self, na_tuple=True): @Appender(_extension_array_shared_docs["repeat"] % _shared_docs_kwargs) def repeat(self, repeats, axis=None): nv.validate_repeat(tuple(), dict(axis=axis)) - combined = self._combined.repeat(repeats, 0) - return type(self)._simple_new(combined, closed=self.closed) + left_repeat = self.left.repeat(repeats) + right_repeat = self.right.repeat(repeats) + return self._shallow_copy(left=left_repeat, right=right_repeat) _interval_shared_docs["contains"] = textwrap.dedent( """ @@ -1408,101 +1457,3 @@ def maybe_convert_platform_interval(values): values = np.asarray(values) return maybe_convert_platform(values) - - -def _maybe_cast_inputs( - left_orig: Union["Index", ArrayLike], - right_orig: Union["Index", ArrayLike], - copy: bool, - dtype: Optional[Dtype], -) -> Tuple["Index", "Index"]: - left = ensure_index(left_orig, copy=copy) - right = ensure_index(right_orig, copy=copy) - - if dtype is not None: - # GH#19262: dtype must be an IntervalDtype to override inferred - dtype = pandas_dtype(dtype) - if not is_interval_dtype(dtype): - msg = f"dtype must be an IntervalDtype, got {dtype}" - raise TypeError(msg) - dtype = cast(IntervalDtype, dtype) - if dtype.subtype is not None: - left = left.astype(dtype.subtype) - right = right.astype(dtype.subtype) - - # coerce dtypes to match if needed - if is_float_dtype(left) and is_integer_dtype(right): - right = right.astype(left.dtype) - elif is_float_dtype(right) and is_integer_dtype(left): - left = left.astype(right.dtype) - - if type(left) != type(right): - msg = ( - f"must not have differing left [{type(left).__name__}] and " - f"right [{type(right).__name__}] types" - ) - raise ValueError(msg) - elif is_categorical_dtype(left.dtype) or is_string_dtype(left.dtype): - # GH#19016 - msg = ( - "category, object, and string subtypes are not supported " - "for IntervalArray" - ) - raise TypeError(msg) - elif isinstance(left, ABCPeriodIndex): - msg = "Period dtypes are not supported, use a PeriodIndex instead" - raise ValueError(msg) - elif isinstance(left, ABCDatetimeIndex) and not is_dtype_equal( - left.dtype, right.dtype - ): - left_arr = cast("DatetimeArray", left._data) - right_arr = cast("DatetimeArray", right._data) - msg = ( - "left and right must have the same time zone, got " - f"'{left_arr.tz}' and '{right_arr.tz}'" - ) - raise ValueError(msg) - - return left, right - - -def _get_combined_data( - left: Union["Index", ArrayLike], right: Union["Index", ArrayLike] -) -> Union[np.ndarray, "DatetimeArray", "TimedeltaArray"]: - # For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray - from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array - - left = maybe_upcast_datetimelike_array(left) - left = extract_array(left, extract_numpy=True) - right = maybe_upcast_datetimelike_array(right) - right = extract_array(right, extract_numpy=True) - - lbase = getattr(left, "_ndarray", left).base - rbase = getattr(right, "_ndarray", right).base - if lbase is not None and lbase is rbase: - # If these share data, then setitem could corrupt our IA - right = right.copy() - - if isinstance(left, np.ndarray): - assert isinstance(right, np.ndarray) # for mypy - combined = np.concatenate( - [left.reshape(-1, 1), right.reshape(-1, 1)], - axis=1, - ) - else: - # error: Item "type" of "Union[Type[Index], Type[ExtensionArray]]" has - # no attribute "_concat_same_type" [union-attr] - - # error: Unexpected keyword argument "axis" for "_concat_same_type" of - # "ExtensionArray" [call-arg] - - # error: Item "Index" of "Union[Index, ExtensionArray]" has no - # attribute "reshape" [union-attr] - - # error: Item "ExtensionArray" of "Union[Index, ExtensionArray]" has no - # attribute "reshape" [union-attr] - combined = type(left)._concat_same_type( # type: ignore[union-attr,call-arg] - [left.reshape(-1, 1), right.reshape(-1, 1)], # type: ignore[union-attr] - axis=1, - ) - return combined diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 0a10191bfac52..98752a21e44a2 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -872,7 +872,7 @@ def delete(self, loc): """ new_left = self.left.delete(loc) new_right = self.right.delete(loc) - result = IntervalArray.from_arrays(new_left, new_right, closed=self.closed) + result = self._data._shallow_copy(new_left, new_right) return type(self)._simple_new(result, name=self.name) def insert(self, loc, item): @@ -894,7 +894,7 @@ def insert(self, loc, item): new_left = self.left.insert(loc, left_insert) new_right = self.right.insert(loc, right_insert) - result = IntervalArray.from_arrays(new_left, new_right, closed=self.closed) + result = self._data._shallow_copy(new_left, new_right) return type(self)._simple_new(result, name=self.name) # -------------------------------------------------------------------- diff --git a/pandas/tests/base/test_conversion.py b/pandas/tests/base/test_conversion.py index 24e88824088be..63280f5ccf8cd 100644 --- a/pandas/tests/base/test_conversion.py +++ b/pandas/tests/base/test_conversion.py @@ -241,7 +241,7 @@ def test_numpy_array_all_dtypes(any_numpy_dtype): (pd.Categorical(["a", "b"]), "_codes"), (pd.core.arrays.period_array(["2000", "2001"], freq="D"), "_data"), (pd.core.arrays.integer_array([0, np.nan]), "_data"), - (IntervalArray.from_breaks([0, 1]), "_combined"), + (IntervalArray.from_breaks([0, 1]), "_left"), (SparseArray([0, 1]), "_sparse_values"), (DatetimeArray(np.array([1, 2], dtype="datetime64[ns]")), "_data"), # tz-aware Datetime diff --git a/pandas/tests/indexes/interval/test_constructors.py b/pandas/tests/indexes/interval/test_constructors.py index c0ca0b415ba8e..aec7de549744f 100644 --- a/pandas/tests/indexes/interval/test_constructors.py +++ b/pandas/tests/indexes/interval/test_constructors.py @@ -266,11 +266,7 @@ def test_left_right_dont_share_data(self): # GH#36310 breaks = np.arange(5) result = IntervalIndex.from_breaks(breaks)._data - left = result._left - right = result._right - - left[:] = 10000 - assert not (right == 10000).any() + assert result._left.base is None or result._left.base is not result._right.base class TestFromTuples(Base):