1
1
from operator import le , lt
2
2
import textwrap
3
+ from typing import TYPE_CHECKING , Optional , Tuple , Union , cast
3
4
4
5
import numpy as np
5
6
11
12
IntervalMixin ,
12
13
intervals_to_interval_bounds ,
13
14
)
15
+ from pandas ._typing import ArrayLike , Dtype
14
16
from pandas .compat .numpy import function as nv
15
17
from pandas .util ._decorators import Appender
16
18
17
19
from pandas .core .dtypes .cast import maybe_convert_platform
18
20
from pandas .core .dtypes .common import (
19
21
is_categorical_dtype ,
20
22
is_datetime64_any_dtype ,
23
+ is_dtype_equal ,
21
24
is_float_dtype ,
25
+ is_integer ,
22
26
is_integer_dtype ,
23
27
is_interval_dtype ,
24
28
is_list_like ,
45
49
from pandas .core .indexers import check_array_indexer
46
50
from pandas .core .indexes .base import ensure_index
47
51
52
+ if TYPE_CHECKING :
53
+ from pandas import Index
54
+ from pandas .core .arrays import DatetimeArray , TimedeltaArray
55
+
48
56
_interval_shared_docs = {}
49
57
50
58
_shared_docs_kwargs = dict (
@@ -169,6 +177,17 @@ def __new__(
169
177
left = data ._left
170
178
right = data ._right
171
179
closed = closed or data .closed
180
+
181
+ if dtype is None or data .dtype == dtype :
182
+ # This path will preserve id(result._combined)
183
+ # TODO: could also validate dtype before going to simple_new
184
+ combined = data ._combined
185
+ if copy :
186
+ combined = combined .copy ()
187
+ result = cls ._simple_new (combined , closed = closed )
188
+ if verify_integrity :
189
+ result ._validate ()
190
+ return result
172
191
else :
173
192
174
193
# don't allow scalars
@@ -186,83 +205,22 @@ def __new__(
186
205
)
187
206
closed = closed or infer_closed
188
207
189
- return cls ._simple_new (
190
- left ,
191
- right ,
192
- closed ,
193
- copy = copy ,
194
- dtype = dtype ,
195
- verify_integrity = verify_integrity ,
196
- )
208
+ closed = closed or "right"
209
+ left , right = _maybe_cast_inputs (left , right , copy , dtype )
210
+ combined = _get_combined_data (left , right )
211
+ result = cls ._simple_new (combined , closed = closed )
212
+ if verify_integrity :
213
+ result ._validate ()
214
+ return result
197
215
198
216
@classmethod
199
- def _simple_new (
200
- cls , left , right , closed = None , copy = False , dtype = None , verify_integrity = True
201
- ):
217
+ def _simple_new (cls , data , closed = "right" ):
202
218
result = IntervalMixin .__new__ (cls )
203
219
204
- closed = closed or "right"
205
- left = ensure_index (left , copy = copy )
206
- right = ensure_index (right , copy = copy )
207
-
208
- if dtype is not None :
209
- # GH 19262: dtype must be an IntervalDtype to override inferred
210
- dtype = pandas_dtype (dtype )
211
- if not is_interval_dtype (dtype ):
212
- msg = f"dtype must be an IntervalDtype, got { dtype } "
213
- raise TypeError (msg )
214
- elif dtype .subtype is not None :
215
- left = left .astype (dtype .subtype )
216
- right = right .astype (dtype .subtype )
217
-
218
- # coerce dtypes to match if needed
219
- if is_float_dtype (left ) and is_integer_dtype (right ):
220
- right = right .astype (left .dtype )
221
- elif is_float_dtype (right ) and is_integer_dtype (left ):
222
- left = left .astype (right .dtype )
223
-
224
- if type (left ) != type (right ):
225
- msg = (
226
- f"must not have differing left [{ type (left ).__name__ } ] and "
227
- f"right [{ type (right ).__name__ } ] types"
228
- )
229
- raise ValueError (msg )
230
- elif is_categorical_dtype (left .dtype ) or is_string_dtype (left .dtype ):
231
- # GH 19016
232
- msg = (
233
- "category, object, and string subtypes are not supported "
234
- "for IntervalArray"
235
- )
236
- raise TypeError (msg )
237
- elif isinstance (left , ABCPeriodIndex ):
238
- msg = "Period dtypes are not supported, use a PeriodIndex instead"
239
- raise ValueError (msg )
240
- elif isinstance (left , ABCDatetimeIndex ) and str (left .tz ) != str (right .tz ):
241
- msg = (
242
- "left and right must have the same time zone, got "
243
- f"'{ left .tz } ' and '{ right .tz } '"
244
- )
245
- raise ValueError (msg )
246
-
247
- # For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray
248
- from pandas .core .ops .array_ops import maybe_upcast_datetimelike_array
249
-
250
- left = maybe_upcast_datetimelike_array (left )
251
- left = extract_array (left , extract_numpy = True )
252
- right = maybe_upcast_datetimelike_array (right )
253
- right = extract_array (right , extract_numpy = True )
254
-
255
- lbase = getattr (left , "_ndarray" , left ).base
256
- rbase = getattr (right , "_ndarray" , right ).base
257
- if lbase is not None and lbase is rbase :
258
- # If these share data, then setitem could corrupt our IA
259
- right = right .copy ()
260
-
261
- result ._left = left
262
- result ._right = right
220
+ result ._combined = data
221
+ result ._left = data [:, 0 ]
222
+ result ._right = data [:, 1 ]
263
223
result ._closed = closed
264
- if verify_integrity :
265
- result ._validate ()
266
224
return result
267
225
268
226
@classmethod
@@ -397,10 +355,16 @@ def from_breaks(cls, breaks, closed="right", copy=False, dtype=None):
397
355
def from_arrays (cls , left , right , closed = "right" , copy = False , dtype = None ):
398
356
left = maybe_convert_platform_interval (left )
399
357
right = maybe_convert_platform_interval (right )
358
+ if len (left ) != len (right ):
359
+ raise ValueError ("left and right must have the same length" )
400
360
401
- return cls ._simple_new (
402
- left , right , closed , copy = copy , dtype = dtype , verify_integrity = True
403
- )
361
+ closed = closed or "right"
362
+ left , right = _maybe_cast_inputs (left , right , copy , dtype )
363
+ combined = _get_combined_data (left , right )
364
+
365
+ result = cls ._simple_new (combined , closed )
366
+ result ._validate ()
367
+ return result
404
368
405
369
_interval_shared_docs ["from_tuples" ] = textwrap .dedent (
406
370
"""
@@ -506,19 +470,6 @@ def _validate(self):
506
470
msg = "left side of interval must be <= right side"
507
471
raise ValueError (msg )
508
472
509
- def _shallow_copy (self , left , right ):
510
- """
511
- Return a new IntervalArray with the replacement attributes
512
-
513
- Parameters
514
- ----------
515
- left : Index
516
- Values to be used for the left-side of the intervals.
517
- right : Index
518
- Values to be used for the right-side of the intervals.
519
- """
520
- return self ._simple_new (left , right , closed = self .closed , verify_integrity = False )
521
-
522
473
# ---------------------------------------------------------------------
523
474
# Descriptive
524
475
@@ -546,18 +497,20 @@ def __len__(self) -> int:
546
497
547
498
def __getitem__ (self , key ):
548
499
key = check_array_indexer (self , key )
549
- left = self ._left [key ]
550
- right = self ._right [key ]
551
500
552
- if not isinstance (left , (np .ndarray , ExtensionArray )):
553
- # scalar
554
- if is_scalar (left ) and isna (left ):
501
+ result = self ._combined [key ]
502
+
503
+ if is_integer (key ):
504
+ left , right = result [0 ], result [1 ]
505
+ if isna (left ):
555
506
return self ._fill_value
556
507
return Interval (left , right , self .closed )
557
- if np .ndim (left ) > 1 :
508
+
509
+ # TODO: need to watch out for incorrectly-reducing getitem
510
+ if np .ndim (result ) > 2 :
558
511
# GH#30588 multi-dimensional indexer disallowed
559
512
raise ValueError ("multi-dimensional indexing not allowed" )
560
- return self . _shallow_copy ( left , right )
513
+ return type ( self ). _simple_new ( result , closed = self . closed )
561
514
562
515
def __setitem__ (self , key , value ):
563
516
value_left , value_right = self ._validate_setitem_value (value )
@@ -651,7 +604,8 @@ def fillna(self, value=None, method=None, limit=None):
651
604
652
605
left = self .left .fillna (value = value_left )
653
606
right = self .right .fillna (value = value_right )
654
- return self ._shallow_copy (left , right )
607
+ combined = _get_combined_data (left , right )
608
+ return type (self )._simple_new (combined , closed = self .closed )
655
609
656
610
def astype (self , dtype , copy = True ):
657
611
"""
@@ -693,7 +647,9 @@ def astype(self, dtype, copy=True):
693
647
f"Cannot convert { self .dtype } to { dtype } ; subtypes are incompatible"
694
648
)
695
649
raise TypeError (msg ) from err
696
- return self ._shallow_copy (new_left , new_right )
650
+ # TODO: do astype directly on self._combined
651
+ combined = _get_combined_data (new_left , new_right )
652
+ return type (self )._simple_new (combined , closed = self .closed )
697
653
elif is_categorical_dtype (dtype ):
698
654
return Categorical (np .asarray (self ))
699
655
elif isinstance (dtype , StringDtype ):
@@ -734,9 +690,11 @@ def _concat_same_type(cls, to_concat):
734
690
raise ValueError ("Intervals must all be closed on the same side." )
735
691
closed = closed .pop ()
736
692
693
+ # TODO: will this mess up on dt64tz?
737
694
left = np .concatenate ([interval .left for interval in to_concat ])
738
695
right = np .concatenate ([interval .right for interval in to_concat ])
739
- return cls ._simple_new (left , right , closed = closed , copy = False )
696
+ combined = _get_combined_data (left , right ) # TODO: 1-stage concat
697
+ return cls ._simple_new (combined , closed = closed )
740
698
741
699
def copy (self ):
742
700
"""
@@ -746,11 +704,8 @@ def copy(self):
746
704
-------
747
705
IntervalArray
748
706
"""
749
- left = self ._left .copy ()
750
- right = self ._right .copy ()
751
- closed = self .closed
752
- # TODO: Could skip verify_integrity here.
753
- return type (self ).from_arrays (left , right , closed = closed )
707
+ combined = self ._combined .copy ()
708
+ return type (self )._simple_new (combined , closed = self .closed )
754
709
755
710
def isna (self ) -> np .ndarray :
756
711
return isna (self ._left )
@@ -843,7 +798,8 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None, **kwargs):
843
798
self ._right , indices , allow_fill = allow_fill , fill_value = fill_right
844
799
)
845
800
846
- return self ._shallow_copy (left_take , right_take )
801
+ combined = _get_combined_data (left_take , right_take )
802
+ return type (self )._simple_new (combined , closed = self .closed )
847
803
848
804
def _validate_listlike (self , value ):
849
805
# list-like of intervals
@@ -1170,10 +1126,7 @@ def set_closed(self, closed):
1170
1126
if closed not in VALID_CLOSED :
1171
1127
msg = f"invalid option for 'closed': { closed } "
1172
1128
raise ValueError (msg )
1173
-
1174
- return type (self )._simple_new (
1175
- left = self ._left , right = self ._right , closed = closed , verify_integrity = False
1176
- )
1129
+ return type (self )._simple_new (self ._combined , closed = closed )
1177
1130
1178
1131
_interval_shared_docs [
1179
1132
"is_non_overlapping_monotonic"
@@ -1314,9 +1267,8 @@ def to_tuples(self, na_tuple=True):
1314
1267
@Appender (_extension_array_shared_docs ["repeat" ] % _shared_docs_kwargs )
1315
1268
def repeat (self , repeats , axis = None ):
1316
1269
nv .validate_repeat (tuple (), dict (axis = axis ))
1317
- left_repeat = self .left .repeat (repeats )
1318
- right_repeat = self .right .repeat (repeats )
1319
- return self ._shallow_copy (left = left_repeat , right = right_repeat )
1270
+ combined = self ._combined .repeat (repeats , 0 )
1271
+ return type (self )._simple_new (combined , closed = self .closed )
1320
1272
1321
1273
_interval_shared_docs ["contains" ] = textwrap .dedent (
1322
1274
"""
@@ -1399,3 +1351,92 @@ def maybe_convert_platform_interval(values):
1399
1351
values = np .asarray (values )
1400
1352
1401
1353
return maybe_convert_platform (values )
1354
+
1355
+
1356
+ def _maybe_cast_inputs (
1357
+ left_orig : Union ["Index" , ArrayLike ],
1358
+ right_orig : Union ["Index" , ArrayLike ],
1359
+ copy : bool ,
1360
+ dtype : Optional [Dtype ],
1361
+ ) -> Tuple ["Index" , "Index" ]:
1362
+ left = ensure_index (left_orig , copy = copy )
1363
+ right = ensure_index (right_orig , copy = copy )
1364
+
1365
+ if dtype is not None :
1366
+ # GH#19262: dtype must be an IntervalDtype to override inferred
1367
+ dtype = pandas_dtype (dtype )
1368
+ if not is_interval_dtype (dtype ):
1369
+ msg = f"dtype must be an IntervalDtype, got { dtype } "
1370
+ raise TypeError (msg )
1371
+ dtype = cast (IntervalDtype , dtype )
1372
+ if dtype .subtype is not None :
1373
+ left = left .astype (dtype .subtype )
1374
+ right = right .astype (dtype .subtype )
1375
+
1376
+ # coerce dtypes to match if needed
1377
+ if is_float_dtype (left ) and is_integer_dtype (right ):
1378
+ right = right .astype (left .dtype )
1379
+ elif is_float_dtype (right ) and is_integer_dtype (left ):
1380
+ left = left .astype (right .dtype )
1381
+
1382
+ if type (left ) != type (right ):
1383
+ msg = (
1384
+ f"must not have differing left [{ type (left ).__name__ } ] and "
1385
+ f"right [{ type (right ).__name__ } ] types"
1386
+ )
1387
+ raise ValueError (msg )
1388
+ elif is_categorical_dtype (left .dtype ) or is_string_dtype (left .dtype ):
1389
+ # GH#19016
1390
+ msg = (
1391
+ "category, object, and string subtypes are not supported "
1392
+ "for IntervalArray"
1393
+ )
1394
+ raise TypeError (msg )
1395
+ elif isinstance (left , ABCPeriodIndex ):
1396
+ msg = "Period dtypes are not supported, use a PeriodIndex instead"
1397
+ raise ValueError (msg )
1398
+ elif isinstance (left , ABCDatetimeIndex ) and not is_dtype_equal (
1399
+ left .dtype , right .dtype
1400
+ ):
1401
+ left_arr = cast ("DatetimeArray" , left ._data )
1402
+ right_arr = cast ("DatetimeArray" , right ._data )
1403
+ msg = (
1404
+ "left and right must have the same time zone, got "
1405
+ f"'{ left_arr .tz } ' and '{ right_arr .tz } '"
1406
+ )
1407
+ raise ValueError (msg )
1408
+
1409
+ return left , right
1410
+
1411
+
1412
+ def _get_combined_data (
1413
+ left : Union ["Index" , ArrayLike ], right : Union ["Index" , ArrayLike ]
1414
+ ) -> Union [np .ndarray , "DatetimeArray" , "TimedeltaArray" ]:
1415
+ # For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray
1416
+ from pandas .core .ops .array_ops import maybe_upcast_datetimelike_array
1417
+
1418
+ left = maybe_upcast_datetimelike_array (left )
1419
+ left = extract_array (left , extract_numpy = True )
1420
+ right = maybe_upcast_datetimelike_array (right )
1421
+ right = extract_array (right , extract_numpy = True )
1422
+
1423
+ lbase = getattr (left , "_ndarray" , left ).base
1424
+ rbase = getattr (right , "_ndarray" , right ).base
1425
+ if lbase is not None and lbase is rbase :
1426
+ # If these share data, then setitem could corrupt our IA
1427
+ right = right .copy ()
1428
+
1429
+ if isinstance (left , np .ndarray ):
1430
+ assert isinstance (right , np .ndarray ) # for mypy
1431
+ combined = np .concatenate (
1432
+ [left .reshape (- 1 , 1 ), right .reshape (- 1 , 1 )],
1433
+ axis = 1 ,
1434
+ )
1435
+ else :
1436
+ left = cast (Union ["DatetimeArray" , "TimedeltaArray" ], left )
1437
+ right = cast (Union ["DatetimeArray" , "TimedeltaArray" ], right )
1438
+ combined = type (left )._concat_same_type (
1439
+ [left .reshape (- 1 , 1 ), right .reshape (- 1 , 1 )],
1440
+ axis = 1 ,
1441
+ )
1442
+ return combined
0 commit comments