1
+ import operator
1
2
from operator import le , lt
2
3
import textwrap
3
4
from typing import TYPE_CHECKING , Optional , Tuple , Union , cast
12
13
IntervalMixin ,
13
14
intervals_to_interval_bounds ,
14
15
)
16
+ from pandas ._libs .missing import NA
15
17
from pandas ._typing import ArrayLike , Dtype
16
18
from pandas .compat .numpy import function as nv
17
19
from pandas .util ._decorators import Appender
48
50
from pandas .core .construction import array , extract_array
49
51
from pandas .core .indexers import check_array_indexer
50
52
from pandas .core .indexes .base import ensure_index
51
- from pandas .core .ops import unpack_zerodim_and_defer
53
+ from pandas .core .ops import invalid_comparison , unpack_zerodim_and_defer
52
54
53
55
if TYPE_CHECKING :
54
56
from pandas import Index
@@ -520,16 +522,15 @@ def __setitem__(self, key, value):
520
522
self ._left [key ] = value_left
521
523
self ._right [key ] = value_right
522
524
523
- @unpack_zerodim_and_defer ("__eq__" )
524
- def __eq__ (self , other ):
525
+ def _cmp_method (self , other , op ):
525
526
# ensure pandas array for list-like and eliminate non-interval scalars
526
527
if is_list_like (other ):
527
528
if len (self ) != len (other ):
528
529
raise ValueError ("Lengths must match to compare" )
529
530
other = array (other )
530
531
elif not isinstance (other , Interval ):
531
532
# non-interval scalar -> no matches
532
- return np . zeros ( len ( self ), dtype = bool )
533
+ return invalid_comparison ( self , other , op )
533
534
534
535
# determine the dtype of the elements we want to compare
535
536
if isinstance (other , Interval ):
@@ -543,35 +544,79 @@ def __eq__(self, other):
543
544
# extract intervals if we have interval categories with matching closed
544
545
if is_interval_dtype (other_dtype ):
545
546
if self .closed != other .categories .closed :
546
- return np .zeros (len (self ), dtype = bool )
547
+ return invalid_comparison (self , other , op )
548
+
547
549
other = other .categories .take (
548
550
other .codes , allow_fill = True , fill_value = other .categories ._na_value
549
551
)
550
552
551
553
# interval-like -> need same closed and matching endpoints
552
554
if is_interval_dtype (other_dtype ):
553
555
if self .closed != other .closed :
554
- return np .zeros (len (self ), dtype = bool )
555
- return (self ._left == other .left ) & (self ._right == other .right )
556
+ return invalid_comparison (self , other , op )
557
+ elif not isinstance (other , Interval ):
558
+ other = type (self )(other )
559
+
560
+ if op is operator .eq :
561
+ return (self ._left == other .left ) & (self ._right == other .right )
562
+ elif op is operator .ne :
563
+ return (self ._left != other .left ) | (self ._right != other .right )
564
+ elif op is operator .gt :
565
+ return (self ._left > other .left ) | (
566
+ (self ._left == other .left ) & (self ._right > other .right )
567
+ )
568
+ elif op is operator .ge :
569
+ return (self == other ) | (self > other )
570
+ elif op is operator .lt :
571
+ return (self ._left < other .left ) | (
572
+ (self ._left == other .left ) & (self ._right < other .right )
573
+ )
574
+ else :
575
+ # operator.lt
576
+ return (self == other ) | (self < other )
556
577
557
578
# non-interval/non-object dtype -> no matches
558
579
if not is_object_dtype (other_dtype ):
559
- return np . zeros ( len ( self ), dtype = bool )
580
+ return invalid_comparison ( self , other , op )
560
581
561
582
# object dtype -> iteratively check for intervals
562
583
result = np .zeros (len (self ), dtype = bool )
563
584
for i , obj in enumerate (other ):
564
- # need object to be an Interval with same closed and endpoints
565
- if (
566
- isinstance ( obj , Interval )
567
- and self . closed == obj . closed
568
- and self . _left [ i ] == obj . left
569
- and self . _right [ i ] == obj . right
570
- ):
571
- result [ i ] = True
572
-
585
+ try :
586
+ result [ i ] = op ( self [ i ], obj )
587
+ except TypeError :
588
+ if obj is NA :
589
+ # comparison with np.nan returns NA
590
+ # github.com/pandas-dev/pandas/pull/37124#discussion_r509095092
591
+ result [ i ] = op is operator . ne
592
+ else :
593
+ raise
573
594
return result
574
595
596
+ @unpack_zerodim_and_defer ("__eq__" )
597
+ def __eq__ (self , other ):
598
+ return self ._cmp_method (other , operator .eq )
599
+
600
+ @unpack_zerodim_and_defer ("__ne__" )
601
+ def __ne__ (self , other ):
602
+ return self ._cmp_method (other , operator .ne )
603
+
604
+ @unpack_zerodim_and_defer ("__gt__" )
605
+ def __gt__ (self , other ):
606
+ return self ._cmp_method (other , operator .gt )
607
+
608
+ @unpack_zerodim_and_defer ("__ge__" )
609
+ def __ge__ (self , other ):
610
+ return self ._cmp_method (other , operator .ge )
611
+
612
+ @unpack_zerodim_and_defer ("__lt__" )
613
+ def __lt__ (self , other ):
614
+ return self ._cmp_method (other , operator .lt )
615
+
616
+ @unpack_zerodim_and_defer ("__le__" )
617
+ def __le__ (self , other ):
618
+ return self ._cmp_method (other , operator .le )
619
+
575
620
def fillna (self , value = None , method = None , limit = None ):
576
621
"""
577
622
Fill NA/NaN values using the specified method.
0 commit comments