Skip to content

BUG: IntervalIndex is_monotonic, get_loc, get_indexer_for, contains with np.nan #41863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,7 @@ Interval
- Bug in :meth:`IntervalIndex.intersection` returning duplicates when at least one of the :class:`Index` objects have duplicates which are present in the other (:issue:`38743`)
- :meth:`IntervalIndex.union`, :meth:`IntervalIndex.intersection`, :meth:`IntervalIndex.difference`, and :meth:`IntervalIndex.symmetric_difference` now cast to the appropriate dtype instead of raising a ``TypeError`` when operating with another :class:`IntervalIndex` with incompatible dtype (:issue:`39267`)
- :meth:`PeriodIndex.union`, :meth:`PeriodIndex.intersection`, :meth:`PeriodIndex.symmetric_difference`, :meth:`PeriodIndex.difference` now cast to object dtype instead of raising ``IncompatibleFrequency`` when operating with another :class:`PeriodIndex` with incompatible dtype (:issue:`39306`)
- Bug in :meth:`IntervalIndex.is_monotonic`, :meth:`IntervalIndex.get_loc`, :meth:`IntervalIndex.get_indexer_for`, and :meth:`IntervalIndex.__contains__` when NA values are present (:issue:`41831`)

Indexing
^^^^^^^^
Expand Down
4 changes: 4 additions & 0 deletions pandas/_libs/intervaltree.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ cdef class IntervalTree(IntervalMixin):
object dtype
str closed
object _is_overlapping, _left_sorter, _right_sorter
Py_ssize_t _na_count

def __init__(self, left, right, closed='right', leaf_size=100):
"""
Expand Down Expand Up @@ -67,6 +68,7 @@ cdef class IntervalTree(IntervalMixin):

# GH 23352: ensure no nan in nodes
mask = ~np.isnan(self.left)
self._na_count = len(mask) - mask.sum()
self.left = self.left[mask]
self.right = self.right[mask]
indices = indices[mask]
Expand Down Expand Up @@ -116,6 +118,8 @@ cdef class IntervalTree(IntervalMixin):
Return True if the IntervalTree is monotonic increasing (only equal or
increasing values), else False
"""
if self._na_count > 0:
return False
values = [self.right, self.left]

sort_order = np.lexsort(values)
Expand Down
16 changes: 13 additions & 3 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
is_scalar,
)
from pandas.core.dtypes.dtypes import IntervalDtype
from pandas.core.dtypes.missing import is_valid_na_for_dtype

from pandas.core.algorithms import take_nd
from pandas.core.arrays.interval import (
Expand Down Expand Up @@ -343,6 +344,8 @@ def __contains__(self, key: Any) -> bool:
"""
hash(key)
if not isinstance(key, Interval):
if is_valid_na_for_dtype(key, self.dtype):
return self.hasnans
return False

try:
Expand Down Expand Up @@ -618,6 +621,8 @@ def get_loc(
if self.closed != key.closed:
raise KeyError(key)
mask = (self.left == key.left) & (self.right == key.right)
elif is_valid_na_for_dtype(key, self.dtype):
mask = self.isna()
else:
# assume scalar
op_left = le if self.closed_left else lt
Expand All @@ -633,7 +638,12 @@ def get_loc(
raise KeyError(key)
elif matches == 1:
return mask.argmax()
return lib.maybe_booleans_to_slice(mask.view("u1"))

res = lib.maybe_booleans_to_slice(mask.view("u1"))
if isinstance(res, slice) and res.stop is None:
# TODO: DO this in maybe_booleans_to_slice?
res = slice(res.start, len(self), res.step)
return res

def _get_indexer(
self,
Expand Down Expand Up @@ -721,9 +731,9 @@ def _get_indexer_pointwise(self, target: Index) -> tuple[np.ndarray, np.ndarray]
indexer = np.concatenate(indexer)
return ensure_platform_int(indexer), ensure_platform_int(missing)

@property
@cache_readonly
def _index_as_unique(self) -> bool:
return not self.is_overlapping
return not self.is_overlapping and self._engine._na_count < 2

_requires_unique_msg = (
"cannot handle overlapping indices; use IntervalIndex.get_indexer_non_unique"
Expand Down
27 changes: 27 additions & 0 deletions pandas/tests/indexes/interval/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from pandas.errors import InvalidIndexError

from pandas import (
NA,
CategoricalIndex,
Interval,
IntervalIndex,
NaT,
Timedelta,
date_range,
timedelta_range,
Expand Down Expand Up @@ -168,6 +170,20 @@ def test_get_loc_non_scalar_errors(self, key):
with pytest.raises(InvalidIndexError, match=msg):
idx.get_loc(key)

def test_get_indexer_with_nans(self):
# GH#41831
index = IntervalIndex([np.nan, Interval(1, 2), np.nan])

expected = np.array([True, False, True])
for key in [None, np.nan, NA]:
assert key in index
result = index.get_loc(key)
tm.assert_numpy_array_equal(result, expected)

for key in [NaT, np.timedelta64("NaT", "ns"), np.datetime64("NaT", "ns")]:
with pytest.raises(KeyError, match=str(key)):
index.get_loc(key)


class TestGetIndexer:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -326,6 +342,17 @@ def test_get_indexer_non_monotonic(self):
expected = np.array([1, 2], dtype=np.intp)
tm.assert_numpy_array_equal(result, expected)

def test_get_indexer_with_nans(self):
# GH#41831
index = IntervalIndex([np.nan, np.nan])
other = IntervalIndex([np.nan])

assert not index._index_as_unique

result = index.get_indexer_for(other)
expected = np.array([0, 1], dtype=np.intp)
tm.assert_numpy_array_equal(result, expected)


class TestSliceLocs:
def test_slice_locs_with_interval(self):
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/indexes/interval/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,16 @@ def test_monotonic(self, closed):
assert idx.is_monotonic_decreasing is True
assert idx._is_strictly_monotonic_decreasing is True

def test_is_monotonic_with_nans(self):
# GH#41831
index = IntervalIndex([np.nan, np.nan])

assert not index.is_monotonic
assert not index._is_strictly_monotonic_increasing
assert not index.is_monotonic_increasing
assert not index._is_strictly_monotonic_decreasing
assert not index.is_monotonic_decreasing

def test_get_item(self, closed):
i = IntervalIndex.from_arrays((0, 1, np.nan), (1, 2, np.nan), closed=closed)
assert i[0] == Interval(0.0, 1.0, closed=closed)
Expand Down
15 changes: 15 additions & 0 deletions pandas/tests/indexing/interval/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,21 @@ def test_loc_getitem_frame(self):
with pytest.raises(KeyError, match=r"\[10\] not in index"):
df.loc[[10, 4]]

def test_getitem_interval_with_nans(self, frame_or_series, indexer_sl):
# GH#41831

index = IntervalIndex([np.nan, np.nan])
key = index[:-1]

obj = frame_or_series(range(2), index=index)
if frame_or_series is DataFrame and indexer_sl is tm.setitem:
obj = obj.T

result = indexer_sl(obj)[key]
expected = obj

tm.assert_equal(result, expected)


class TestIntervalIndexInsideMultiIndex:
def test_mi_intervalindex_slicing_with_scalar(self):
Expand Down