Skip to content

BUG: Fix IntervalArray equality comparisions #30640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,7 @@ Interval
- Bug in :meth:`IntervalIndex.get_indexer` where a :class:`Categorical` or :class:`CategoricalIndex` ``target`` would incorrectly raise a ``TypeError`` (:issue:`30063`)
- Bug in ``pandas.core.dtypes.cast.infer_dtype_from_scalar`` where passing ``pandas_dtype=True`` did not infer :class:`IntervalDtype` (:issue:`30337`)
- Bug in :class:`IntervalDtype` where the ``kind`` attribute was incorrectly set as ``None`` instead of ``"O"`` (:issue:`30568`)
- Bug in :class:`IntervalIndex`, :class:`~arrays.IntervalArray`, and :class:`Series` with interval data where equality comparisons were incorrect (:issue:`24112`)

Indexing
^^^^^^^^
Expand Down
55 changes: 55 additions & 0 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
is_integer_dtype,
is_interval,
is_interval_dtype,
is_list_like,
is_object_dtype,
is_scalar,
is_string_dtype,
is_timedelta64_dtype,
Expand All @@ -37,6 +39,7 @@
from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs
from pandas.core.arrays.categorical import Categorical
import pandas.core.common as com
from pandas.core.construction import array
from pandas.core.indexes.base import ensure_index

_VALID_CLOSED = {"left", "right", "both", "neither"}
Expand Down Expand Up @@ -547,6 +550,58 @@ def __setitem__(self, key, value):
right.values[key] = value_right
self._right = right

def __eq__(self, other):
# ensure pandas array for list-like and eliminate non-interval scalars
if is_list_like(other):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this particular check can be pushed to the base class

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite follow. Whose the base class here? ExtensionArray?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

if len(self) != len(other):
raise ValueError("Lengths must match to compare")
other = array(other)
elif not isinstance(other, Interval):
# non-interval scalar -> no matches
return np.zeros(len(self), dtype=bool)

# determine the dtype of the elements we want to compare
if isinstance(other, Interval):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this point can’t u just wrap other in array()?

other_dtype = "interval"
elif not is_categorical_dtype(other):
other_dtype = other.dtype
else:
# for categorical defer to categories for dtype
other_dtype = other.categories.dtype

# extract intervals if we have interval categories with matching closed
if is_interval_dtype(other_dtype):
if self.closed != other.categories.closed:
return np.zeros(len(self), dtype=bool)
other = other.categories.take(other.codes)

# interval-like -> need same closed and matching endpoints
if is_interval_dtype(other_dtype):
if self.closed != other.closed:
return np.zeros(len(self), dtype=bool)
return (self.left == other.left) & (self.right == other.right)

# non-interval/non-object dtype -> no matches
if not is_object_dtype(other_dtype):
return np.zeros(len(self), dtype=bool)

# object dtype -> iteratively check for intervals
result = np.zeros(len(self), dtype=bool)
for i, obj in enumerate(other):
# need object to be an Interval with same closed and endpoints
if (
isinstance(obj, Interval)
and self.closed == obj.closed
and self.left[i] == obj.left
and self.right[i] == obj.right
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this check be just self[i] == obj?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could but there'd be a perf hit for actually materializing the Interval object.

):
result[i] = True

return result

def __ne__(self, other):
return ~self.__eq__(other)

def fillna(self, value=None, method=None, limit=None):
"""
Fill NA/NaN values using the specified method.
Expand Down
11 changes: 10 additions & 1 deletion pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ def func(intvidx_self, other, sort=False):
"__array__",
"overlaps",
"contains",
"__eq__",
"__len__",
"__ne__",
"set_closed",
"to_tuples",
],
Expand All @@ -224,7 +226,14 @@ class IntervalIndex(IntervalMixin, Index, accessor.PandasDelegate):
# Immutable, so we are able to cache computations like isna in '_mask'
_mask = None

_raw_inherit = {"_ndarray_values", "__array__", "overlaps", "contains"}
_raw_inherit = {
"_ndarray_values",
"__array__",
"overlaps",
"contains",
"__eq__",
"__ne__",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be consistent with our other EAs, these need to be dispatched/wrapped the same way they are in datetimelike or categorical. im planning to move the relevant code to indexes.extension so this can re-use the existing code.

The relevant test will be arr == series --> Series, idx == series --> ndarray[bool]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should now be able to use indexes.extension.make_wrapped_comparison_op

}

# --------------------------------------------------------------------
# Constructors
Expand Down
235 changes: 235 additions & 0 deletions pandas/tests/arrays/interval/test_interval.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import operator

import numpy as np
import pytest

from pandas.core.dtypes.common import is_list_like

import pandas as pd
from pandas import (
Categorical,
Index,
Interval,
IntervalIndex,
Period,
Series,
Timedelta,
Timestamp,
date_range,
period_range,
timedelta_range,
)
from pandas.core.arrays import IntervalArray
Expand All @@ -35,6 +43,18 @@ def left_right_dtypes(request):
return request.param


def create_categorical_intervals(left, right, closed="right"):
return Categorical(IntervalIndex.from_arrays(left, right, closed))


def create_series_intervals(left, right, closed="right"):
return Series(IntervalArray.from_arrays(left, right, closed))


def create_series_categorical_intervals(left, right, closed="right"):
return Series(Categorical(IntervalIndex.from_arrays(left, right, closed)))


class TestAttributes:
@pytest.mark.parametrize(
"left, right",
Expand Down Expand Up @@ -93,6 +113,221 @@ def test_set_na(self, left_right_dtypes):
tm.assert_extension_array_equal(result, expected)


class TestComparison:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possibly put in tests.arithmetic.test_interval and parametrize with box_with_array

@pytest.fixture(params=[operator.eq, operator.ne])
def op(self, request):
return request.param

@pytest.fixture
def array(self, left_right_dtypes):
"""
Fixture to generate an IntervalArray of various dtypes containing NA if possible
"""
left, right = left_right_dtypes
if left.dtype != "int64":
left, right = left.insert(4, np.nan), right.insert(4, np.nan)
else:
left, right = left.insert(4, 10), right.insert(4, 20)
return IntervalArray.from_arrays(left, right)

@pytest.fixture(
params=[
IntervalArray.from_arrays,
IntervalIndex.from_arrays,
create_categorical_intervals,
create_series_intervals,
create_series_categorical_intervals,
],
ids=[
"IntervalArray",
"IntervalIndex",
"Categorical[Interval]",
"Series[Interval]",
"Series[Categorical[Interval]]",
],
)
def interval_constructor(self, request):
"""
Fixture for all pandas native interval constructors.
To be used as the LHS of IntervalArray comparisons.
"""
return request.param

def elementwise_comparison(self, op, array, other):
"""
Helper that performs elementwise comparisions between `array` and `other`
"""
other = other if is_list_like(other) else [other] * len(array)
return np.array([op(x, y) for x, y in zip(array, other)])

def test_compare_scalar_interval(self, op, array):
# matches first interval
other = array[0]
result = op(array, other)
expected = self.elementwise_comparison(op, array, other)
tm.assert_numpy_array_equal(result, expected)

# matches on a single endpoint but not both
other = Interval(array.left[0], array.right[1])
result = op(array, other)
expected = self.elementwise_comparison(op, array, other)
tm.assert_numpy_array_equal(result, expected)

def test_compare_scalar_interval_mixed_closed(self, op, closed, other_closed):
array = IntervalArray.from_arrays(range(2), range(1, 3), closed=closed)
other = Interval(0, 1, closed=other_closed)

result = op(array, other)
expected = self.elementwise_comparison(op, array, other)
tm.assert_numpy_array_equal(result, expected)

def test_compare_scalar_na(self, op, array, nulls_fixture):
result = op(array, nulls_fixture)
expected = self.elementwise_comparison(op, array, nulls_fixture)
tm.assert_numpy_array_equal(result, expected)

@pytest.mark.parametrize(
"other",
[
0,
1.0,
True,
"foo",
Timestamp("2017-01-01"),
Timestamp("2017-01-01", tz="US/Eastern"),
Timedelta("0 days"),
Period("2017-01-01", "D"),
],
)
def test_compare_scalar_other(self, op, array, other):
result = op(array, other)
expected = self.elementwise_comparison(op, array, other)
tm.assert_numpy_array_equal(result, expected)

def test_compare_list_like_interval(
self, op, array, interval_constructor,
):
# same endpoints
other = interval_constructor(array.left, array.right)
result = op(array, other)
expected = self.elementwise_comparison(op, array, other)
tm.assert_numpy_array_equal(result, expected)

# different endpoints
other = interval_constructor(array.left[::-1], array.right[::-1])
result = op(array, other)
expected = self.elementwise_comparison(op, array, other)
tm.assert_numpy_array_equal(result, expected)

# all nan endpoints
other = interval_constructor([np.nan] * 4, [np.nan] * 4)
result = op(array, other)
expected = self.elementwise_comparison(op, array, other)
tm.assert_numpy_array_equal(result, expected)

def test_compare_list_like_interval_mixed_closed(
self, op, interval_constructor, closed, other_closed
):
array = IntervalArray.from_arrays(range(2), range(1, 3), closed=closed)
other = interval_constructor(range(2), range(1, 3), closed=other_closed)

result = op(array, other)
expected = self.elementwise_comparison(op, array, other)
tm.assert_numpy_array_equal(result, expected)

@pytest.mark.parametrize(
"other",
[
(
Interval(0, 1),
Interval(Timedelta("1 day"), Timedelta("2 days")),
Interval(4, 5, "both"),
Interval(10, 20, "neither"),
),
(0, 1.5, Timestamp("20170103"), np.nan),
(
Timestamp("20170102", tz="US/Eastern"),
Timedelta("2 days"),
"baz",
pd.NaT,
),
],
)
def test_compare_list_like_object(self, op, array, other):
result = op(array, other)
expected = self.elementwise_comparison(op, array, other)
tm.assert_numpy_array_equal(result, expected)

def test_compare_list_like_nan(self, op, array, nulls_fixture):
other = [nulls_fixture] * 4
result = op(array, other)
expected = self.elementwise_comparison(op, array, other)
tm.assert_numpy_array_equal(result, expected)

@pytest.mark.parametrize(
"other",
[
np.arange(4, dtype="int64"),
np.arange(4, dtype="float64"),
date_range("2017-01-01", periods=4),
date_range("2017-01-01", periods=4, tz="US/Eastern"),
timedelta_range("0 days", periods=4),
period_range("2017-01-01", periods=4, freq="D"),
Categorical(list("abab")),
Categorical(date_range("2017-01-01", periods=4)),
pd.array(list("abcd")),
pd.array(["foo", 3.14, None, object()]),
],
ids=lambda x: str(x.dtype),
)
def test_compare_list_like_other(self, op, array, other):
result = op(array, other)
expected = self.elementwise_comparison(op, array, other)
tm.assert_numpy_array_equal(result, expected)

@pytest.mark.parametrize("length", [1, 3, 5])
@pytest.mark.parametrize("other_constructor", [IntervalArray, list])
def test_compare_length_mismatch_errors(self, op, other_constructor, length):
array = IntervalArray.from_arrays(range(4), range(1, 5))
other = other_constructor([Interval(0, 1)] * length)
with pytest.raises(ValueError, match="Lengths must match to compare"):
op(array, other)

@pytest.mark.parametrize(
"constructor, expected_type, assert_func",
[
(IntervalIndex, np.array, tm.assert_numpy_array_equal),
(Series, Series, tm.assert_series_equal),
],
)
def test_index_series_compat(self, op, constructor, expected_type, assert_func):
# IntervalIndex/Series that rely on IntervalArray for comparisons
breaks = range(4)
index = constructor(IntervalIndex.from_breaks(breaks))

# scalar comparisons
other = index[0]
result = op(index, other)
expected = expected_type(self.elementwise_comparison(op, index, other))
assert_func(result, expected)

other = breaks[0]
result = op(index, other)
expected = expected_type(self.elementwise_comparison(op, index, other))
assert_func(result, expected)

# list-like comparisons
other = IntervalArray.from_breaks(breaks)
result = op(index, other)
expected = expected_type(self.elementwise_comparison(op, index, other))
assert_func(result, expected)

other = [index[0], breaks[0], "foo"]
result = op(index, other)
expected = expected_type(self.elementwise_comparison(op, index, other))
assert_func(result, expected)


def test_repr():
# GH 25022
arr = IntervalArray.from_tuples([(0, 1), (1, 2)])
Expand Down
8 changes: 8 additions & 0 deletions pandas/tests/series/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,14 @@ def test_ser_cmp_result_names(self, names, op):
result = op(ser, tdi)
assert result.name == names[2]

# interval dtype
if op in [operator.eq, operator.ne]:
# interval dtype comparisons not yet implemented
ii = pd.interval_range(start=0, periods=5, name=names[0])
ser = Series(ii).rename(names[1])
result = op(ser, ii)
assert result.name == names[2]

# categorical
if op in [operator.eq, operator.ne]:
# categorical dtype comparisons raise for inequalities
Expand Down