Skip to content

Commit 5e488a0

Browse files
jschendeljreback
authored andcommitted
BUG: Fix IntervalArray equality comparisions (pandas-dev#30640)
1 parent 936a990 commit 5e488a0

File tree

5 files changed

+309
-1
lines changed

5 files changed

+309
-1
lines changed

doc/source/whatsnew/v1.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,7 @@ Interval
884884
- Bug in :meth:`IntervalIndex.get_indexer` where a :class:`Categorical` or :class:`CategoricalIndex` ``target`` would incorrectly raise a ``TypeError`` (:issue:`30063`)
885885
- Bug in ``pandas.core.dtypes.cast.infer_dtype_from_scalar`` where passing ``pandas_dtype=True`` did not infer :class:`IntervalDtype` (:issue:`30337`)
886886
- Bug in :class:`IntervalDtype` where the ``kind`` attribute was incorrectly set as ``None`` instead of ``"O"`` (:issue:`30568`)
887+
- Bug in :class:`IntervalIndex`, :class:`~arrays.IntervalArray`, and :class:`Series` with interval data where equality comparisons were incorrect (:issue:`24112`)
887888

888889
Indexing
889890
^^^^^^^^

pandas/core/arrays/interval.py

+55
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
is_integer_dtype,
1818
is_interval,
1919
is_interval_dtype,
20+
is_list_like,
21+
is_object_dtype,
2022
is_scalar,
2123
is_string_dtype,
2224
is_timedelta64_dtype,
@@ -37,6 +39,7 @@
3739
from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs
3840
from pandas.core.arrays.categorical import Categorical
3941
import pandas.core.common as com
42+
from pandas.core.construction import array
4043
from pandas.core.indexes.base import ensure_index
4144

4245
_VALID_CLOSED = {"left", "right", "both", "neither"}
@@ -547,6 +550,58 @@ def __setitem__(self, key, value):
547550
right.values[key] = value_right
548551
self._right = right
549552

553+
def __eq__(self, other):
554+
# ensure pandas array for list-like and eliminate non-interval scalars
555+
if is_list_like(other):
556+
if len(self) != len(other):
557+
raise ValueError("Lengths must match to compare")
558+
other = array(other)
559+
elif not isinstance(other, Interval):
560+
# non-interval scalar -> no matches
561+
return np.zeros(len(self), dtype=bool)
562+
563+
# determine the dtype of the elements we want to compare
564+
if isinstance(other, Interval):
565+
other_dtype = "interval"
566+
elif not is_categorical_dtype(other):
567+
other_dtype = other.dtype
568+
else:
569+
# for categorical defer to categories for dtype
570+
other_dtype = other.categories.dtype
571+
572+
# extract intervals if we have interval categories with matching closed
573+
if is_interval_dtype(other_dtype):
574+
if self.closed != other.categories.closed:
575+
return np.zeros(len(self), dtype=bool)
576+
other = other.categories.take(other.codes)
577+
578+
# interval-like -> need same closed and matching endpoints
579+
if is_interval_dtype(other_dtype):
580+
if self.closed != other.closed:
581+
return np.zeros(len(self), dtype=bool)
582+
return (self.left == other.left) & (self.right == other.right)
583+
584+
# non-interval/non-object dtype -> no matches
585+
if not is_object_dtype(other_dtype):
586+
return np.zeros(len(self), dtype=bool)
587+
588+
# object dtype -> iteratively check for intervals
589+
result = np.zeros(len(self), dtype=bool)
590+
for i, obj in enumerate(other):
591+
# need object to be an Interval with same closed and endpoints
592+
if (
593+
isinstance(obj, Interval)
594+
and self.closed == obj.closed
595+
and self.left[i] == obj.left
596+
and self.right[i] == obj.right
597+
):
598+
result[i] = True
599+
600+
return result
601+
602+
def __ne__(self, other):
603+
return ~self.__eq__(other)
604+
550605
def fillna(self, value=None, method=None, limit=None):
551606
"""
552607
Fill NA/NaN values using the specified method.

pandas/core/indexes/interval.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ def func(intvidx_self, other, sort=False):
205205
"__array__",
206206
"overlaps",
207207
"contains",
208+
"__eq__",
208209
"__len__",
210+
"__ne__",
209211
"set_closed",
210212
"to_tuples",
211213
],
@@ -224,7 +226,14 @@ class IntervalIndex(IntervalMixin, ExtensionIndex, accessor.PandasDelegate):
224226
# Immutable, so we are able to cache computations like isna in '_mask'
225227
_mask = None
226228

227-
_raw_inherit = {"_ndarray_values", "__array__", "overlaps", "contains"}
229+
_raw_inherit = {
230+
"_ndarray_values",
231+
"__array__",
232+
"overlaps",
233+
"contains",
234+
"__eq__",
235+
"__ne__",
236+
}
228237

229238
# --------------------------------------------------------------------
230239
# Constructors

pandas/tests/arrays/interval/test_interval.py

+235
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
1+
import operator
2+
13
import numpy as np
24
import pytest
35

6+
from pandas.core.dtypes.common import is_list_like
7+
48
import pandas as pd
59
from pandas import (
10+
Categorical,
611
Index,
712
Interval,
813
IntervalIndex,
14+
Period,
15+
Series,
916
Timedelta,
1017
Timestamp,
1118
date_range,
19+
period_range,
1220
timedelta_range,
1321
)
1422
import pandas._testing as tm
@@ -35,6 +43,18 @@ def left_right_dtypes(request):
3543
return request.param
3644

3745

46+
def create_categorical_intervals(left, right, closed="right"):
47+
return Categorical(IntervalIndex.from_arrays(left, right, closed))
48+
49+
50+
def create_series_intervals(left, right, closed="right"):
51+
return Series(IntervalArray.from_arrays(left, right, closed))
52+
53+
54+
def create_series_categorical_intervals(left, right, closed="right"):
55+
return Series(Categorical(IntervalIndex.from_arrays(left, right, closed)))
56+
57+
3858
class TestAttributes:
3959
@pytest.mark.parametrize(
4060
"left, right",
@@ -93,6 +113,221 @@ def test_set_na(self, left_right_dtypes):
93113
tm.assert_extension_array_equal(result, expected)
94114

95115

116+
class TestComparison:
117+
@pytest.fixture(params=[operator.eq, operator.ne])
118+
def op(self, request):
119+
return request.param
120+
121+
@pytest.fixture
122+
def array(self, left_right_dtypes):
123+
"""
124+
Fixture to generate an IntervalArray of various dtypes containing NA if possible
125+
"""
126+
left, right = left_right_dtypes
127+
if left.dtype != "int64":
128+
left, right = left.insert(4, np.nan), right.insert(4, np.nan)
129+
else:
130+
left, right = left.insert(4, 10), right.insert(4, 20)
131+
return IntervalArray.from_arrays(left, right)
132+
133+
@pytest.fixture(
134+
params=[
135+
IntervalArray.from_arrays,
136+
IntervalIndex.from_arrays,
137+
create_categorical_intervals,
138+
create_series_intervals,
139+
create_series_categorical_intervals,
140+
],
141+
ids=[
142+
"IntervalArray",
143+
"IntervalIndex",
144+
"Categorical[Interval]",
145+
"Series[Interval]",
146+
"Series[Categorical[Interval]]",
147+
],
148+
)
149+
def interval_constructor(self, request):
150+
"""
151+
Fixture for all pandas native interval constructors.
152+
To be used as the LHS of IntervalArray comparisons.
153+
"""
154+
return request.param
155+
156+
def elementwise_comparison(self, op, array, other):
157+
"""
158+
Helper that performs elementwise comparisions between `array` and `other`
159+
"""
160+
other = other if is_list_like(other) else [other] * len(array)
161+
return np.array([op(x, y) for x, y in zip(array, other)])
162+
163+
def test_compare_scalar_interval(self, op, array):
164+
# matches first interval
165+
other = array[0]
166+
result = op(array, other)
167+
expected = self.elementwise_comparison(op, array, other)
168+
tm.assert_numpy_array_equal(result, expected)
169+
170+
# matches on a single endpoint but not both
171+
other = Interval(array.left[0], array.right[1])
172+
result = op(array, other)
173+
expected = self.elementwise_comparison(op, array, other)
174+
tm.assert_numpy_array_equal(result, expected)
175+
176+
def test_compare_scalar_interval_mixed_closed(self, op, closed, other_closed):
177+
array = IntervalArray.from_arrays(range(2), range(1, 3), closed=closed)
178+
other = Interval(0, 1, closed=other_closed)
179+
180+
result = op(array, other)
181+
expected = self.elementwise_comparison(op, array, other)
182+
tm.assert_numpy_array_equal(result, expected)
183+
184+
def test_compare_scalar_na(self, op, array, nulls_fixture):
185+
result = op(array, nulls_fixture)
186+
expected = self.elementwise_comparison(op, array, nulls_fixture)
187+
tm.assert_numpy_array_equal(result, expected)
188+
189+
@pytest.mark.parametrize(
190+
"other",
191+
[
192+
0,
193+
1.0,
194+
True,
195+
"foo",
196+
Timestamp("2017-01-01"),
197+
Timestamp("2017-01-01", tz="US/Eastern"),
198+
Timedelta("0 days"),
199+
Period("2017-01-01", "D"),
200+
],
201+
)
202+
def test_compare_scalar_other(self, op, array, other):
203+
result = op(array, other)
204+
expected = self.elementwise_comparison(op, array, other)
205+
tm.assert_numpy_array_equal(result, expected)
206+
207+
def test_compare_list_like_interval(
208+
self, op, array, interval_constructor,
209+
):
210+
# same endpoints
211+
other = interval_constructor(array.left, array.right)
212+
result = op(array, other)
213+
expected = self.elementwise_comparison(op, array, other)
214+
tm.assert_numpy_array_equal(result, expected)
215+
216+
# different endpoints
217+
other = interval_constructor(array.left[::-1], array.right[::-1])
218+
result = op(array, other)
219+
expected = self.elementwise_comparison(op, array, other)
220+
tm.assert_numpy_array_equal(result, expected)
221+
222+
# all nan endpoints
223+
other = interval_constructor([np.nan] * 4, [np.nan] * 4)
224+
result = op(array, other)
225+
expected = self.elementwise_comparison(op, array, other)
226+
tm.assert_numpy_array_equal(result, expected)
227+
228+
def test_compare_list_like_interval_mixed_closed(
229+
self, op, interval_constructor, closed, other_closed
230+
):
231+
array = IntervalArray.from_arrays(range(2), range(1, 3), closed=closed)
232+
other = interval_constructor(range(2), range(1, 3), closed=other_closed)
233+
234+
result = op(array, other)
235+
expected = self.elementwise_comparison(op, array, other)
236+
tm.assert_numpy_array_equal(result, expected)
237+
238+
@pytest.mark.parametrize(
239+
"other",
240+
[
241+
(
242+
Interval(0, 1),
243+
Interval(Timedelta("1 day"), Timedelta("2 days")),
244+
Interval(4, 5, "both"),
245+
Interval(10, 20, "neither"),
246+
),
247+
(0, 1.5, Timestamp("20170103"), np.nan),
248+
(
249+
Timestamp("20170102", tz="US/Eastern"),
250+
Timedelta("2 days"),
251+
"baz",
252+
pd.NaT,
253+
),
254+
],
255+
)
256+
def test_compare_list_like_object(self, op, array, other):
257+
result = op(array, other)
258+
expected = self.elementwise_comparison(op, array, other)
259+
tm.assert_numpy_array_equal(result, expected)
260+
261+
def test_compare_list_like_nan(self, op, array, nulls_fixture):
262+
other = [nulls_fixture] * 4
263+
result = op(array, other)
264+
expected = self.elementwise_comparison(op, array, other)
265+
tm.assert_numpy_array_equal(result, expected)
266+
267+
@pytest.mark.parametrize(
268+
"other",
269+
[
270+
np.arange(4, dtype="int64"),
271+
np.arange(4, dtype="float64"),
272+
date_range("2017-01-01", periods=4),
273+
date_range("2017-01-01", periods=4, tz="US/Eastern"),
274+
timedelta_range("0 days", periods=4),
275+
period_range("2017-01-01", periods=4, freq="D"),
276+
Categorical(list("abab")),
277+
Categorical(date_range("2017-01-01", periods=4)),
278+
pd.array(list("abcd")),
279+
pd.array(["foo", 3.14, None, object()]),
280+
],
281+
ids=lambda x: str(x.dtype),
282+
)
283+
def test_compare_list_like_other(self, op, array, other):
284+
result = op(array, other)
285+
expected = self.elementwise_comparison(op, array, other)
286+
tm.assert_numpy_array_equal(result, expected)
287+
288+
@pytest.mark.parametrize("length", [1, 3, 5])
289+
@pytest.mark.parametrize("other_constructor", [IntervalArray, list])
290+
def test_compare_length_mismatch_errors(self, op, other_constructor, length):
291+
array = IntervalArray.from_arrays(range(4), range(1, 5))
292+
other = other_constructor([Interval(0, 1)] * length)
293+
with pytest.raises(ValueError, match="Lengths must match to compare"):
294+
op(array, other)
295+
296+
@pytest.mark.parametrize(
297+
"constructor, expected_type, assert_func",
298+
[
299+
(IntervalIndex, np.array, tm.assert_numpy_array_equal),
300+
(Series, Series, tm.assert_series_equal),
301+
],
302+
)
303+
def test_index_series_compat(self, op, constructor, expected_type, assert_func):
304+
# IntervalIndex/Series that rely on IntervalArray for comparisons
305+
breaks = range(4)
306+
index = constructor(IntervalIndex.from_breaks(breaks))
307+
308+
# scalar comparisons
309+
other = index[0]
310+
result = op(index, other)
311+
expected = expected_type(self.elementwise_comparison(op, index, other))
312+
assert_func(result, expected)
313+
314+
other = breaks[0]
315+
result = op(index, other)
316+
expected = expected_type(self.elementwise_comparison(op, index, other))
317+
assert_func(result, expected)
318+
319+
# list-like comparisons
320+
other = IntervalArray.from_breaks(breaks)
321+
result = op(index, other)
322+
expected = expected_type(self.elementwise_comparison(op, index, other))
323+
assert_func(result, expected)
324+
325+
other = [index[0], breaks[0], "foo"]
326+
result = op(index, other)
327+
expected = expected_type(self.elementwise_comparison(op, index, other))
328+
assert_func(result, expected)
329+
330+
96331
def test_repr():
97332
# GH 25022
98333
arr = IntervalArray.from_tuples([(0, 1), (1, 2)])

pandas/tests/series/test_arithmetic.py

+8
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,14 @@ def test_ser_cmp_result_names(self, names, op):
172172
result = op(ser, tdi)
173173
assert result.name == names[2]
174174

175+
# interval dtype
176+
if op in [operator.eq, operator.ne]:
177+
# interval dtype comparisons not yet implemented
178+
ii = pd.interval_range(start=0, periods=5, name=names[0])
179+
ser = Series(ii).rename(names[1])
180+
result = op(ser, ii)
181+
assert result.name == names[2]
182+
175183
# categorical
176184
if op in [operator.eq, operator.ne]:
177185
# categorical dtype comparisons raise for inequalities

0 commit comments

Comments
 (0)