Skip to content

Commit 337bf20

Browse files
authored
REF: IntervalArray comparisons (#37124)
1 parent 28a0f66 commit 337bf20

File tree

4 files changed

+68
-34
lines changed

4 files changed

+68
-34
lines changed

pandas/core/arrays/interval.py

+62-17
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import operator
12
from operator import le, lt
23
import textwrap
34
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
@@ -12,6 +13,7 @@
1213
IntervalMixin,
1314
intervals_to_interval_bounds,
1415
)
16+
from pandas._libs.missing import NA
1517
from pandas._typing import ArrayLike, Dtype
1618
from pandas.compat.numpy import function as nv
1719
from pandas.util._decorators import Appender
@@ -48,7 +50,7 @@
4850
from pandas.core.construction import array, extract_array
4951
from pandas.core.indexers import check_array_indexer
5052
from pandas.core.indexes.base import ensure_index
51-
from pandas.core.ops import unpack_zerodim_and_defer
53+
from pandas.core.ops import invalid_comparison, unpack_zerodim_and_defer
5254

5355
if TYPE_CHECKING:
5456
from pandas import Index
@@ -520,16 +522,15 @@ def __setitem__(self, key, value):
520522
self._left[key] = value_left
521523
self._right[key] = value_right
522524

523-
@unpack_zerodim_and_defer("__eq__")
524-
def __eq__(self, other):
525+
def _cmp_method(self, other, op):
525526
# ensure pandas array for list-like and eliminate non-interval scalars
526527
if is_list_like(other):
527528
if len(self) != len(other):
528529
raise ValueError("Lengths must match to compare")
529530
other = array(other)
530531
elif not isinstance(other, Interval):
531532
# non-interval scalar -> no matches
532-
return np.zeros(len(self), dtype=bool)
533+
return invalid_comparison(self, other, op)
533534

534535
# determine the dtype of the elements we want to compare
535536
if isinstance(other, Interval):
@@ -543,35 +544,79 @@ def __eq__(self, other):
543544
# extract intervals if we have interval categories with matching closed
544545
if is_interval_dtype(other_dtype):
545546
if self.closed != other.categories.closed:
546-
return np.zeros(len(self), dtype=bool)
547+
return invalid_comparison(self, other, op)
548+
547549
other = other.categories.take(
548550
other.codes, allow_fill=True, fill_value=other.categories._na_value
549551
)
550552

551553
# interval-like -> need same closed and matching endpoints
552554
if is_interval_dtype(other_dtype):
553555
if self.closed != other.closed:
554-
return np.zeros(len(self), dtype=bool)
555-
return (self._left == other.left) & (self._right == other.right)
556+
return invalid_comparison(self, other, op)
557+
elif not isinstance(other, Interval):
558+
other = type(self)(other)
559+
560+
if op is operator.eq:
561+
return (self._left == other.left) & (self._right == other.right)
562+
elif op is operator.ne:
563+
return (self._left != other.left) | (self._right != other.right)
564+
elif op is operator.gt:
565+
return (self._left > other.left) | (
566+
(self._left == other.left) & (self._right > other.right)
567+
)
568+
elif op is operator.ge:
569+
return (self == other) | (self > other)
570+
elif op is operator.lt:
571+
return (self._left < other.left) | (
572+
(self._left == other.left) & (self._right < other.right)
573+
)
574+
else:
575+
# operator.lt
576+
return (self == other) | (self < other)
556577

557578
# non-interval/non-object dtype -> no matches
558579
if not is_object_dtype(other_dtype):
559-
return np.zeros(len(self), dtype=bool)
580+
return invalid_comparison(self, other, op)
560581

561582
# object dtype -> iteratively check for intervals
562583
result = np.zeros(len(self), dtype=bool)
563584
for i, obj in enumerate(other):
564-
# need object to be an Interval with same closed and endpoints
565-
if (
566-
isinstance(obj, Interval)
567-
and self.closed == obj.closed
568-
and self._left[i] == obj.left
569-
and self._right[i] == obj.right
570-
):
571-
result[i] = True
572-
585+
try:
586+
result[i] = op(self[i], obj)
587+
except TypeError:
588+
if obj is NA:
589+
# comparison with np.nan returns NA
590+
# github.com/pandas-dev/pandas/pull/37124#discussion_r509095092
591+
result[i] = op is operator.ne
592+
else:
593+
raise
573594
return result
574595

596+
@unpack_zerodim_and_defer("__eq__")
597+
def __eq__(self, other):
598+
return self._cmp_method(other, operator.eq)
599+
600+
@unpack_zerodim_and_defer("__ne__")
601+
def __ne__(self, other):
602+
return self._cmp_method(other, operator.ne)
603+
604+
@unpack_zerodim_and_defer("__gt__")
605+
def __gt__(self, other):
606+
return self._cmp_method(other, operator.gt)
607+
608+
@unpack_zerodim_and_defer("__ge__")
609+
def __ge__(self, other):
610+
return self._cmp_method(other, operator.ge)
611+
612+
@unpack_zerodim_and_defer("__lt__")
613+
def __lt__(self, other):
614+
return self._cmp_method(other, operator.lt)
615+
616+
@unpack_zerodim_and_defer("__le__")
617+
def __le__(self, other):
618+
return self._cmp_method(other, operator.le)
619+
575620
def fillna(self, value=None, method=None, limit=None):
576621
"""
577622
Fill NA/NaN values using the specified method.

pandas/core/indexes/interval.py

-13
Original file line numberDiff line numberDiff line change
@@ -1074,19 +1074,6 @@ def _is_all_dates(self) -> bool:
10741074

10751075
# TODO: arithmetic operations
10761076

1077-
# GH#30817 until IntervalArray implements inequalities, get them from Index
1078-
def __lt__(self, other):
1079-
return Index.__lt__(self, other)
1080-
1081-
def __le__(self, other):
1082-
return Index.__le__(self, other)
1083-
1084-
def __gt__(self, other):
1085-
return Index.__gt__(self, other)
1086-
1087-
def __ge__(self, other):
1088-
return Index.__ge__(self, other)
1089-
10901077

10911078
def _is_valid_endpoint(endpoint) -> bool:
10921079
"""

pandas/tests/extension/base/methods.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def test_repeat(self, data, repeats, as_series, use_numpy):
447447
@pytest.mark.parametrize(
448448
"repeats, kwargs, error, msg",
449449
[
450-
(2, dict(axis=1), ValueError, "'axis"),
450+
(2, dict(axis=1), ValueError, "axis"),
451451
(-1, dict(), ValueError, "negative"),
452452
([1, 2], dict(), ValueError, "shape"),
453453
(2, dict(foo="bar"), TypeError, "'foo'"),

pandas/tests/indexes/interval/test_interval.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -579,9 +579,11 @@ def test_comparison(self):
579579
actual = self.index == self.index.left
580580
tm.assert_numpy_array_equal(actual, np.array([False, False]))
581581

582-
msg = (
583-
"not supported between instances of 'int' and "
584-
"'pandas._libs.interval.Interval'"
582+
msg = "|".join(
583+
[
584+
"not supported between instances of 'int' and '.*.Interval'",
585+
r"Invalid comparison between dtype=interval\[int64\] and ",
586+
]
585587
)
586588
with pytest.raises(TypeError, match=msg):
587589
self.index > 0

0 commit comments

Comments
 (0)