Skip to content

Commit 089fad9

Browse files
authored
REF: Back IntervalArray by array instead of Index (#36310)
1 parent 456fcb9 commit 089fad9

File tree

8 files changed

+123
-80
lines changed

8 files changed

+123
-80
lines changed

doc/source/whatsnew/v1.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ Performance improvements
282282
- Performance improvement in :meth:`GroupBy.transform` with the ``numba`` engine (:issue:`36240`)
283283
- ``Styler`` uuid method altered to compress data transmission over web whilst maintaining reasonably low table collision probability (:issue:`36345`)
284284
- Performance improvement in :meth:`pd.to_datetime` with non-ns time unit for ``float`` ``dtype`` columns (:issue:`20445`)
285+
- Performance improvement in setting values on a :class:`IntervalArray` (:issue:`36310`)
285286

286287
.. ---------------------------------------------------------------------------
287288

pandas/_testing.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -977,8 +977,14 @@ def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray")
977977
"""
978978
_check_isinstance(left, right, IntervalArray)
979979

980-
assert_index_equal(left.left, right.left, exact=exact, obj=f"{obj}.left")
981-
assert_index_equal(left.right, right.right, exact=exact, obj=f"{obj}.left")
980+
kwargs = {}
981+
if left._left.dtype.kind in ["m", "M"]:
982+
# We have a DatetimeArray or TimedeltaArray
983+
kwargs["check_freq"] = False
984+
985+
assert_equal(left._left, right._left, obj=f"{obj}.left", **kwargs)
986+
assert_equal(left._right, right._right, obj=f"{obj}.left", **kwargs)
987+
982988
assert_attr_equal("closed", left, right, obj=obj)
983989

984990

@@ -989,20 +995,22 @@ def assert_period_array_equal(left, right, obj="PeriodArray"):
989995
assert_attr_equal("freq", left, right, obj=obj)
990996

991997

992-
def assert_datetime_array_equal(left, right, obj="DatetimeArray"):
998+
def assert_datetime_array_equal(left, right, obj="DatetimeArray", check_freq=True):
993999
__tracebackhide__ = True
9941000
_check_isinstance(left, right, DatetimeArray)
9951001

9961002
assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data")
997-
assert_attr_equal("freq", left, right, obj=obj)
1003+
if check_freq:
1004+
assert_attr_equal("freq", left, right, obj=obj)
9981005
assert_attr_equal("tz", left, right, obj=obj)
9991006

10001007

1001-
def assert_timedelta_array_equal(left, right, obj="TimedeltaArray"):
1008+
def assert_timedelta_array_equal(left, right, obj="TimedeltaArray", check_freq=True):
10021009
__tracebackhide__ = True
10031010
_check_isinstance(left, right, TimedeltaArray)
10041011
assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data")
1005-
assert_attr_equal("freq", left, right, obj=obj)
1012+
if check_freq:
1013+
assert_attr_equal("freq", left, right, obj=obj)
10061014

10071015

10081016
def raise_assert_detail(obj, message, left, right, diff=None, index_values=None):

pandas/core/arrays/interval.py

+71-55
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from pandas.core.dtypes.dtypes import IntervalDtype
3232
from pandas.core.dtypes.generic import (
3333
ABCDatetimeIndex,
34-
ABCIndexClass,
3534
ABCIntervalIndex,
3635
ABCPeriodIndex,
3736
ABCSeries,
@@ -42,7 +41,7 @@
4241
from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs
4342
from pandas.core.arrays.categorical import Categorical
4443
import pandas.core.common as com
45-
from pandas.core.construction import array
44+
from pandas.core.construction import array, extract_array
4645
from pandas.core.indexers import check_array_indexer
4746
from pandas.core.indexes.base import ensure_index
4847

@@ -161,12 +160,14 @@ def __new__(
161160
verify_integrity: bool = True,
162161
):
163162

164-
if isinstance(data, ABCSeries) and is_interval_dtype(data.dtype):
165-
data = data._values
163+
if isinstance(data, (ABCSeries, ABCIntervalIndex)) and is_interval_dtype(
164+
data.dtype
165+
):
166+
data = data._values # TODO: extract_array?
166167

167-
if isinstance(data, (cls, ABCIntervalIndex)):
168-
left = data.left
169-
right = data.right
168+
if isinstance(data, cls):
169+
left = data._left
170+
right = data._right
170171
closed = closed or data.closed
171172
else:
172173

@@ -243,6 +244,20 @@ def _simple_new(
243244
)
244245
raise ValueError(msg)
245246

247+
# For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray
248+
from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array
249+
250+
left = maybe_upcast_datetimelike_array(left)
251+
left = extract_array(left, extract_numpy=True)
252+
right = maybe_upcast_datetimelike_array(right)
253+
right = extract_array(right, extract_numpy=True)
254+
255+
lbase = getattr(left, "_ndarray", left).base
256+
rbase = getattr(right, "_ndarray", right).base
257+
if lbase is not None and lbase is rbase:
258+
# If these share data, then setitem could corrupt our IA
259+
right = right.copy()
260+
246261
result._left = left
247262
result._right = right
248263
result._closed = closed
@@ -476,18 +491,18 @@ def _validate(self):
476491
if self.closed not in VALID_CLOSED:
477492
msg = f"invalid option for 'closed': {self.closed}"
478493
raise ValueError(msg)
479-
if len(self.left) != len(self.right):
494+
if len(self._left) != len(self._right):
480495
msg = "left and right must have the same length"
481496
raise ValueError(msg)
482-
left_mask = notna(self.left)
483-
right_mask = notna(self.right)
497+
left_mask = notna(self._left)
498+
right_mask = notna(self._right)
484499
if not (left_mask == right_mask).all():
485500
msg = (
486501
"missing values must be missing in the same "
487502
"location both left and right sides"
488503
)
489504
raise ValueError(msg)
490-
if not (self.left[left_mask] <= self.right[left_mask]).all():
505+
if not (self._left[left_mask] <= self._right[left_mask]).all():
491506
msg = "left side of interval must be <= right side"
492507
raise ValueError(msg)
493508

@@ -527,37 +542,29 @@ def __iter__(self):
527542
return iter(np.asarray(self))
528543

529544
def __len__(self) -> int:
530-
return len(self.left)
545+
return len(self._left)
531546

532547
def __getitem__(self, value):
533548
value = check_array_indexer(self, value)
534-
left = self.left[value]
535-
right = self.right[value]
549+
left = self._left[value]
550+
right = self._right[value]
536551

537-
# scalar
538-
if not isinstance(left, ABCIndexClass):
552+
if not isinstance(left, (np.ndarray, ExtensionArray)):
553+
# scalar
539554
if is_scalar(left) and isna(left):
540555
return self._fill_value
541-
if np.ndim(left) > 1:
542-
# GH#30588 multi-dimensional indexer disallowed
543-
raise ValueError("multi-dimensional indexing not allowed")
544556
return Interval(left, right, self.closed)
545-
557+
if np.ndim(left) > 1:
558+
# GH#30588 multi-dimensional indexer disallowed
559+
raise ValueError("multi-dimensional indexing not allowed")
546560
return self._shallow_copy(left, right)
547561

548562
def __setitem__(self, key, value):
549563
value_left, value_right = self._validate_setitem_value(value)
550564
key = check_array_indexer(self, key)
551565

552-
# Need to ensure that left and right are updated atomically, so we're
553-
# forced to copy, update the copy, and swap in the new values.
554-
left = self.left.copy(deep=True)
555-
left._values[key] = value_left
556-
self._left = left
557-
558-
right = self.right.copy(deep=True)
559-
right._values[key] = value_right
560-
self._right = right
566+
self._left[key] = value_left
567+
self._right[key] = value_right
561568

562569
def __eq__(self, other):
563570
# ensure pandas array for list-like and eliminate non-interval scalars
@@ -588,7 +595,7 @@ def __eq__(self, other):
588595
if is_interval_dtype(other_dtype):
589596
if self.closed != other.closed:
590597
return np.zeros(len(self), dtype=bool)
591-
return (self.left == other.left) & (self.right == other.right)
598+
return (self._left == other.left) & (self._right == other.right)
592599

593600
# non-interval/non-object dtype -> no matches
594601
if not is_object_dtype(other_dtype):
@@ -601,8 +608,8 @@ def __eq__(self, other):
601608
if (
602609
isinstance(obj, Interval)
603610
and self.closed == obj.closed
604-
and self.left[i] == obj.left
605-
and self.right[i] == obj.right
611+
and self._left[i] == obj.left
612+
and self._right[i] == obj.right
606613
):
607614
result[i] = True
608615

@@ -665,6 +672,7 @@ def astype(self, dtype, copy=True):
665672
array : ExtensionArray or ndarray
666673
ExtensionArray or NumPy ndarray with 'dtype' for its dtype.
667674
"""
675+
from pandas import Index
668676
from pandas.core.arrays.string_ import StringDtype
669677

670678
if dtype is not None:
@@ -676,8 +684,10 @@ def astype(self, dtype, copy=True):
676684

677685
# need to cast to different subtype
678686
try:
679-
new_left = self.left.astype(dtype.subtype)
680-
new_right = self.right.astype(dtype.subtype)
687+
# We need to use Index rules for astype to prevent casting
688+
# np.nan entries to int subtypes
689+
new_left = Index(self._left, copy=False).astype(dtype.subtype)
690+
new_right = Index(self._right, copy=False).astype(dtype.subtype)
681691
except TypeError as err:
682692
msg = (
683693
f"Cannot convert {self.dtype} to {dtype}; subtypes are incompatible"
@@ -726,14 +736,14 @@ def copy(self):
726736
-------
727737
IntervalArray
728738
"""
729-
left = self.left.copy(deep=True)
730-
right = self.right.copy(deep=True)
739+
left = self._left.copy()
740+
right = self._right.copy()
731741
closed = self.closed
732742
# TODO: Could skip verify_integrity here.
733743
return type(self).from_arrays(left, right, closed=closed)
734744

735-
def isna(self):
736-
return isna(self.left)
745+
def isna(self) -> np.ndarray:
746+
return isna(self._left)
737747

738748
def shift(self, periods: int = 1, fill_value: object = None) -> "IntervalArray":
739749
if not len(self) or periods == 0:
@@ -749,7 +759,9 @@ def shift(self, periods: int = 1, fill_value: object = None) -> "IntervalArray":
749759

750760
empty_len = min(abs(periods), len(self))
751761
if isna(fill_value):
752-
fill_value = self.left._na_value
762+
from pandas import Index
763+
764+
fill_value = Index(self._left, copy=False)._na_value
753765
empty = IntervalArray.from_breaks([fill_value] * (empty_len + 1))
754766
else:
755767
empty = self._from_sequence([fill_value] * empty_len)
@@ -815,10 +827,10 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None, **kwargs):
815827
fill_left, fill_right = self._validate_fill_value(fill_value)
816828

817829
left_take = take(
818-
self.left, indices, allow_fill=allow_fill, fill_value=fill_left
830+
self._left, indices, allow_fill=allow_fill, fill_value=fill_left
819831
)
820832
right_take = take(
821-
self.right, indices, allow_fill=allow_fill, fill_value=fill_right
833+
self._right, indices, allow_fill=allow_fill, fill_value=fill_right
822834
)
823835

824836
return self._shallow_copy(left_take, right_take)
@@ -977,15 +989,19 @@ def left(self):
977989
Return the left endpoints of each Interval in the IntervalArray as
978990
an Index.
979991
"""
980-
return self._left
992+
from pandas import Index
993+
994+
return Index(self._left, copy=False)
981995

982996
@property
983997
def right(self):
984998
"""
985999
Return the right endpoints of each Interval in the IntervalArray as
9861000
an Index.
9871001
"""
988-
return self._right
1002+
from pandas import Index
1003+
1004+
return Index(self._right, copy=False)
9891005

9901006
@property
9911007
def length(self):
@@ -1146,7 +1162,7 @@ def set_closed(self, closed):
11461162
raise ValueError(msg)
11471163

11481164
return type(self)._simple_new(
1149-
left=self.left, right=self.right, closed=closed, verify_integrity=False
1165+
left=self._left, right=self._right, closed=closed, verify_integrity=False
11501166
)
11511167

11521168
_interval_shared_docs[
@@ -1172,15 +1188,15 @@ def is_non_overlapping_monotonic(self):
11721188
# at a point when both sides of intervals are included
11731189
if self.closed == "both":
11741190
return bool(
1175-
(self.right[:-1] < self.left[1:]).all()
1176-
or (self.left[:-1] > self.right[1:]).all()
1191+
(self._right[:-1] < self._left[1:]).all()
1192+
or (self._left[:-1] > self._right[1:]).all()
11771193
)
11781194

11791195
# non-strict inequality when closed != 'both'; at least one side is
11801196
# not included in the intervals, so equality does not imply overlapping
11811197
return bool(
1182-
(self.right[:-1] <= self.left[1:]).all()
1183-
or (self.left[:-1] >= self.right[1:]).all()
1198+
(self._right[:-1] <= self._left[1:]).all()
1199+
or (self._left[:-1] >= self._right[1:]).all()
11841200
)
11851201

11861202
# ---------------------------------------------------------------------
@@ -1191,8 +1207,8 @@ def __array__(self, dtype=None) -> np.ndarray:
11911207
Return the IntervalArray's data as a numpy array of Interval
11921208
objects (with dtype='object')
11931209
"""
1194-
left = self.left
1195-
right = self.right
1210+
left = self._left
1211+
right = self._right
11961212
mask = self.isna()
11971213
closed = self._closed
11981214

@@ -1222,8 +1238,8 @@ def __arrow_array__(self, type=None):
12221238
interval_type = ArrowIntervalType(subtype, self.closed)
12231239
storage_array = pyarrow.StructArray.from_arrays(
12241240
[
1225-
pyarrow.array(self.left, type=subtype, from_pandas=True),
1226-
pyarrow.array(self.right, type=subtype, from_pandas=True),
1241+
pyarrow.array(self._left, type=subtype, from_pandas=True),
1242+
pyarrow.array(self._right, type=subtype, from_pandas=True),
12271243
],
12281244
names=["left", "right"],
12291245
)
@@ -1277,7 +1293,7 @@ def __arrow_array__(self, type=None):
12771293
_interval_shared_docs["to_tuples"] % dict(return_type="ndarray", examples="")
12781294
)
12791295
def to_tuples(self, na_tuple=True):
1280-
tuples = com.asarray_tuplesafe(zip(self.left, self.right))
1296+
tuples = com.asarray_tuplesafe(zip(self._left, self._right))
12811297
if not na_tuple:
12821298
# GH 18756
12831299
tuples = np.where(~self.isna(), tuples, np.nan)
@@ -1343,8 +1359,8 @@ def contains(self, other):
13431359
if isinstance(other, Interval):
13441360
raise NotImplementedError("contains not implemented for two intervals")
13451361

1346-
return (self.left < other if self.open_left else self.left <= other) & (
1347-
other < self.right if self.open_right else other <= self.right
1362+
return (self._left < other if self.open_left else self._left <= other) & (
1363+
other < self._right if self.open_right else other <= self._right
13481364
)
13491365

13501366

0 commit comments

Comments
 (0)