diff --git a/doc/source/whatsnew/v1.0.0.rst b/doc/source/whatsnew/v1.0.0.rst index 40690abe0c600..a5a583963f82d 100755 --- a/doc/source/whatsnew/v1.0.0.rst +++ b/doc/source/whatsnew/v1.0.0.rst @@ -881,6 +881,7 @@ Interval - Bug in :meth:`IntervalIndex.get_indexer` where a :class:`Categorical` or :class:`CategoricalIndex` ``target`` would incorrectly raise a ``TypeError`` (:issue:`30063`) - Bug in ``pandas.core.dtypes.cast.infer_dtype_from_scalar`` where passing ``pandas_dtype=True`` did not infer :class:`IntervalDtype` (:issue:`30337`) - Bug in :class:`IntervalDtype` where the ``kind`` attribute was incorrectly set as ``None`` instead of ``"O"`` (:issue:`30568`) +- Bug in :class:`IntervalIndex`, :class:`~arrays.IntervalArray`, and :class:`Series` with interval data where equality comparisons were incorrect (:issue:`24112`) Indexing ^^^^^^^^ diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index cea059fb22be1..7a12b1dcf436d 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -17,6 +17,8 @@ is_integer_dtype, is_interval, is_interval_dtype, + is_list_like, + is_object_dtype, is_scalar, is_string_dtype, is_timedelta64_dtype, @@ -37,6 +39,7 @@ from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs from pandas.core.arrays.categorical import Categorical import pandas.core.common as com +from pandas.core.construction import array from pandas.core.indexes.base import ensure_index _VALID_CLOSED = {"left", "right", "both", "neither"} @@ -547,6 +550,58 @@ def __setitem__(self, key, value): right.values[key] = value_right self._right = right + def __eq__(self, other): + # ensure pandas array for list-like and eliminate non-interval scalars + if is_list_like(other): + if len(self) != len(other): + raise ValueError("Lengths must match to compare") + other = array(other) + elif not isinstance(other, Interval): + # non-interval scalar -> no matches + return np.zeros(len(self), dtype=bool) + + # determine the dtype of the elements we want to compare + if isinstance(other, Interval): + other_dtype = "interval" + elif not is_categorical_dtype(other): + other_dtype = other.dtype + else: + # for categorical defer to categories for dtype + other_dtype = other.categories.dtype + + # extract intervals if we have interval categories with matching closed + if is_interval_dtype(other_dtype): + if self.closed != other.categories.closed: + return np.zeros(len(self), dtype=bool) + other = other.categories.take(other.codes) + + # interval-like -> need same closed and matching endpoints + if is_interval_dtype(other_dtype): + if self.closed != other.closed: + return np.zeros(len(self), dtype=bool) + return (self.left == other.left) & (self.right == other.right) + + # non-interval/non-object dtype -> no matches + if not is_object_dtype(other_dtype): + return np.zeros(len(self), dtype=bool) + + # object dtype -> iteratively check for intervals + result = np.zeros(len(self), dtype=bool) + for i, obj in enumerate(other): + # need object to be an Interval with same closed and endpoints + if ( + isinstance(obj, Interval) + and self.closed == obj.closed + and self.left[i] == obj.left + and self.right[i] == obj.right + ): + result[i] = True + + return result + + def __ne__(self, other): + return ~self.__eq__(other) + def fillna(self, value=None, method=None, limit=None): """ Fill NA/NaN values using the specified method. diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index cae9fa949f711..5708f8a73fe63 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -205,7 +205,9 @@ def func(intvidx_self, other, sort=False): "__array__", "overlaps", "contains", + "__eq__", "__len__", + "__ne__", "set_closed", "to_tuples", ], @@ -224,7 +226,14 @@ class IntervalIndex(IntervalMixin, Index, accessor.PandasDelegate): # Immutable, so we are able to cache computations like isna in '_mask' _mask = None - _raw_inherit = {"_ndarray_values", "__array__", "overlaps", "contains"} + _raw_inherit = { + "_ndarray_values", + "__array__", + "overlaps", + "contains", + "__eq__", + "__ne__", + } # -------------------------------------------------------------------- # Constructors diff --git a/pandas/tests/arrays/interval/test_interval.py b/pandas/tests/arrays/interval/test_interval.py index 655a6e717119b..3bdd68fa904fc 100644 --- a/pandas/tests/arrays/interval/test_interval.py +++ b/pandas/tests/arrays/interval/test_interval.py @@ -1,14 +1,22 @@ +import operator + import numpy as np import pytest +from pandas.core.dtypes.common import is_list_like + import pandas as pd from pandas import ( + Categorical, Index, Interval, IntervalIndex, + Period, + Series, Timedelta, Timestamp, date_range, + period_range, timedelta_range, ) from pandas.core.arrays import IntervalArray @@ -35,6 +43,18 @@ def left_right_dtypes(request): return request.param +def create_categorical_intervals(left, right, closed="right"): + return Categorical(IntervalIndex.from_arrays(left, right, closed)) + + +def create_series_intervals(left, right, closed="right"): + return Series(IntervalArray.from_arrays(left, right, closed)) + + +def create_series_categorical_intervals(left, right, closed="right"): + return Series(Categorical(IntervalIndex.from_arrays(left, right, closed))) + + class TestAttributes: @pytest.mark.parametrize( "left, right", @@ -93,6 +113,221 @@ def test_set_na(self, left_right_dtypes): tm.assert_extension_array_equal(result, expected) +class TestComparison: + @pytest.fixture(params=[operator.eq, operator.ne]) + def op(self, request): + return request.param + + @pytest.fixture + def array(self, left_right_dtypes): + """ + Fixture to generate an IntervalArray of various dtypes containing NA if possible + """ + left, right = left_right_dtypes + if left.dtype != "int64": + left, right = left.insert(4, np.nan), right.insert(4, np.nan) + else: + left, right = left.insert(4, 10), right.insert(4, 20) + return IntervalArray.from_arrays(left, right) + + @pytest.fixture( + params=[ + IntervalArray.from_arrays, + IntervalIndex.from_arrays, + create_categorical_intervals, + create_series_intervals, + create_series_categorical_intervals, + ], + ids=[ + "IntervalArray", + "IntervalIndex", + "Categorical[Interval]", + "Series[Interval]", + "Series[Categorical[Interval]]", + ], + ) + def interval_constructor(self, request): + """ + Fixture for all pandas native interval constructors. + To be used as the LHS of IntervalArray comparisons. + """ + return request.param + + def elementwise_comparison(self, op, array, other): + """ + Helper that performs elementwise comparisions between `array` and `other` + """ + other = other if is_list_like(other) else [other] * len(array) + return np.array([op(x, y) for x, y in zip(array, other)]) + + def test_compare_scalar_interval(self, op, array): + # matches first interval + other = array[0] + result = op(array, other) + expected = self.elementwise_comparison(op, array, other) + tm.assert_numpy_array_equal(result, expected) + + # matches on a single endpoint but not both + other = Interval(array.left[0], array.right[1]) + result = op(array, other) + expected = self.elementwise_comparison(op, array, other) + tm.assert_numpy_array_equal(result, expected) + + def test_compare_scalar_interval_mixed_closed(self, op, closed, other_closed): + array = IntervalArray.from_arrays(range(2), range(1, 3), closed=closed) + other = Interval(0, 1, closed=other_closed) + + result = op(array, other) + expected = self.elementwise_comparison(op, array, other) + tm.assert_numpy_array_equal(result, expected) + + def test_compare_scalar_na(self, op, array, nulls_fixture): + result = op(array, nulls_fixture) + expected = self.elementwise_comparison(op, array, nulls_fixture) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize( + "other", + [ + 0, + 1.0, + True, + "foo", + Timestamp("2017-01-01"), + Timestamp("2017-01-01", tz="US/Eastern"), + Timedelta("0 days"), + Period("2017-01-01", "D"), + ], + ) + def test_compare_scalar_other(self, op, array, other): + result = op(array, other) + expected = self.elementwise_comparison(op, array, other) + tm.assert_numpy_array_equal(result, expected) + + def test_compare_list_like_interval( + self, op, array, interval_constructor, + ): + # same endpoints + other = interval_constructor(array.left, array.right) + result = op(array, other) + expected = self.elementwise_comparison(op, array, other) + tm.assert_numpy_array_equal(result, expected) + + # different endpoints + other = interval_constructor(array.left[::-1], array.right[::-1]) + result = op(array, other) + expected = self.elementwise_comparison(op, array, other) + tm.assert_numpy_array_equal(result, expected) + + # all nan endpoints + other = interval_constructor([np.nan] * 4, [np.nan] * 4) + result = op(array, other) + expected = self.elementwise_comparison(op, array, other) + tm.assert_numpy_array_equal(result, expected) + + def test_compare_list_like_interval_mixed_closed( + self, op, interval_constructor, closed, other_closed + ): + array = IntervalArray.from_arrays(range(2), range(1, 3), closed=closed) + other = interval_constructor(range(2), range(1, 3), closed=other_closed) + + result = op(array, other) + expected = self.elementwise_comparison(op, array, other) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize( + "other", + [ + ( + Interval(0, 1), + Interval(Timedelta("1 day"), Timedelta("2 days")), + Interval(4, 5, "both"), + Interval(10, 20, "neither"), + ), + (0, 1.5, Timestamp("20170103"), np.nan), + ( + Timestamp("20170102", tz="US/Eastern"), + Timedelta("2 days"), + "baz", + pd.NaT, + ), + ], + ) + def test_compare_list_like_object(self, op, array, other): + result = op(array, other) + expected = self.elementwise_comparison(op, array, other) + tm.assert_numpy_array_equal(result, expected) + + def test_compare_list_like_nan(self, op, array, nulls_fixture): + other = [nulls_fixture] * 4 + result = op(array, other) + expected = self.elementwise_comparison(op, array, other) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize( + "other", + [ + np.arange(4, dtype="int64"), + np.arange(4, dtype="float64"), + date_range("2017-01-01", periods=4), + date_range("2017-01-01", periods=4, tz="US/Eastern"), + timedelta_range("0 days", periods=4), + period_range("2017-01-01", periods=4, freq="D"), + Categorical(list("abab")), + Categorical(date_range("2017-01-01", periods=4)), + pd.array(list("abcd")), + pd.array(["foo", 3.14, None, object()]), + ], + ids=lambda x: str(x.dtype), + ) + def test_compare_list_like_other(self, op, array, other): + result = op(array, other) + expected = self.elementwise_comparison(op, array, other) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("length", [1, 3, 5]) + @pytest.mark.parametrize("other_constructor", [IntervalArray, list]) + def test_compare_length_mismatch_errors(self, op, other_constructor, length): + array = IntervalArray.from_arrays(range(4), range(1, 5)) + other = other_constructor([Interval(0, 1)] * length) + with pytest.raises(ValueError, match="Lengths must match to compare"): + op(array, other) + + @pytest.mark.parametrize( + "constructor, expected_type, assert_func", + [ + (IntervalIndex, np.array, tm.assert_numpy_array_equal), + (Series, Series, tm.assert_series_equal), + ], + ) + def test_index_series_compat(self, op, constructor, expected_type, assert_func): + # IntervalIndex/Series that rely on IntervalArray for comparisons + breaks = range(4) + index = constructor(IntervalIndex.from_breaks(breaks)) + + # scalar comparisons + other = index[0] + result = op(index, other) + expected = expected_type(self.elementwise_comparison(op, index, other)) + assert_func(result, expected) + + other = breaks[0] + result = op(index, other) + expected = expected_type(self.elementwise_comparison(op, index, other)) + assert_func(result, expected) + + # list-like comparisons + other = IntervalArray.from_breaks(breaks) + result = op(index, other) + expected = expected_type(self.elementwise_comparison(op, index, other)) + assert_func(result, expected) + + other = [index[0], breaks[0], "foo"] + result = op(index, other) + expected = expected_type(self.elementwise_comparison(op, index, other)) + assert_func(result, expected) + + def test_repr(): # GH 25022 arr = IntervalArray.from_tuples([(0, 1), (1, 2)]) diff --git a/pandas/tests/series/test_arithmetic.py b/pandas/tests/series/test_arithmetic.py index 68d6169fa4f34..412bd1c63d140 100644 --- a/pandas/tests/series/test_arithmetic.py +++ b/pandas/tests/series/test_arithmetic.py @@ -171,6 +171,14 @@ def test_ser_cmp_result_names(self, names, op): result = op(ser, tdi) assert result.name == names[2] + # interval dtype + if op in [operator.eq, operator.ne]: + # interval dtype comparisons not yet implemented + ii = pd.interval_range(start=0, periods=5, name=names[0]) + ser = Series(ii).rename(names[1]) + result = op(ser, ii) + assert result.name == names[2] + # categorical if op in [operator.eq, operator.ne]: # categorical dtype comparisons raise for inequalities