Skip to content

Commit 02f19f8

Browse files
committed
BUG: Perform i8 conversion for datetimelike IntervalTree queries
1 parent c282e31 commit 02f19f8

File tree

3 files changed

+186
-8
lines changed

3 files changed

+186
-8
lines changed

doc/source/whatsnew/v0.24.0.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,7 @@ Interval
755755
- Bug in the :class:`IntervalIndex` constructor where the ``closed`` parameter did not always override the inferred ``closed`` (:issue:`19370`)
756756
- Bug in the ``IntervalIndex`` repr where a trailing comma was missing after the list of intervals (:issue:`20611`)
757757
- Bug in :class:`Interval` where scalar arithmetic operations did not retain the ``closed`` value (:issue:`22313`)
758-
-
758+
- Bug in :class:`IntervalIndex` where indexing with datetime-like values raised a ``KeyError`` (:issue:`20636`)
759759

760760
Indexing
761761
^^^^^^^^

pandas/core/indexes/interval.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,23 @@
66

77
from pandas.compat import add_metaclass
88
from pandas.core.dtypes.missing import isna
9-
from pandas.core.dtypes.cast import find_common_type, maybe_downcast_to_dtype
9+
from pandas.core.dtypes.cast import (
10+
find_common_type, maybe_downcast_to_dtype, infer_dtype_from_scalar)
1011
from pandas.core.dtypes.common import (
1112
ensure_platform_int,
1213
is_list_like,
1314
is_datetime_or_timedelta_dtype,
1415
is_datetime64tz_dtype,
16+
is_dtype_equal,
1517
is_integer_dtype,
1618
is_float_dtype,
1719
is_interval_dtype,
1820
is_object_dtype,
1921
is_scalar,
2022
is_float,
2123
is_number,
22-
is_integer)
24+
is_integer,
25+
needs_i8_conversion)
2326
from pandas.core.indexes.base import (
2427
Index, ensure_index,
2528
default_pprint, _index_shared_docs)
@@ -29,8 +32,8 @@
2932
Interval, IntervalMixin, IntervalTree,
3033
)
3134

32-
from pandas.core.indexes.datetimes import date_range
33-
from pandas.core.indexes.timedeltas import timedelta_range
35+
from pandas.core.indexes.datetimes import date_range, DatetimeIndex
36+
from pandas.core.indexes.timedeltas import timedelta_range, TimedeltaIndex
3437
from pandas.core.indexes.multi import MultiIndex
3538
import pandas.core.common as com
3639
from pandas.util._decorators import cache_readonly, Appender
@@ -192,7 +195,9 @@ def _isnan(self):
192195

193196
@cache_readonly
194197
def _engine(self):
195-
return IntervalTree(self.left, self.right, closed=self.closed)
198+
left = self._maybe_convert_i8(self.left)
199+
right = self._maybe_convert_i8(self.right)
200+
return IntervalTree(left, right, closed=self.closed)
196201

197202
def __contains__(self, key):
198203
"""
@@ -514,6 +519,41 @@ def _maybe_cast_indexed(self, key):
514519

515520
return key
516521

522+
def _maybe_convert_i8(self, key):
523+
if isinstance(key, Interval):
524+
if not isinstance(key.left, (Timestamp, Timedelta)):
525+
return key
526+
left = self._maybe_convert_i8(key.left)
527+
right = self._maybe_convert_i8(key.right)
528+
return Interval(left, right, key.closed)
529+
elif isinstance(key, (IntervalIndex, IntervalArray)):
530+
if not needs_i8_conversion(key.left):
531+
return key
532+
left = self._maybe_convert_i8(key.left)
533+
right = self._maybe_convert_i8(key.right)
534+
return IntervalIndex.from_arrays(left, right, key.closed)
535+
elif is_list_like(key) and not isinstance(key, Index):
536+
result = self._maybe_convert_i8(ensure_index(key))
537+
if result[0] == key[0]:
538+
# return the list-like key if no conversion
539+
return key
540+
return result
541+
542+
subtype = self.dtype.subtype
543+
msg = ('Cannot index an IntervalIndex of subtype {subtype} with '
544+
'values of dtype {other}')
545+
if isinstance(key, (Timestamp, Timedelta)):
546+
key_dtype, key_i8 = infer_dtype_from_scalar(key, pandas_dtype=True)
547+
if not is_dtype_equal(subtype, key_dtype):
548+
raise ValueError(msg.format(subtype=subtype, other=key_dtype))
549+
return key_i8
550+
elif isinstance(key, (DatetimeIndex, TimedeltaIndex)):
551+
if not is_dtype_equal(subtype, key.dtype):
552+
raise ValueError(msg.format(subtype=subtype, other=key.dtype))
553+
return Index(key.asi8)
554+
555+
return key
556+
517557
def _check_method(self, method):
518558
if method is None:
519559
return
@@ -648,6 +688,7 @@ def get_loc(self, key, method=None):
648688

649689
else:
650690
# use the interval tree
691+
key = self._maybe_convert_i8(key)
651692
if isinstance(key, Interval):
652693
left, right = _get_interval_closed_bounds(key)
653694
return self._engine.get_loc_interval(left, right)
@@ -711,8 +752,10 @@ def _get_reindexer(self, target):
711752
"""
712753

713754
# find the left and right indexers
714-
lindexer = self._engine.get_indexer(target.left.values)
715-
rindexer = self._engine.get_indexer(target.right.values)
755+
left = self._maybe_convert_i8(target.left)
756+
right = self._maybe_convert_i8(target.right)
757+
lindexer = self._engine.get_indexer(left.values)
758+
rindexer = self._engine.get_indexer(right.values)
716759

717760
# we want to return an indexer on the intervals
718761
# however, our keys could provide overlapping of multiple

pandas/tests/indexes/interval/test_interval.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import division
22

3+
from itertools import permutations
34
import pytest
45
import numpy as np
6+
import re
57
from pandas import (
68
Interval, IntervalIndex, Index, isna, notna, interval_range, Timestamp,
79
Timedelta, date_range, timedelta_range)
@@ -498,6 +500,48 @@ def test_get_loc_length_one(self, item, closed):
498500
result = index.get_loc(item)
499501
assert result == 0
500502

503+
# Make consistent with test_interval_new.py (see #16316, #16386)
504+
@pytest.mark.parametrize('breaks', [
505+
date_range('20180101', periods=4),
506+
date_range('20180101', periods=4, tz='US/Eastern'),
507+
timedelta_range('0 days', periods=4)], ids=lambda x: str(x.dtype))
508+
def test_get_loc_datetimelike_nonoverlapping(self, breaks):
509+
# GH 20636
510+
# nonoverlapping = IntervalIndex method and no i8 conversion
511+
index = IntervalIndex.from_breaks(breaks)
512+
513+
value = index[0].mid
514+
result = index.get_loc(value)
515+
expected = 0
516+
assert result == expected
517+
518+
interval = Interval(index[0].left, index[1].right)
519+
result = index.get_loc(interval)
520+
expected = slice(0, 2)
521+
assert result == expected
522+
523+
# Make consistent with test_interval_new.py (see #16316, #16386)
524+
@pytest.mark.parametrize('arrays', [
525+
(date_range('20180101', periods=4), date_range('20180103', periods=4)),
526+
(date_range('20180101', periods=4, tz='US/Eastern'),
527+
date_range('20180103', periods=4, tz='US/Eastern')),
528+
(timedelta_range('0 days', periods=4),
529+
timedelta_range('2 days', periods=4))], ids=lambda x: str(x[0].dtype))
530+
def test_get_loc_datetimelike_overlapping(self, arrays):
531+
# GH 20636
532+
# overlapping = IntervalTree method with i8 conversion
533+
index = IntervalIndex.from_arrays(*arrays)
534+
535+
value = index[0].mid + Timedelta('12 hours')
536+
result = np.sort(index.get_loc(value))
537+
expected = np.array([0, 1], dtype='int64')
538+
assert tm.assert_numpy_array_equal(result, expected)
539+
540+
interval = Interval(index[0].left, index[1].right)
541+
result = np.sort(index.get_loc(interval))
542+
expected = np.array([0, 1, 2], dtype='int64')
543+
assert tm.assert_numpy_array_equal(result, expected)
544+
501545
# To be removed, replaced by test_interval_new.py (see #16316, #16386)
502546
def test_get_indexer(self):
503547
actual = self.index.get_indexer([-1, 0, 0.5, 1, 1.5, 2, 3])
@@ -555,6 +599,97 @@ def test_get_indexer_length_one(self, item, closed):
555599
expected = np.array([0] * len(item), dtype='intp')
556600
tm.assert_numpy_array_equal(result, expected)
557601

602+
# Make consistent with test_interval_new.py (see #16316, #16386)
603+
@pytest.mark.parametrize('arrays', [
604+
(date_range('20180101', periods=4), date_range('20180103', periods=4)),
605+
(date_range('20180101', periods=4, tz='US/Eastern'),
606+
date_range('20180103', periods=4, tz='US/Eastern')),
607+
(timedelta_range('0 days', periods=4),
608+
timedelta_range('2 days', periods=4))], ids=lambda x: str(x[0].dtype))
609+
def test_get_reindexer_datetimelike(self, arrays):
610+
# GH 20636
611+
index = IntervalIndex.from_arrays(*arrays)
612+
tuples = [(index[0].left, index[0].left + pd.Timedelta('12H')),
613+
(index[-1].right - pd.Timedelta('12H'), index[-1].right)]
614+
target = IntervalIndex.from_tuples(tuples)
615+
616+
result = index._get_reindexer(target)
617+
expected = np.array([0, 3], dtype='int64')
618+
tm.assert_numpy_array_equal(result, expected)
619+
620+
@pytest.mark.parametrize('breaks', [
621+
date_range('20180101', periods=4),
622+
date_range('20180101', periods=4, tz='US/Eastern'),
623+
timedelta_range('0 days', periods=4)], ids=lambda x: str(x.dtype))
624+
def test_maybe_convert_i8(self, breaks):
625+
# GH 20636
626+
index = IntervalIndex.from_breaks(breaks)
627+
628+
# intervalindex
629+
result = index._maybe_convert_i8(index)
630+
expected = IntervalIndex.from_breaks(breaks.asi8)
631+
tm.assert_index_equal(result, expected)
632+
633+
# interval
634+
interval = Interval(breaks[0], breaks[1])
635+
result = index._maybe_convert_i8(interval)
636+
expected = Interval(breaks[0].value, breaks[1].value)
637+
assert result == expected
638+
639+
# datetimelike index
640+
result = index._maybe_convert_i8(breaks)
641+
expected = Index(breaks.asi8)
642+
tm.assert_index_equal(result, expected)
643+
644+
# datetimelike scalar
645+
result = index._maybe_convert_i8(breaks[0])
646+
expected = breaks[0].value
647+
assert result == expected
648+
649+
# list-like of datetimelike scalars
650+
result = index._maybe_convert_i8(list(breaks))
651+
expected = Index(breaks.asi8)
652+
tm.assert_index_equal(result, expected)
653+
654+
@pytest.mark.parametrize('breaks', [
655+
np.arange(5, dtype='int64'),
656+
np.arange(5, dtype='float64')], ids=lambda x: str(x.dtype))
657+
def test_maybe_convert_i8_numeric(self, breaks):
658+
# GH 20636
659+
index = IntervalIndex.from_breaks(breaks)
660+
numeric_keys = [
661+
IntervalIndex.from_breaks(breaks),
662+
Interval(breaks[0], breaks[1]),
663+
breaks,
664+
breaks[0],
665+
list(breaks)]
666+
667+
# no conversion occurs for numeric
668+
for key in numeric_keys:
669+
result = index._maybe_convert_i8(key)
670+
assert result is key
671+
672+
@pytest.mark.parametrize('breaks1, breaks2', permutations([
673+
date_range('20180101', periods=4),
674+
date_range('20180101', periods=4, tz='US/Eastern'),
675+
timedelta_range('0 days', periods=4)], 2), ids=lambda x: str(x.dtype))
676+
def test_maybe_convert_i8_errors(self, breaks1, breaks2):
677+
# GH 20636
678+
index = IntervalIndex.from_breaks(breaks1)
679+
invalid_keys = [
680+
IntervalIndex.from_breaks(breaks2),
681+
Interval(breaks2[0], breaks2[1]),
682+
breaks2,
683+
breaks2[0],
684+
list(breaks2)]
685+
686+
msg = ('Cannot index an IntervalIndex of subtype {dtype1} with '
687+
'values of dtype {dtype2}')
688+
msg = re.escape(msg.format(dtype1=breaks1.dtype, dtype2=breaks2.dtype))
689+
for key in invalid_keys:
690+
with tm.assert_raises_regex(ValueError, msg):
691+
index._maybe_convert_i8(key)
692+
558693
# To be removed, replaced by test_interval_new.py (see #16316, #16386)
559694
def test_contains(self):
560695
# Only endpoints are valid.

0 commit comments

Comments
 (0)