-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
BUG: Fix IntervalArray equality comparisions #30640
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,8 @@ | |
is_integer_dtype, | ||
is_interval, | ||
is_interval_dtype, | ||
is_list_like, | ||
is_object_dtype, | ||
is_scalar, | ||
is_string_dtype, | ||
is_timedelta64_dtype, | ||
|
@@ -37,6 +39,7 @@ | |
from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs | ||
from pandas.core.arrays.categorical import Categorical | ||
import pandas.core.common as com | ||
from pandas.core.construction import array | ||
from pandas.core.indexes.base import ensure_index | ||
|
||
_VALID_CLOSED = {"left", "right", "both", "neither"} | ||
|
@@ -547,6 +550,58 @@ def __setitem__(self, key, value): | |
right.values[key] = value_right | ||
self._right = right | ||
|
||
def __eq__(self, other): | ||
# ensure pandas array for list-like and eliminate non-interval scalars | ||
if is_list_like(other): | ||
if len(self) != len(other): | ||
raise ValueError("Lengths must match to compare") | ||
other = array(other) | ||
elif not isinstance(other, Interval): | ||
# non-interval scalar -> no matches | ||
return np.zeros(len(self), dtype=bool) | ||
|
||
# determine the dtype of the elements we want to compare | ||
if isinstance(other, Interval): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. at this point can’t u just wrap other in array()? |
||
other_dtype = "interval" | ||
elif not is_categorical_dtype(other): | ||
other_dtype = other.dtype | ||
else: | ||
# for categorical defer to categories for dtype | ||
other_dtype = other.categories.dtype | ||
|
||
# extract intervals if we have interval categories with matching closed | ||
if is_interval_dtype(other_dtype): | ||
if self.closed != other.categories.closed: | ||
return np.zeros(len(self), dtype=bool) | ||
other = other.categories.take(other.codes) | ||
|
||
# interval-like -> need same closed and matching endpoints | ||
if is_interval_dtype(other_dtype): | ||
if self.closed != other.closed: | ||
return np.zeros(len(self), dtype=bool) | ||
return (self.left == other.left) & (self.right == other.right) | ||
|
||
# non-interval/non-object dtype -> no matches | ||
if not is_object_dtype(other_dtype): | ||
return np.zeros(len(self), dtype=bool) | ||
|
||
# object dtype -> iteratively check for intervals | ||
result = np.zeros(len(self), dtype=bool) | ||
for i, obj in enumerate(other): | ||
# need object to be an Interval with same closed and endpoints | ||
if ( | ||
isinstance(obj, Interval) | ||
and self.closed == obj.closed | ||
and self.left[i] == obj.left | ||
and self.right[i] == obj.right | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can this check be just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It could but there'd be a perf hit for actually materializing the |
||
): | ||
result[i] = True | ||
|
||
return result | ||
|
||
def __ne__(self, other): | ||
return ~self.__eq__(other) | ||
|
||
def fillna(self, value=None, method=None, limit=None): | ||
""" | ||
Fill NA/NaN values using the specified method. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -205,7 +205,9 @@ def func(intvidx_self, other, sort=False): | |
"__array__", | ||
"overlaps", | ||
"contains", | ||
"__eq__", | ||
"__len__", | ||
"__ne__", | ||
"set_closed", | ||
"to_tuples", | ||
], | ||
|
@@ -224,7 +226,14 @@ class IntervalIndex(IntervalMixin, Index, accessor.PandasDelegate): | |
# Immutable, so we are able to cache computations like isna in '_mask' | ||
_mask = None | ||
|
||
_raw_inherit = {"_ndarray_values", "__array__", "overlaps", "contains"} | ||
_raw_inherit = { | ||
"_ndarray_values", | ||
"__array__", | ||
"overlaps", | ||
"contains", | ||
"__eq__", | ||
"__ne__", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to be consistent with our other EAs, these need to be dispatched/wrapped the same way they are in datetimelike or categorical. im planning to move the relevant code to indexes.extension so this can re-use the existing code. The relevant test will be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should now be able to use indexes.extension.make_wrapped_comparison_op |
||
} | ||
|
||
# -------------------------------------------------------------------- | ||
# Constructors | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,22 @@ | ||
import operator | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from pandas.core.dtypes.common import is_list_like | ||
|
||
import pandas as pd | ||
from pandas import ( | ||
Categorical, | ||
Index, | ||
Interval, | ||
IntervalIndex, | ||
Period, | ||
Series, | ||
Timedelta, | ||
Timestamp, | ||
date_range, | ||
period_range, | ||
timedelta_range, | ||
) | ||
from pandas.core.arrays import IntervalArray | ||
|
@@ -35,6 +43,18 @@ def left_right_dtypes(request): | |
return request.param | ||
|
||
|
||
def create_categorical_intervals(left, right, closed="right"): | ||
return Categorical(IntervalIndex.from_arrays(left, right, closed)) | ||
|
||
|
||
def create_series_intervals(left, right, closed="right"): | ||
return Series(IntervalArray.from_arrays(left, right, closed)) | ||
|
||
|
||
def create_series_categorical_intervals(left, right, closed="right"): | ||
return Series(Categorical(IntervalIndex.from_arrays(left, right, closed))) | ||
|
||
|
||
class TestAttributes: | ||
@pytest.mark.parametrize( | ||
"left, right", | ||
|
@@ -93,6 +113,221 @@ def test_set_na(self, left_right_dtypes): | |
tm.assert_extension_array_equal(result, expected) | ||
|
||
|
||
class TestComparison: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. possibly put in tests.arithmetic.test_interval and parametrize with box_with_array |
||
@pytest.fixture(params=[operator.eq, operator.ne]) | ||
def op(self, request): | ||
return request.param | ||
|
||
@pytest.fixture | ||
def array(self, left_right_dtypes): | ||
""" | ||
Fixture to generate an IntervalArray of various dtypes containing NA if possible | ||
""" | ||
left, right = left_right_dtypes | ||
if left.dtype != "int64": | ||
left, right = left.insert(4, np.nan), right.insert(4, np.nan) | ||
else: | ||
left, right = left.insert(4, 10), right.insert(4, 20) | ||
return IntervalArray.from_arrays(left, right) | ||
|
||
@pytest.fixture( | ||
params=[ | ||
IntervalArray.from_arrays, | ||
IntervalIndex.from_arrays, | ||
create_categorical_intervals, | ||
create_series_intervals, | ||
create_series_categorical_intervals, | ||
], | ||
ids=[ | ||
"IntervalArray", | ||
"IntervalIndex", | ||
"Categorical[Interval]", | ||
"Series[Interval]", | ||
"Series[Categorical[Interval]]", | ||
], | ||
) | ||
def interval_constructor(self, request): | ||
""" | ||
Fixture for all pandas native interval constructors. | ||
To be used as the LHS of IntervalArray comparisons. | ||
""" | ||
return request.param | ||
|
||
def elementwise_comparison(self, op, array, other): | ||
""" | ||
Helper that performs elementwise comparisions between `array` and `other` | ||
""" | ||
other = other if is_list_like(other) else [other] * len(array) | ||
return np.array([op(x, y) for x, y in zip(array, other)]) | ||
|
||
def test_compare_scalar_interval(self, op, array): | ||
# matches first interval | ||
other = array[0] | ||
result = op(array, other) | ||
expected = self.elementwise_comparison(op, array, other) | ||
tm.assert_numpy_array_equal(result, expected) | ||
|
||
# matches on a single endpoint but not both | ||
other = Interval(array.left[0], array.right[1]) | ||
result = op(array, other) | ||
expected = self.elementwise_comparison(op, array, other) | ||
tm.assert_numpy_array_equal(result, expected) | ||
|
||
def test_compare_scalar_interval_mixed_closed(self, op, closed, other_closed): | ||
array = IntervalArray.from_arrays(range(2), range(1, 3), closed=closed) | ||
other = Interval(0, 1, closed=other_closed) | ||
|
||
result = op(array, other) | ||
expected = self.elementwise_comparison(op, array, other) | ||
tm.assert_numpy_array_equal(result, expected) | ||
|
||
def test_compare_scalar_na(self, op, array, nulls_fixture): | ||
result = op(array, nulls_fixture) | ||
expected = self.elementwise_comparison(op, array, nulls_fixture) | ||
tm.assert_numpy_array_equal(result, expected) | ||
|
||
@pytest.mark.parametrize( | ||
"other", | ||
[ | ||
0, | ||
1.0, | ||
True, | ||
"foo", | ||
Timestamp("2017-01-01"), | ||
Timestamp("2017-01-01", tz="US/Eastern"), | ||
Timedelta("0 days"), | ||
Period("2017-01-01", "D"), | ||
], | ||
) | ||
def test_compare_scalar_other(self, op, array, other): | ||
result = op(array, other) | ||
expected = self.elementwise_comparison(op, array, other) | ||
tm.assert_numpy_array_equal(result, expected) | ||
|
||
def test_compare_list_like_interval( | ||
self, op, array, interval_constructor, | ||
): | ||
# same endpoints | ||
other = interval_constructor(array.left, array.right) | ||
result = op(array, other) | ||
expected = self.elementwise_comparison(op, array, other) | ||
tm.assert_numpy_array_equal(result, expected) | ||
|
||
# different endpoints | ||
other = interval_constructor(array.left[::-1], array.right[::-1]) | ||
result = op(array, other) | ||
expected = self.elementwise_comparison(op, array, other) | ||
tm.assert_numpy_array_equal(result, expected) | ||
|
||
# all nan endpoints | ||
other = interval_constructor([np.nan] * 4, [np.nan] * 4) | ||
result = op(array, other) | ||
expected = self.elementwise_comparison(op, array, other) | ||
tm.assert_numpy_array_equal(result, expected) | ||
|
||
def test_compare_list_like_interval_mixed_closed( | ||
self, op, interval_constructor, closed, other_closed | ||
): | ||
array = IntervalArray.from_arrays(range(2), range(1, 3), closed=closed) | ||
other = interval_constructor(range(2), range(1, 3), closed=other_closed) | ||
|
||
result = op(array, other) | ||
expected = self.elementwise_comparison(op, array, other) | ||
tm.assert_numpy_array_equal(result, expected) | ||
|
||
@pytest.mark.parametrize( | ||
"other", | ||
[ | ||
( | ||
Interval(0, 1), | ||
Interval(Timedelta("1 day"), Timedelta("2 days")), | ||
Interval(4, 5, "both"), | ||
Interval(10, 20, "neither"), | ||
), | ||
(0, 1.5, Timestamp("20170103"), np.nan), | ||
( | ||
Timestamp("20170102", tz="US/Eastern"), | ||
Timedelta("2 days"), | ||
"baz", | ||
pd.NaT, | ||
), | ||
], | ||
) | ||
def test_compare_list_like_object(self, op, array, other): | ||
result = op(array, other) | ||
expected = self.elementwise_comparison(op, array, other) | ||
tm.assert_numpy_array_equal(result, expected) | ||
|
||
def test_compare_list_like_nan(self, op, array, nulls_fixture): | ||
other = [nulls_fixture] * 4 | ||
result = op(array, other) | ||
expected = self.elementwise_comparison(op, array, other) | ||
tm.assert_numpy_array_equal(result, expected) | ||
|
||
@pytest.mark.parametrize( | ||
"other", | ||
[ | ||
np.arange(4, dtype="int64"), | ||
np.arange(4, dtype="float64"), | ||
date_range("2017-01-01", periods=4), | ||
date_range("2017-01-01", periods=4, tz="US/Eastern"), | ||
timedelta_range("0 days", periods=4), | ||
period_range("2017-01-01", periods=4, freq="D"), | ||
Categorical(list("abab")), | ||
Categorical(date_range("2017-01-01", periods=4)), | ||
pd.array(list("abcd")), | ||
pd.array(["foo", 3.14, None, object()]), | ||
], | ||
ids=lambda x: str(x.dtype), | ||
) | ||
def test_compare_list_like_other(self, op, array, other): | ||
result = op(array, other) | ||
expected = self.elementwise_comparison(op, array, other) | ||
tm.assert_numpy_array_equal(result, expected) | ||
|
||
@pytest.mark.parametrize("length", [1, 3, 5]) | ||
@pytest.mark.parametrize("other_constructor", [IntervalArray, list]) | ||
def test_compare_length_mismatch_errors(self, op, other_constructor, length): | ||
array = IntervalArray.from_arrays(range(4), range(1, 5)) | ||
other = other_constructor([Interval(0, 1)] * length) | ||
with pytest.raises(ValueError, match="Lengths must match to compare"): | ||
op(array, other) | ||
|
||
@pytest.mark.parametrize( | ||
"constructor, expected_type, assert_func", | ||
[ | ||
(IntervalIndex, np.array, tm.assert_numpy_array_equal), | ||
(Series, Series, tm.assert_series_equal), | ||
], | ||
) | ||
def test_index_series_compat(self, op, constructor, expected_type, assert_func): | ||
# IntervalIndex/Series that rely on IntervalArray for comparisons | ||
breaks = range(4) | ||
index = constructor(IntervalIndex.from_breaks(breaks)) | ||
|
||
# scalar comparisons | ||
other = index[0] | ||
result = op(index, other) | ||
expected = expected_type(self.elementwise_comparison(op, index, other)) | ||
assert_func(result, expected) | ||
|
||
other = breaks[0] | ||
result = op(index, other) | ||
expected = expected_type(self.elementwise_comparison(op, index, other)) | ||
assert_func(result, expected) | ||
|
||
# list-like comparisons | ||
other = IntervalArray.from_breaks(breaks) | ||
result = op(index, other) | ||
expected = expected_type(self.elementwise_comparison(op, index, other)) | ||
assert_func(result, expected) | ||
|
||
other = [index[0], breaks[0], "foo"] | ||
result = op(index, other) | ||
expected = expected_type(self.elementwise_comparison(op, index, other)) | ||
assert_func(result, expected) | ||
|
||
|
||
def test_repr(): | ||
# GH 25022 | ||
arr = IntervalArray.from_tuples([(0, 1), (1, 2)]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this particular check can be pushed to the base class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite follow. Whose the base class here? ExtensionArray?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes