diff --git a/pandas/_typing.py b/pandas/_typing.py index 1ba5be8b5b0ed..8d3044a978291 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -44,6 +44,10 @@ from pandas.core.dtypes.dtypes import ExtensionDtype from pandas import Interval + from pandas.arrays import ( + DatetimeArray, + TimedeltaArray, + ) from pandas.core.arrays.base import ExtensionArray from pandas.core.frame import DataFrame from pandas.core.generic import NDFrame @@ -88,6 +92,7 @@ ArrayLike = Union["ExtensionArray", np.ndarray] AnyArrayLike = Union[ArrayLike, "Index", "Series"] +TimeArrayLike = Union["DatetimeArray", "TimedeltaArray"] # scalars diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 3c6686b5c0173..cc72e7a290f62 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -39,6 +39,7 @@ ScalarIndexer, SequenceIndexer, SortKind, + TimeArrayLike, npt, ) from pandas.compat.numpy import function as nv @@ -82,6 +83,8 @@ ExtensionArray, _extension_array_shared_docs, ) +from pandas.core.arrays.datetimes import DatetimeArray +from pandas.core.arrays.timedeltas import TimedeltaArray import pandas.core.common as com from pandas.core.construction import ( array as pd_array, @@ -102,6 +105,7 @@ IntervalArrayT = TypeVar("IntervalArrayT", bound="IntervalArray") +IntervalSideT = Union[TimeArrayLike, np.ndarray] IntervalOrNA = Union[Interval, float] _interval_shared_docs: dict[str, str] = {} @@ -123,8 +127,8 @@ Parameters ---------- data : array-like (1-dimensional) - Array-like containing Interval objects from which to build the - %(klass)s. + Array-like (ndarray, :class:`DateTimeArray`, :class:`TimeDeltaArray`) containing + Interval objects from which to build the %(klass)s. closed : {'left', 'right', 'both', 'neither'}, default 'right' Whether the intervals are closed on the left-side, right-side, both or neither. @@ -213,8 +217,8 @@ def ndim(self) -> Literal[1]: return 1 # To make mypy recognize the fields - _left: np.ndarray - _right: np.ndarray + _left: IntervalSideT + _right: IntervalSideT _dtype: IntervalDtype # --------------------------------------------------------------------- @@ -232,9 +236,10 @@ def __new__( data = extract_array(data, extract_numpy=True) if isinstance(data, cls): - left = data._left - right = data._right + left: IntervalSideT = data._left + right: IntervalSideT = data._right closed = closed or data.closed + dtype = IntervalDtype(left.dtype, closed=closed) else: # don't allow scalars @@ -255,37 +260,57 @@ def __new__( right = lib.maybe_convert_objects(right) closed = closed or infer_closed + left, right, dtype = cls._ensure_simple_new_inputs( + left, + right, + closed=closed, + copy=copy, + dtype=dtype, + ) + + if verify_integrity: + cls._validate(left, right, dtype=dtype) + return cls._simple_new( left, right, - closed, - copy=copy, dtype=dtype, - verify_integrity=verify_integrity, ) @classmethod def _simple_new( cls: type[IntervalArrayT], + left: IntervalSideT, + right: IntervalSideT, + dtype: IntervalDtype, + ) -> IntervalArrayT: + result = IntervalMixin.__new__(cls) + result._left = left + result._right = right + result._dtype = dtype + + return result + + @classmethod + def _ensure_simple_new_inputs( + cls, left, right, closed: IntervalClosedType | None = None, copy: bool = False, dtype: Dtype | None = None, - verify_integrity: bool = True, - ) -> IntervalArrayT: - result = IntervalMixin.__new__(cls) + ) -> tuple[IntervalSideT, IntervalSideT, IntervalDtype]: + """Ensure correctness of input parameters for cls._simple_new.""" + from pandas.core.indexes.base import ensure_index + + left = ensure_index(left, copy=copy) + right = ensure_index(right, copy=copy) if closed is None and isinstance(dtype, IntervalDtype): closed = dtype.closed closed = closed or "right" - from pandas.core.indexes.base import ensure_index - - left = ensure_index(left, copy=copy) - right = ensure_index(right, copy=copy) - if dtype is not None: # GH 19262: dtype must be an IntervalDtype to override inferred dtype = pandas_dtype(dtype) @@ -346,13 +371,8 @@ def _simple_new( right = right.copy() dtype = IntervalDtype(left.dtype, closed=closed) - result._dtype = dtype - result._left = left - result._right = right - if verify_integrity: - result._validate() - return result + return left, right, dtype @classmethod def _from_sequence( @@ -512,9 +532,16 @@ def from_arrays( left = _maybe_convert_platform_interval(left) right = _maybe_convert_platform_interval(right) - return cls._simple_new( - left, right, closed, copy=copy, dtype=dtype, verify_integrity=True + left, right, dtype = cls._ensure_simple_new_inputs( + left, + right, + closed=closed, + copy=copy, + dtype=dtype, ) + cls._validate(left, right, dtype=dtype) + + return cls._simple_new(left, right, dtype=dtype) _interval_shared_docs["from_tuples"] = textwrap.dedent( """ @@ -599,32 +626,33 @@ def from_tuples( return cls.from_arrays(left, right, closed, copy=False, dtype=dtype) - def _validate(self): + @classmethod + def _validate(cls, left, right, dtype: IntervalDtype) -> None: """ Verify that the IntervalArray is valid. Checks that - * closed is valid + * dtype is correct * left and right match lengths * left and right have the same missing values * left is always below right """ - if self.closed not in VALID_CLOSED: - msg = f"invalid option for 'closed': {self.closed}" + if not isinstance(dtype, IntervalDtype): + msg = f"invalid dtype: {dtype}" raise ValueError(msg) - if len(self._left) != len(self._right): + if len(left) != len(right): msg = "left and right must have the same length" raise ValueError(msg) - left_mask = notna(self._left) - right_mask = notna(self._right) + left_mask = notna(left) + right_mask = notna(right) if not (left_mask == right_mask).all(): msg = ( "missing values must be missing in the same " "location both left and right sides" ) raise ValueError(msg) - if not (self._left[left_mask] <= self._right[left_mask]).all(): + if not (left[left_mask] <= right[left_mask]).all(): msg = "left side of interval must be <= right side" raise ValueError(msg) @@ -639,7 +667,11 @@ def _shallow_copy(self: IntervalArrayT, left, right) -> IntervalArrayT: right : Index Values to be used for the right-side of the intervals. """ - return self._simple_new(left, right, closed=self.closed, verify_integrity=False) + dtype = IntervalDtype(left.dtype, closed=self.closed) + left, right, dtype = self._ensure_simple_new_inputs(left, right, dtype=dtype) + self._validate(left, right, dtype=dtype) + + return self._simple_new(left, right, dtype=dtype) # --------------------------------------------------------------------- # Descriptive @@ -986,7 +1018,10 @@ def _concat_same_type( left = np.concatenate([interval.left for interval in to_concat]) right = np.concatenate([interval.right for interval in to_concat]) - return cls._simple_new(left, right, closed=closed, copy=False) + + left, right, dtype = cls._ensure_simple_new_inputs(left, right, closed=closed) + + return cls._simple_new(left, right, dtype=dtype) def copy(self: IntervalArrayT) -> IntervalArrayT: """ @@ -998,9 +1033,8 @@ def copy(self: IntervalArrayT) -> IntervalArrayT: """ left = self._left.copy() right = self._right.copy() - closed = self.closed - # TODO: Could skip verify_integrity here. - return type(self).from_arrays(left, right, closed=closed) + dtype = self.dtype + return self._simple_new(left, right, dtype=dtype) def isna(self) -> np.ndarray: return isna(self._left) @@ -1400,9 +1434,9 @@ def set_closed(self: IntervalArrayT, closed: IntervalClosedType) -> IntervalArra msg = f"invalid option for 'closed': {closed}" raise ValueError(msg) - return type(self)._simple_new( - left=self._left, right=self._right, closed=closed, verify_integrity=False - ) + left, right = self._left, self._right + dtype = IntervalDtype(left.dtype, closed=closed) + return self._simple_new(left, right, dtype=dtype) _interval_shared_docs[ "is_non_overlapping_monotonic" @@ -1544,9 +1578,11 @@ def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None: if isinstance(self._left, np.ndarray): np.putmask(self._left, mask, value_left) + assert isinstance(self._right, np.ndarray) np.putmask(self._right, mask, value_right) else: self._left._putmask(mask, value_left) + assert not isinstance(self._right, np.ndarray) self._right._putmask(mask, value_right) def insert(self: IntervalArrayT, loc: int, item: Interval) -> IntervalArrayT: @@ -1574,9 +1610,11 @@ def insert(self: IntervalArrayT, loc: int, item: Interval) -> IntervalArrayT: def delete(self: IntervalArrayT, loc) -> IntervalArrayT: if isinstance(self._left, np.ndarray): new_left = np.delete(self._left, loc) + assert isinstance(self._right, np.ndarray) new_right = np.delete(self._right, loc) else: new_left = self._left.delete(loc) + assert not isinstance(self._right, np.ndarray) new_right = self._right.delete(loc) return self._shallow_copy(left=new_left, right=new_right) @@ -1677,7 +1715,7 @@ def isin(self, values) -> npt.NDArray[np.bool_]: return isin(self.astype(object), values.astype(object)) @property - def _combined(self) -> ArrayLike: + def _combined(self) -> IntervalSideT: left = self.left._values.reshape(-1, 1) right = self.right._values.reshape(-1, 1) if needs_i8_conversion(left.dtype): @@ -1694,15 +1732,12 @@ def _from_combined(self, combined: np.ndarray) -> IntervalArray: dtype = self._left.dtype if needs_i8_conversion(dtype): - # error: "Type[ndarray[Any, Any]]" has no attribute "_from_sequence" - new_left = type(self._left)._from_sequence( # type: ignore[attr-defined] - nc[:, 0], dtype=dtype - ) - # error: "Type[ndarray[Any, Any]]" has no attribute "_from_sequence" - new_right = type(self._right)._from_sequence( # type: ignore[attr-defined] - nc[:, 1], dtype=dtype - ) + assert isinstance(self._left, (DatetimeArray, TimedeltaArray)) + new_left = type(self._left)._from_sequence(nc[:, 0], dtype=dtype) + assert isinstance(self._right, (DatetimeArray, TimedeltaArray)) + new_right = type(self._right)._from_sequence(nc[:, 1], dtype=dtype) else: + assert isinstance(dtype, np.dtype) new_left = nc[:, 0].view(dtype) new_right = nc[:, 1].view(dtype) return self._shallow_copy(left=new_left, right=new_right)