Skip to content

Commit 9c202a1

Browse files
authored
BUG: IntervalArray.__eq__ not deferring to Series (#36908)
* BUG: IntervalArray.__eq__ not deferring to Series * whatsnew * post-rebase fixup
1 parent 9cb3723 commit 9c202a1

File tree

4 files changed

+12
-24
lines changed

4 files changed

+12
-24
lines changed

doc/source/whatsnew/v1.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ Numeric
368368
- Bug in :class:`IntegerArray` multiplication with ``timedelta`` and ``np.timedelta64`` objects (:issue:`36870`)
369369
- Bug in :meth:`DataFrame.diff` with ``datetime64`` dtypes including ``NaT`` values failing to fill ``NaT`` results correctly (:issue:`32441`)
370370
- Bug in :class:`DataFrame` arithmetic ops incorrectly accepting keyword arguments (:issue:`36843`)
371+
- Bug in :class:`IntervalArray` comparisons with :class:`Series` not returning :class:`Series` (:issue:`36908`)
371372

372373
Conversion
373374
^^^^^^^^^^

pandas/core/arrays/interval.py

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from pandas.core.construction import array, extract_array
4949
from pandas.core.indexers import check_array_indexer
5050
from pandas.core.indexes.base import ensure_index
51+
from pandas.core.ops import unpack_zerodim_and_defer
5152

5253
if TYPE_CHECKING:
5354
from pandas import Index
@@ -519,6 +520,7 @@ def __setitem__(self, key, value):
519520
self._left[key] = value_left
520521
self._right[key] = value_right
521522

523+
@unpack_zerodim_and_defer("__eq__")
522524
def __eq__(self, other):
523525
# ensure pandas array for list-like and eliminate non-interval scalars
524526
if is_list_like(other):

pandas/core/indexes/base.py

+1-19
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from pandas._libs.tslibs.period import IncompatibleFrequency
2828
from pandas._libs.tslibs.timezones import tz_compare
2929
from pandas._typing import AnyArrayLike, Dtype, DtypeObj, Label
30-
from pandas.compat import set_function_name
3130
from pandas.compat.numpy import function as nv
3231
from pandas.errors import DuplicateLabelError, InvalidIndexError
3332
from pandas.util._decorators import Appender, cache_readonly, doc
@@ -68,7 +67,6 @@
6867
from pandas.core.dtypes.concat import concat_compat
6968
from pandas.core.dtypes.generic import (
7069
ABCCategorical,
71-
ABCDataFrame,
7270
ABCDatetimeIndex,
7371
ABCMultiIndex,
7472
ABCPandasArray,
@@ -122,22 +120,6 @@
122120
str_t = str
123121

124122

125-
def _make_arithmetic_op(op, cls):
126-
def index_arithmetic_method(self, other):
127-
if isinstance(other, (ABCSeries, ABCDataFrame, ABCTimedeltaIndex)):
128-
return NotImplemented
129-
130-
from pandas import Series
131-
132-
result = op(Series(self), other)
133-
if isinstance(result, tuple):
134-
return (Index(result[0]), Index(result[1]))
135-
return Index(result)
136-
137-
name = f"__{op.__name__}__"
138-
return set_function_name(index_arithmetic_method, name, cls)
139-
140-
141123
_o_dtype = np.dtype(object)
142124
_Identity = object
143125

@@ -5380,7 +5362,7 @@ def _cmp_method(self, other, op):
53805362
Wrapper used to dispatch comparison operations.
53815363
"""
53825364
if isinstance(other, (np.ndarray, Index, ABCSeries, ExtensionArray)):
5383-
if other.ndim > 0 and len(self) != len(other):
5365+
if len(self) != len(other):
53845366
raise ValueError("Lengths must match to compare")
53855367

53865368
if is_object_dtype(self.dtype) and isinstance(other, ABCCategorical):

pandas/tests/arithmetic/test_interval.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,10 @@ def elementwise_comparison(self, op, array, other):
103103
Helper that performs elementwise comparisons between `array` and `other`
104104
"""
105105
other = other if is_list_like(other) else [other] * len(array)
106-
return np.array([op(x, y) for x, y in zip(array, other)])
106+
expected = np.array([op(x, y) for x, y in zip(array, other)])
107+
if isinstance(other, Series):
108+
return Series(expected, index=other.index)
109+
return expected
107110

108111
def test_compare_scalar_interval(self, op, array):
109112
# matches first interval
@@ -161,19 +164,19 @@ def test_compare_list_like_interval(self, op, array, interval_constructor):
161164
other = interval_constructor(array.left, array.right)
162165
result = op(array, other)
163166
expected = self.elementwise_comparison(op, array, other)
164-
tm.assert_numpy_array_equal(result, expected)
167+
tm.assert_equal(result, expected)
165168

166169
# different endpoints
167170
other = interval_constructor(array.left[::-1], array.right[::-1])
168171
result = op(array, other)
169172
expected = self.elementwise_comparison(op, array, other)
170-
tm.assert_numpy_array_equal(result, expected)
173+
tm.assert_equal(result, expected)
171174

172175
# all nan endpoints
173176
other = interval_constructor([np.nan] * 4, [np.nan] * 4)
174177
result = op(array, other)
175178
expected = self.elementwise_comparison(op, array, other)
176-
tm.assert_numpy_array_equal(result, expected)
179+
tm.assert_equal(result, expected)
177180

178181
def test_compare_list_like_interval_mixed_closed(
179182
self, op, interval_constructor, closed, other_closed
@@ -183,7 +186,7 @@ def test_compare_list_like_interval_mixed_closed(
183186

184187
result = op(array, other)
185188
expected = self.elementwise_comparison(op, array, other)
186-
tm.assert_numpy_array_equal(result, expected)
189+
tm.assert_equal(result, expected)
187190

188191
@pytest.mark.parametrize(
189192
"other",

0 commit comments

Comments
 (0)