1
+ import operator
1
2
from operator import le , lt
2
3
import textwrap
3
4
from typing import TYPE_CHECKING , Optional , Tuple , Union , cast
48
49
from pandas .core .construction import array , extract_array
49
50
from pandas .core .indexers import check_array_indexer
50
51
from pandas .core .indexes .base import ensure_index
51
- from pandas .core .ops import unpack_zerodim_and_defer
52
+ from pandas .core .ops import invalid_comparison , unpack_zerodim_and_defer
52
53
53
54
if TYPE_CHECKING :
54
55
from pandas import Index
@@ -520,16 +521,15 @@ def __setitem__(self, key, value):
520
521
self ._left [key ] = value_left
521
522
self ._right [key ] = value_right
522
523
523
- @unpack_zerodim_and_defer ("__eq__" )
524
- def __eq__ (self , other ):
524
+ def _cmp_method (self , other , op ):
525
525
# ensure pandas array for list-like and eliminate non-interval scalars
526
526
if is_list_like (other ):
527
527
if len (self ) != len (other ):
528
528
raise ValueError ("Lengths must match to compare" )
529
529
other = array (other )
530
530
elif not isinstance (other , Interval ):
531
531
# non-interval scalar -> no matches
532
- return np . zeros ( len ( self ), dtype = bool )
532
+ return invalid_comparison ( self , other , op )
533
533
534
534
# determine the dtype of the elements we want to compare
535
535
if isinstance (other , Interval ):
@@ -543,33 +543,87 @@ def __eq__(self, other):
543
543
# extract intervals if we have interval categories with matching closed
544
544
if is_interval_dtype (other_dtype ):
545
545
if self .closed != other .categories .closed :
546
- return np . zeros ( len ( self ), dtype = bool )
546
+ return invalid_comparison ( self , other , op )
547
547
other = other .categories .take (other .codes )
548
548
549
549
# interval-like -> need same closed and matching endpoints
550
550
if is_interval_dtype (other_dtype ):
551
551
if self .closed != other .closed :
552
- return np .zeros (len (self ), dtype = bool )
553
- return (self ._left == other .left ) & (self ._right == other .right )
552
+ return invalid_comparison (self , other , op )
553
+ if isinstance (other , Interval ):
554
+ other = type (self )._from_sequence ([other ])
555
+ if self ._combined .dtype .kind in ["m" , "M" ]:
556
+ # Need to repeat bc we do not broadcast length-1
557
+ # TODO: would be helpful to have a tile method to do
558
+ # this without copies
559
+ other = other .repeat (len (self ))
560
+ else :
561
+ other = type (self )(other )
562
+
563
+ if op is operator .eq :
564
+ return (self ._combined [:, 0 ] == other ._left ) & (
565
+ self ._combined [:, 1 ] == other ._right
566
+ )
567
+ elif op is operator .ne :
568
+ return (self ._combined [:, 0 ] != other ._left ) | (
569
+ self ._combined [:, 1 ] != other ._right
570
+ )
571
+ elif op is operator .gt :
572
+ return (self ._combined [:, 0 ] > other ._combined [:, 0 ]) | (
573
+ (self ._combined [:, 0 ] == other ._left )
574
+ & (self ._combined [:, 1 ] > other ._right )
575
+ )
576
+ elif op is operator .ge :
577
+ return (self == other ) | (self > other )
578
+ elif op is operator .lt :
579
+ return (self ._combined [:, 0 ] < other ._combined [:, 0 ]) | (
580
+ (self ._combined [:, 0 ] == other ._left )
581
+ & (self ._combined [:, 1 ] < other ._right )
582
+ )
583
+ else :
584
+ # operator.lt
585
+ return (self == other ) | (self < other )
554
586
555
587
# non-interval/non-object dtype -> no matches
556
588
if not is_object_dtype (other_dtype ):
557
- return np . zeros ( len ( self ), dtype = bool )
589
+ return invalid_comparison ( self , other , op )
558
590
559
591
# object dtype -> iteratively check for intervals
560
- result = np .zeros (len (self ), dtype = bool )
561
- for i , obj in enumerate (other ):
562
- # need object to be an Interval with same closed and endpoints
563
- if (
564
- isinstance (obj , Interval )
565
- and self .closed == obj .closed
566
- and self ._left [i ] == obj .left
567
- and self ._right [i ] == obj .right
568
- ):
569
- result [i ] = True
570
-
592
+ try :
593
+ result = np .zeros (len (self ), dtype = bool )
594
+ for i , obj in enumerate (other ):
595
+ result [i ] = op (self [i ], obj )
596
+ except TypeError :
597
+ # pd.NA
598
+ result = np .zeros (len (self ), dtype = object )
599
+ for i , obj in enumerate (other ):
600
+ result [i ] = op (self [i ], obj )
571
601
return result
572
602
603
+ @unpack_zerodim_and_defer ("__eq__" )
604
+ def __eq__ (self , other ):
605
+ return self ._cmp_method (other , operator .eq )
606
+
607
+ @unpack_zerodim_and_defer ("__ne__" )
608
+ def __ne__ (self , other ):
609
+ return self ._cmp_method (other , operator .ne )
610
+
611
+ @unpack_zerodim_and_defer ("__gt__" )
612
+ def __gt__ (self , other ):
613
+ return self ._cmp_method (other , operator .gt )
614
+
615
+ @unpack_zerodim_and_defer ("__ge__" )
616
+ def __ge__ (self , other ):
617
+ return self ._cmp_method (other , operator .ge )
618
+
619
+ @unpack_zerodim_and_defer ("__lt__" )
620
+ def __lt__ (self , other ):
621
+ return self ._cmp_method (other , operator .lt )
622
+
623
+ @unpack_zerodim_and_defer ("__le__" )
624
+ def __le__ (self , other ):
625
+ return self ._cmp_method (other , operator .le )
626
+
573
627
def fillna (self , value = None , method = None , limit = None ):
574
628
"""
575
629
Fill NA/NaN values using the specified method.
0 commit comments