Skip to content

Commit 9cb3723

Browse files
authored
REF: back IntervalArray by a single ndarray (#37047)
1 parent b526620 commit 9cb3723

File tree

4 files changed

+160
-115
lines changed

4 files changed

+160
-115
lines changed

pandas/core/arrays/interval.py

+152-111
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from operator import le, lt
22
import textwrap
3+
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
34

45
import numpy as np
56

@@ -11,14 +12,17 @@
1112
IntervalMixin,
1213
intervals_to_interval_bounds,
1314
)
15+
from pandas._typing import ArrayLike, Dtype
1416
from pandas.compat.numpy import function as nv
1517
from pandas.util._decorators import Appender
1618

1719
from pandas.core.dtypes.cast import maybe_convert_platform
1820
from pandas.core.dtypes.common import (
1921
is_categorical_dtype,
2022
is_datetime64_any_dtype,
23+
is_dtype_equal,
2124
is_float_dtype,
25+
is_integer,
2226
is_integer_dtype,
2327
is_interval_dtype,
2428
is_list_like,
@@ -45,6 +49,10 @@
4549
from pandas.core.indexers import check_array_indexer
4650
from pandas.core.indexes.base import ensure_index
4751

52+
if TYPE_CHECKING:
53+
from pandas import Index
54+
from pandas.core.arrays import DatetimeArray, TimedeltaArray
55+
4856
_interval_shared_docs = {}
4957

5058
_shared_docs_kwargs = dict(
@@ -169,6 +177,17 @@ def __new__(
169177
left = data._left
170178
right = data._right
171179
closed = closed or data.closed
180+
181+
if dtype is None or data.dtype == dtype:
182+
# This path will preserve id(result._combined)
183+
# TODO: could also validate dtype before going to simple_new
184+
combined = data._combined
185+
if copy:
186+
combined = combined.copy()
187+
result = cls._simple_new(combined, closed=closed)
188+
if verify_integrity:
189+
result._validate()
190+
return result
172191
else:
173192

174193
# don't allow scalars
@@ -186,83 +205,22 @@ def __new__(
186205
)
187206
closed = closed or infer_closed
188207

189-
return cls._simple_new(
190-
left,
191-
right,
192-
closed,
193-
copy=copy,
194-
dtype=dtype,
195-
verify_integrity=verify_integrity,
196-
)
208+
closed = closed or "right"
209+
left, right = _maybe_cast_inputs(left, right, copy, dtype)
210+
combined = _get_combined_data(left, right)
211+
result = cls._simple_new(combined, closed=closed)
212+
if verify_integrity:
213+
result._validate()
214+
return result
197215

198216
@classmethod
199-
def _simple_new(
200-
cls, left, right, closed=None, copy=False, dtype=None, verify_integrity=True
201-
):
217+
def _simple_new(cls, data, closed="right"):
202218
result = IntervalMixin.__new__(cls)
203219

204-
closed = closed or "right"
205-
left = ensure_index(left, copy=copy)
206-
right = ensure_index(right, copy=copy)
207-
208-
if dtype is not None:
209-
# GH 19262: dtype must be an IntervalDtype to override inferred
210-
dtype = pandas_dtype(dtype)
211-
if not is_interval_dtype(dtype):
212-
msg = f"dtype must be an IntervalDtype, got {dtype}"
213-
raise TypeError(msg)
214-
elif dtype.subtype is not None:
215-
left = left.astype(dtype.subtype)
216-
right = right.astype(dtype.subtype)
217-
218-
# coerce dtypes to match if needed
219-
if is_float_dtype(left) and is_integer_dtype(right):
220-
right = right.astype(left.dtype)
221-
elif is_float_dtype(right) and is_integer_dtype(left):
222-
left = left.astype(right.dtype)
223-
224-
if type(left) != type(right):
225-
msg = (
226-
f"must not have differing left [{type(left).__name__}] and "
227-
f"right [{type(right).__name__}] types"
228-
)
229-
raise ValueError(msg)
230-
elif is_categorical_dtype(left.dtype) or is_string_dtype(left.dtype):
231-
# GH 19016
232-
msg = (
233-
"category, object, and string subtypes are not supported "
234-
"for IntervalArray"
235-
)
236-
raise TypeError(msg)
237-
elif isinstance(left, ABCPeriodIndex):
238-
msg = "Period dtypes are not supported, use a PeriodIndex instead"
239-
raise ValueError(msg)
240-
elif isinstance(left, ABCDatetimeIndex) and str(left.tz) != str(right.tz):
241-
msg = (
242-
"left and right must have the same time zone, got "
243-
f"'{left.tz}' and '{right.tz}'"
244-
)
245-
raise ValueError(msg)
246-
247-
# For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray
248-
from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array
249-
250-
left = maybe_upcast_datetimelike_array(left)
251-
left = extract_array(left, extract_numpy=True)
252-
right = maybe_upcast_datetimelike_array(right)
253-
right = extract_array(right, extract_numpy=True)
254-
255-
lbase = getattr(left, "_ndarray", left).base
256-
rbase = getattr(right, "_ndarray", right).base
257-
if lbase is not None and lbase is rbase:
258-
# If these share data, then setitem could corrupt our IA
259-
right = right.copy()
260-
261-
result._left = left
262-
result._right = right
220+
result._combined = data
221+
result._left = data[:, 0]
222+
result._right = data[:, 1]
263223
result._closed = closed
264-
if verify_integrity:
265-
result._validate()
266224
return result
267225

268226
@classmethod
@@ -397,10 +355,16 @@ def from_breaks(cls, breaks, closed="right", copy=False, dtype=None):
397355
def from_arrays(cls, left, right, closed="right", copy=False, dtype=None):
398356
left = maybe_convert_platform_interval(left)
399357
right = maybe_convert_platform_interval(right)
358+
if len(left) != len(right):
359+
raise ValueError("left and right must have the same length")
400360

401-
return cls._simple_new(
402-
left, right, closed, copy=copy, dtype=dtype, verify_integrity=True
403-
)
361+
closed = closed or "right"
362+
left, right = _maybe_cast_inputs(left, right, copy, dtype)
363+
combined = _get_combined_data(left, right)
364+
365+
result = cls._simple_new(combined, closed)
366+
result._validate()
367+
return result
404368

405369
_interval_shared_docs["from_tuples"] = textwrap.dedent(
406370
"""
@@ -506,19 +470,6 @@ def _validate(self):
506470
msg = "left side of interval must be <= right side"
507471
raise ValueError(msg)
508472

509-
def _shallow_copy(self, left, right):
510-
"""
511-
Return a new IntervalArray with the replacement attributes
512-
513-
Parameters
514-
----------
515-
left : Index
516-
Values to be used for the left-side of the intervals.
517-
right : Index
518-
Values to be used for the right-side of the intervals.
519-
"""
520-
return self._simple_new(left, right, closed=self.closed, verify_integrity=False)
521-
522473
# ---------------------------------------------------------------------
523474
# Descriptive
524475

@@ -546,18 +497,20 @@ def __len__(self) -> int:
546497

547498
def __getitem__(self, key):
548499
key = check_array_indexer(self, key)
549-
left = self._left[key]
550-
right = self._right[key]
551500

552-
if not isinstance(left, (np.ndarray, ExtensionArray)):
553-
# scalar
554-
if is_scalar(left) and isna(left):
501+
result = self._combined[key]
502+
503+
if is_integer(key):
504+
left, right = result[0], result[1]
505+
if isna(left):
555506
return self._fill_value
556507
return Interval(left, right, self.closed)
557-
if np.ndim(left) > 1:
508+
509+
# TODO: need to watch out for incorrectly-reducing getitem
510+
if np.ndim(result) > 2:
558511
# GH#30588 multi-dimensional indexer disallowed
559512
raise ValueError("multi-dimensional indexing not allowed")
560-
return self._shallow_copy(left, right)
513+
return type(self)._simple_new(result, closed=self.closed)
561514

562515
def __setitem__(self, key, value):
563516
value_left, value_right = self._validate_setitem_value(value)
@@ -651,7 +604,8 @@ def fillna(self, value=None, method=None, limit=None):
651604

652605
left = self.left.fillna(value=value_left)
653606
right = self.right.fillna(value=value_right)
654-
return self._shallow_copy(left, right)
607+
combined = _get_combined_data(left, right)
608+
return type(self)._simple_new(combined, closed=self.closed)
655609

656610
def astype(self, dtype, copy=True):
657611
"""
@@ -693,7 +647,9 @@ def astype(self, dtype, copy=True):
693647
f"Cannot convert {self.dtype} to {dtype}; subtypes are incompatible"
694648
)
695649
raise TypeError(msg) from err
696-
return self._shallow_copy(new_left, new_right)
650+
# TODO: do astype directly on self._combined
651+
combined = _get_combined_data(new_left, new_right)
652+
return type(self)._simple_new(combined, closed=self.closed)
697653
elif is_categorical_dtype(dtype):
698654
return Categorical(np.asarray(self))
699655
elif isinstance(dtype, StringDtype):
@@ -734,9 +690,11 @@ def _concat_same_type(cls, to_concat):
734690
raise ValueError("Intervals must all be closed on the same side.")
735691
closed = closed.pop()
736692

693+
# TODO: will this mess up on dt64tz?
737694
left = np.concatenate([interval.left for interval in to_concat])
738695
right = np.concatenate([interval.right for interval in to_concat])
739-
return cls._simple_new(left, right, closed=closed, copy=False)
696+
combined = _get_combined_data(left, right) # TODO: 1-stage concat
697+
return cls._simple_new(combined, closed=closed)
740698

741699
def copy(self):
742700
"""
@@ -746,11 +704,8 @@ def copy(self):
746704
-------
747705
IntervalArray
748706
"""
749-
left = self._left.copy()
750-
right = self._right.copy()
751-
closed = self.closed
752-
# TODO: Could skip verify_integrity here.
753-
return type(self).from_arrays(left, right, closed=closed)
707+
combined = self._combined.copy()
708+
return type(self)._simple_new(combined, closed=self.closed)
754709

755710
def isna(self) -> np.ndarray:
756711
return isna(self._left)
@@ -843,7 +798,8 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None, **kwargs):
843798
self._right, indices, allow_fill=allow_fill, fill_value=fill_right
844799
)
845800

846-
return self._shallow_copy(left_take, right_take)
801+
combined = _get_combined_data(left_take, right_take)
802+
return type(self)._simple_new(combined, closed=self.closed)
847803

848804
def _validate_listlike(self, value):
849805
# list-like of intervals
@@ -1170,10 +1126,7 @@ def set_closed(self, closed):
11701126
if closed not in VALID_CLOSED:
11711127
msg = f"invalid option for 'closed': {closed}"
11721128
raise ValueError(msg)
1173-
1174-
return type(self)._simple_new(
1175-
left=self._left, right=self._right, closed=closed, verify_integrity=False
1176-
)
1129+
return type(self)._simple_new(self._combined, closed=closed)
11771130

11781131
_interval_shared_docs[
11791132
"is_non_overlapping_monotonic"
@@ -1314,9 +1267,8 @@ def to_tuples(self, na_tuple=True):
13141267
@Appender(_extension_array_shared_docs["repeat"] % _shared_docs_kwargs)
13151268
def repeat(self, repeats, axis=None):
13161269
nv.validate_repeat(tuple(), dict(axis=axis))
1317-
left_repeat = self.left.repeat(repeats)
1318-
right_repeat = self.right.repeat(repeats)
1319-
return self._shallow_copy(left=left_repeat, right=right_repeat)
1270+
combined = self._combined.repeat(repeats, 0)
1271+
return type(self)._simple_new(combined, closed=self.closed)
13201272

13211273
_interval_shared_docs["contains"] = textwrap.dedent(
13221274
"""
@@ -1399,3 +1351,92 @@ def maybe_convert_platform_interval(values):
13991351
values = np.asarray(values)
14001352

14011353
return maybe_convert_platform(values)
1354+
1355+
1356+
def _maybe_cast_inputs(
1357+
left_orig: Union["Index", ArrayLike],
1358+
right_orig: Union["Index", ArrayLike],
1359+
copy: bool,
1360+
dtype: Optional[Dtype],
1361+
) -> Tuple["Index", "Index"]:
1362+
left = ensure_index(left_orig, copy=copy)
1363+
right = ensure_index(right_orig, copy=copy)
1364+
1365+
if dtype is not None:
1366+
# GH#19262: dtype must be an IntervalDtype to override inferred
1367+
dtype = pandas_dtype(dtype)
1368+
if not is_interval_dtype(dtype):
1369+
msg = f"dtype must be an IntervalDtype, got {dtype}"
1370+
raise TypeError(msg)
1371+
dtype = cast(IntervalDtype, dtype)
1372+
if dtype.subtype is not None:
1373+
left = left.astype(dtype.subtype)
1374+
right = right.astype(dtype.subtype)
1375+
1376+
# coerce dtypes to match if needed
1377+
if is_float_dtype(left) and is_integer_dtype(right):
1378+
right = right.astype(left.dtype)
1379+
elif is_float_dtype(right) and is_integer_dtype(left):
1380+
left = left.astype(right.dtype)
1381+
1382+
if type(left) != type(right):
1383+
msg = (
1384+
f"must not have differing left [{type(left).__name__}] and "
1385+
f"right [{type(right).__name__}] types"
1386+
)
1387+
raise ValueError(msg)
1388+
elif is_categorical_dtype(left.dtype) or is_string_dtype(left.dtype):
1389+
# GH#19016
1390+
msg = (
1391+
"category, object, and string subtypes are not supported "
1392+
"for IntervalArray"
1393+
)
1394+
raise TypeError(msg)
1395+
elif isinstance(left, ABCPeriodIndex):
1396+
msg = "Period dtypes are not supported, use a PeriodIndex instead"
1397+
raise ValueError(msg)
1398+
elif isinstance(left, ABCDatetimeIndex) and not is_dtype_equal(
1399+
left.dtype, right.dtype
1400+
):
1401+
left_arr = cast("DatetimeArray", left._data)
1402+
right_arr = cast("DatetimeArray", right._data)
1403+
msg = (
1404+
"left and right must have the same time zone, got "
1405+
f"'{left_arr.tz}' and '{right_arr.tz}'"
1406+
)
1407+
raise ValueError(msg)
1408+
1409+
return left, right
1410+
1411+
1412+
def _get_combined_data(
1413+
left: Union["Index", ArrayLike], right: Union["Index", ArrayLike]
1414+
) -> Union[np.ndarray, "DatetimeArray", "TimedeltaArray"]:
1415+
# For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray
1416+
from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array
1417+
1418+
left = maybe_upcast_datetimelike_array(left)
1419+
left = extract_array(left, extract_numpy=True)
1420+
right = maybe_upcast_datetimelike_array(right)
1421+
right = extract_array(right, extract_numpy=True)
1422+
1423+
lbase = getattr(left, "_ndarray", left).base
1424+
rbase = getattr(right, "_ndarray", right).base
1425+
if lbase is not None and lbase is rbase:
1426+
# If these share data, then setitem could corrupt our IA
1427+
right = right.copy()
1428+
1429+
if isinstance(left, np.ndarray):
1430+
assert isinstance(right, np.ndarray) # for mypy
1431+
combined = np.concatenate(
1432+
[left.reshape(-1, 1), right.reshape(-1, 1)],
1433+
axis=1,
1434+
)
1435+
else:
1436+
left = cast(Union["DatetimeArray", "TimedeltaArray"], left)
1437+
right = cast(Union["DatetimeArray", "TimedeltaArray"], right)
1438+
combined = type(left)._concat_same_type(
1439+
[left.reshape(-1, 1), right.reshape(-1, 1)],
1440+
axis=1,
1441+
)
1442+
return combined

0 commit comments

Comments
 (0)