From 6e0c1cabc24bc253d748a100566a801d7aa4afd5 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 5 Oct 2020 19:25:17 -0700 Subject: [PATCH 1/3] BUG: IntervalArray.__eq__ not deferring to Series --- pandas/core/arrays/interval.py | 2 ++ pandas/core/indexes/base.py | 18 +++++++++--------- pandas/tests/arithmetic/test_interval.py | 13 ++++++++----- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 413430942575d..6899f64ad5bba 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -44,6 +44,7 @@ from pandas.core.construction import array, extract_array from pandas.core.indexers import check_array_indexer from pandas.core.indexes.base import ensure_index +from pandas.core.ops import unpack_zerodim_and_defer _interval_shared_docs = {} @@ -566,6 +567,7 @@ def __setitem__(self, key, value): self._left[key] = value_left self._right[key] = value_right + @unpack_zerodim_and_defer("__eq__") def __eq__(self, other): # ensure pandas array for list-like and eliminate non-interval scalars if is_list_like(other): diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index ff3d8bf05f9a5..cac56887c4231 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -65,7 +65,6 @@ from pandas.core.dtypes.concat import concat_compat from pandas.core.dtypes.generic import ( ABCCategorical, - ABCDataFrame, ABCDatetimeIndex, ABCMultiIndex, ABCPandasArray, @@ -120,9 +119,12 @@ def _make_comparison_op(op, cls): + opname = f"__{op.__name__}__" + + @ops.unpack_zerodim_and_defer(opname) def cmp_method(self, other): if isinstance(other, (np.ndarray, Index, ABCSeries, ExtensionArray)): - if other.ndim > 0 and len(self) != len(other): + if len(self) != len(other): raise ValueError("Lengths must match to compare") if is_object_dtype(self.dtype) and isinstance(other, ABCCategorical): @@ -150,15 +152,14 @@ def cmp_method(self, other): return result return ops.invalid_comparison(self, other, op) - name = f"__{op.__name__}__" - return set_function_name(cmp_method, name, cls) + return set_function_name(cmp_method, opname, cls) def _make_arithmetic_op(op, cls): - def index_arithmetic_method(self, other): - if isinstance(other, (ABCSeries, ABCDataFrame, ABCTimedeltaIndex)): - return NotImplemented + opname = f"__{op.__name__}__" + @ops.unpack_zerodim_and_defer(opname) + def index_arithmetic_method(self, other): from pandas import Series result = op(Series(self), other) @@ -166,8 +167,7 @@ def index_arithmetic_method(self, other): return (Index(result[0]), Index(result[1])) return Index(result) - name = f"__{op.__name__}__" - return set_function_name(index_arithmetic_method, name, cls) + return set_function_name(index_arithmetic_method, opname, cls) _o_dtype = np.dtype(object) diff --git a/pandas/tests/arithmetic/test_interval.py b/pandas/tests/arithmetic/test_interval.py index 72ef7ea6bf8ca..03cc4fe2bdcb5 100644 --- a/pandas/tests/arithmetic/test_interval.py +++ b/pandas/tests/arithmetic/test_interval.py @@ -103,7 +103,10 @@ def elementwise_comparison(self, op, array, other): Helper that performs elementwise comparisons between `array` and `other` """ other = other if is_list_like(other) else [other] * len(array) - return np.array([op(x, y) for x, y in zip(array, other)]) + expected = np.array([op(x, y) for x, y in zip(array, other)]) + if isinstance(other, Series): + return Series(expected, index=other.index) + return expected def test_compare_scalar_interval(self, op, array): # matches first interval @@ -161,19 +164,19 @@ def test_compare_list_like_interval(self, op, array, interval_constructor): other = interval_constructor(array.left, array.right) result = op(array, other) expected = self.elementwise_comparison(op, array, other) - tm.assert_numpy_array_equal(result, expected) + tm.assert_equal(result, expected) # different endpoints other = interval_constructor(array.left[::-1], array.right[::-1]) result = op(array, other) expected = self.elementwise_comparison(op, array, other) - tm.assert_numpy_array_equal(result, expected) + tm.assert_equal(result, expected) # all nan endpoints other = interval_constructor([np.nan] * 4, [np.nan] * 4) result = op(array, other) expected = self.elementwise_comparison(op, array, other) - tm.assert_numpy_array_equal(result, expected) + tm.assert_equal(result, expected) def test_compare_list_like_interval_mixed_closed( self, op, interval_constructor, closed, other_closed @@ -183,7 +186,7 @@ def test_compare_list_like_interval_mixed_closed( result = op(array, other) expected = self.elementwise_comparison(op, array, other) - tm.assert_numpy_array_equal(result, expected) + tm.assert_equal(result, expected) @pytest.mark.parametrize( "other", From 113b58a86fcf04c528fce51b0188c95748ab9324 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 6 Oct 2020 08:22:22 -0700 Subject: [PATCH 2/3] whatsnew --- doc/source/whatsnew/v1.2.0.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index 47ebd962b367c..f4fc6cd5a548d 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -334,6 +334,7 @@ Numeric - Bug in :class:`Series` where two :class:`Series` each have a :class:`DatetimeIndex` with different timezones having those indexes incorrectly changed when performing arithmetic operations (:issue:`33671`) - Bug in :meth:`pd._testing.assert_almost_equal` was incorrect for complex numeric types (:issue:`28235`) - Bug in :meth:`DataFrame.__rmatmul__` error handling reporting transposed shapes (:issue:`21581`) +- Bug in :class:`IntervalArray` comparisons with :class:`Series` not returning :class:`Series` (:issue:`36908`) Conversion ^^^^^^^^^^ From 76e847c1a0d1f2b0c067af6ee851924d6760d6c0 Mon Sep 17 00:00:00 2001 From: Brock Date: Sun, 11 Oct 2020 15:05:00 -0700 Subject: [PATCH 3/3] post-rebase fixup --- pandas/core/indexes/base.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 21edaf80878d0..697e6b43b6daa 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -27,7 +27,6 @@ from pandas._libs.tslibs.period import IncompatibleFrequency from pandas._libs.tslibs.timezones import tz_compare from pandas._typing import AnyArrayLike, Dtype, DtypeObj, Label -from pandas.compat import set_function_name from pandas.compat.numpy import function as nv from pandas.errors import DuplicateLabelError, InvalidIndexError from pandas.util._decorators import Appender, cache_readonly, doc @@ -121,21 +120,6 @@ str_t = str -def _make_arithmetic_op(op, cls): - opname = f"__{op.__name__}__" - - @ops.unpack_zerodim_and_defer(opname) - def index_arithmetic_method(self, other): - from pandas import Series - - result = op(Series(self), other) - if isinstance(result, tuple): - return (Index(result[0]), Index(result[1])) - return Index(result) - - return set_function_name(index_arithmetic_method, opname, cls) - - _o_dtype = np.dtype(object) _Identity = object