Skip to content

Commit a3b7c45

Browse files
committed
ENH: IntervalArray comparisons
1 parent 9c202a1 commit a3b7c45

File tree

5 files changed

+79
-30
lines changed

5 files changed

+79
-30
lines changed

pandas/core/arrays/_mixins.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44

55
from pandas._libs import lib
6-
from pandas.compat.numpy import function as nv
76
from pandas.errors import AbstractMethodError
87
from pandas.util._decorators import cache_readonly, doc
98
from pandas.util._validators import validate_fillna_kwargs
@@ -139,7 +138,6 @@ def repeat(self: _T, repeats, axis=None) -> _T:
139138
--------
140139
numpy.ndarray.repeat
141140
"""
142-
nv.validate_repeat(tuple(), dict(axis=axis))
143141
new_data = self._ndarray.repeat(repeats, axis=axis)
144142
return self._from_backing_data(new_data)
145143

pandas/core/arrays/interval.py

Lines changed: 73 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import operator
12
from operator import le, lt
23
import textwrap
34
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
@@ -48,7 +49,7 @@
4849
from pandas.core.construction import array, extract_array
4950
from pandas.core.indexers import check_array_indexer
5051
from pandas.core.indexes.base import ensure_index
51-
from pandas.core.ops import unpack_zerodim_and_defer
52+
from pandas.core.ops import invalid_comparison, unpack_zerodim_and_defer
5253

5354
if TYPE_CHECKING:
5455
from pandas import Index
@@ -520,16 +521,15 @@ def __setitem__(self, key, value):
520521
self._left[key] = value_left
521522
self._right[key] = value_right
522523

523-
@unpack_zerodim_and_defer("__eq__")
524-
def __eq__(self, other):
524+
def _cmp_method(self, other, op):
525525
# ensure pandas array for list-like and eliminate non-interval scalars
526526
if is_list_like(other):
527527
if len(self) != len(other):
528528
raise ValueError("Lengths must match to compare")
529529
other = array(other)
530530
elif not isinstance(other, Interval):
531531
# non-interval scalar -> no matches
532-
return np.zeros(len(self), dtype=bool)
532+
return invalid_comparison(self, other, op)
533533

534534
# determine the dtype of the elements we want to compare
535535
if isinstance(other, Interval):
@@ -543,33 +543,87 @@ def __eq__(self, other):
543543
# extract intervals if we have interval categories with matching closed
544544
if is_interval_dtype(other_dtype):
545545
if self.closed != other.categories.closed:
546-
return np.zeros(len(self), dtype=bool)
546+
return invalid_comparison(self, other, op)
547547
other = other.categories.take(other.codes)
548548

549549
# interval-like -> need same closed and matching endpoints
550550
if is_interval_dtype(other_dtype):
551551
if self.closed != other.closed:
552-
return np.zeros(len(self), dtype=bool)
553-
return (self._left == other.left) & (self._right == other.right)
552+
return invalid_comparison(self, other, op)
553+
if isinstance(other, Interval):
554+
other = type(self)._from_sequence([other])
555+
if self._combined.dtype.kind in ["m", "M"]:
556+
# Need to repeat bc we do not broadcast length-1
557+
# TODO: would be helpful to have a tile method to do
558+
# this without copies
559+
other = other.repeat(len(self))
560+
else:
561+
other = type(self)(other)
562+
563+
if op is operator.eq:
564+
return (self._combined[:, 0] == other._left) & (
565+
self._combined[:, 1] == other._right
566+
)
567+
elif op is operator.ne:
568+
return (self._combined[:, 0] != other._left) | (
569+
self._combined[:, 1] != other._right
570+
)
571+
elif op is operator.gt:
572+
return (self._combined[:, 0] > other._combined[:, 0]) | (
573+
(self._combined[:, 0] == other._left)
574+
& (self._combined[:, 1] > other._right)
575+
)
576+
elif op is operator.ge:
577+
return (self == other) | (self > other)
578+
elif op is operator.lt:
579+
return (self._combined[:, 0] < other._combined[:, 0]) | (
580+
(self._combined[:, 0] == other._left)
581+
& (self._combined[:, 1] < other._right)
582+
)
583+
else:
584+
# operator.lt
585+
return (self == other) | (self < other)
554586

555587
# non-interval/non-object dtype -> no matches
556588
if not is_object_dtype(other_dtype):
557-
return np.zeros(len(self), dtype=bool)
589+
return invalid_comparison(self, other, op)
558590

559591
# object dtype -> iteratively check for intervals
560-
result = np.zeros(len(self), dtype=bool)
561-
for i, obj in enumerate(other):
562-
# need object to be an Interval with same closed and endpoints
563-
if (
564-
isinstance(obj, Interval)
565-
and self.closed == obj.closed
566-
and self._left[i] == obj.left
567-
and self._right[i] == obj.right
568-
):
569-
result[i] = True
570-
592+
try:
593+
result = np.zeros(len(self), dtype=bool)
594+
for i, obj in enumerate(other):
595+
result[i] = op(self[i], obj)
596+
except TypeError:
597+
# pd.NA
598+
result = np.zeros(len(self), dtype=object)
599+
for i, obj in enumerate(other):
600+
result[i] = op(self[i], obj)
571601
return result
572602

603+
@unpack_zerodim_and_defer("__eq__")
604+
def __eq__(self, other):
605+
return self._cmp_method(other, operator.eq)
606+
607+
@unpack_zerodim_and_defer("__ne__")
608+
def __ne__(self, other):
609+
return self._cmp_method(other, operator.ne)
610+
611+
@unpack_zerodim_and_defer("__gt__")
612+
def __gt__(self, other):
613+
return self._cmp_method(other, operator.gt)
614+
615+
@unpack_zerodim_and_defer("__ge__")
616+
def __ge__(self, other):
617+
return self._cmp_method(other, operator.ge)
618+
619+
@unpack_zerodim_and_defer("__lt__")
620+
def __lt__(self, other):
621+
return self._cmp_method(other, operator.lt)
622+
623+
@unpack_zerodim_and_defer("__le__")
624+
def __le__(self, other):
625+
return self._cmp_method(other, operator.le)
626+
573627
def fillna(self, value=None, method=None, limit=None):
574628
"""
575629
Fill NA/NaN values using the specified method.

pandas/tests/arithmetic/test_interval.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,6 @@ def test_compare_list_like_nan(self, op, array, nulls_fixture, request):
216216
result = op(array, other)
217217
expected = self.elementwise_comparison(op, array, other)
218218

219-
if nulls_fixture is pd.NA and array.dtype.subtype != "i8":
220-
reason = "broken for non-integer IntervalArray; see GH 31882"
221-
mark = pytest.mark.xfail(reason=reason)
222-
request.node.add_marker(mark)
223-
224219
tm.assert_numpy_array_equal(result, expected)
225220

226221
@pytest.mark.parametrize(

pandas/tests/extension/base/methods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def test_repeat(self, data, repeats, as_series, use_numpy):
443443
@pytest.mark.parametrize(
444444
"repeats, kwargs, error, msg",
445445
[
446-
(2, dict(axis=1), ValueError, "'axis"),
446+
(2, dict(axis=1), ValueError, "axis"),
447447
(-1, dict(), ValueError, "negative"),
448448
([1, 2], dict(), ValueError, "shape"),
449449
(2, dict(foo="bar"), TypeError, "'foo'"),

pandas/tests/indexes/interval/test_interval.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -579,9 +579,11 @@ def test_comparison(self):
579579
actual = self.index == self.index.left
580580
tm.assert_numpy_array_equal(actual, np.array([False, False]))
581581

582-
msg = (
583-
"not supported between instances of 'int' and "
584-
"'pandas._libs.interval.Interval'"
582+
msg = "|".join(
583+
[
584+
"not supported between instances of 'int' and '.*.Interval'",
585+
r"Invalid comparison between dtype=interval\[int64\] and ",
586+
]
585587
)
586588
with pytest.raises(TypeError, match=msg):
587589
self.index > 0

0 commit comments

Comments
 (0)