diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index 74534bc371094..f77a55b95d567 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -368,6 +368,7 @@ Numeric - Bug in :class:`IntegerArray` multiplication with ``timedelta`` and ``np.timedelta64`` objects (:issue:`36870`) - Bug in :meth:`DataFrame.diff` with ``datetime64`` dtypes including ``NaT`` values failing to fill ``NaT`` results correctly (:issue:`32441`) - Bug in :class:`DataFrame` arithmetic ops incorrectly accepting keyword arguments (:issue:`36843`) +- Bug in :class:`IntervalArray` comparisons with :class:`Series` not returning :class:`Series` (:issue:`36908`) Conversion ^^^^^^^^^^ diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index d943fe3df88c5..09488b9576212 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -48,6 +48,7 @@ from pandas.core.construction import array, extract_array from pandas.core.indexers import check_array_indexer from pandas.core.indexes.base import ensure_index +from pandas.core.ops import unpack_zerodim_and_defer if TYPE_CHECKING: from pandas import Index @@ -519,6 +520,7 @@ def __setitem__(self, key, value): self._left[key] = value_left self._right[key] = value_right + @unpack_zerodim_and_defer("__eq__") def __eq__(self, other): # ensure pandas array for list-like and eliminate non-interval scalars if is_list_like(other): diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index b3f5fb6f0291a..7d3c2c2297d5d 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -27,7 +27,6 @@ from pandas._libs.tslibs.period import IncompatibleFrequency from pandas._libs.tslibs.timezones import tz_compare from pandas._typing import AnyArrayLike, Dtype, DtypeObj, Label -from pandas.compat import set_function_name from pandas.compat.numpy import function as nv from pandas.errors import DuplicateLabelError, InvalidIndexError from pandas.util._decorators import Appender, cache_readonly, doc @@ -68,7 +67,6 @@ from pandas.core.dtypes.concat import concat_compat from pandas.core.dtypes.generic import ( ABCCategorical, - ABCDataFrame, ABCDatetimeIndex, ABCMultiIndex, ABCPandasArray, @@ -122,22 +120,6 @@ str_t = str -def _make_arithmetic_op(op, cls): - def index_arithmetic_method(self, other): - if isinstance(other, (ABCSeries, ABCDataFrame, ABCTimedeltaIndex)): - return NotImplemented - - from pandas import Series - - result = op(Series(self), other) - if isinstance(result, tuple): - return (Index(result[0]), Index(result[1])) - return Index(result) - - name = f"__{op.__name__}__" - return set_function_name(index_arithmetic_method, name, cls) - - _o_dtype = np.dtype(object) _Identity = object @@ -5380,7 +5362,7 @@ def _cmp_method(self, other, op): Wrapper used to dispatch comparison operations. """ if isinstance(other, (np.ndarray, Index, ABCSeries, ExtensionArray)): - if other.ndim > 0 and len(self) != len(other): + if len(self) != len(other): raise ValueError("Lengths must match to compare") if is_object_dtype(self.dtype) and isinstance(other, ABCCategorical): diff --git a/pandas/tests/arithmetic/test_interval.py b/pandas/tests/arithmetic/test_interval.py index 72ef7ea6bf8ca..03cc4fe2bdcb5 100644 --- a/pandas/tests/arithmetic/test_interval.py +++ b/pandas/tests/arithmetic/test_interval.py @@ -103,7 +103,10 @@ def elementwise_comparison(self, op, array, other): Helper that performs elementwise comparisons between `array` and `other` """ other = other if is_list_like(other) else [other] * len(array) - return np.array([op(x, y) for x, y in zip(array, other)]) + expected = np.array([op(x, y) for x, y in zip(array, other)]) + if isinstance(other, Series): + return Series(expected, index=other.index) + return expected def test_compare_scalar_interval(self, op, array): # matches first interval @@ -161,19 +164,19 @@ def test_compare_list_like_interval(self, op, array, interval_constructor): other = interval_constructor(array.left, array.right) result = op(array, other) expected = self.elementwise_comparison(op, array, other) - tm.assert_numpy_array_equal(result, expected) + tm.assert_equal(result, expected) # different endpoints other = interval_constructor(array.left[::-1], array.right[::-1]) result = op(array, other) expected = self.elementwise_comparison(op, array, other) - tm.assert_numpy_array_equal(result, expected) + tm.assert_equal(result, expected) # all nan endpoints other = interval_constructor([np.nan] * 4, [np.nan] * 4) result = op(array, other) expected = self.elementwise_comparison(op, array, other) - tm.assert_numpy_array_equal(result, expected) + tm.assert_equal(result, expected) def test_compare_list_like_interval_mixed_closed( self, op, interval_constructor, closed, other_closed @@ -183,7 +186,7 @@ def test_compare_list_like_interval_mixed_closed( result = op(array, other) expected = self.elementwise_comparison(op, array, other) - tm.assert_numpy_array_equal(result, expected) + tm.assert_equal(result, expected) @pytest.mark.parametrize( "other",