Skip to content

Commit 6aeea4c

Browse files
committed
WIP/API/ENH: IntervalIndex
Fixes pandas-dev#7640, pandas-dev#8625 This is a work in progress, but it's far enough along that I'd love to get some feedback. TODOs (more called out in the code): - [ ] documentation + docstrings - [ ] finish the index methods: - [ ] `get_loc` - [ ] `get_indexer` - [ ] `slice_locs` - [ ] comparison operations - [ ] fix `is_monotonic` (pending pandas-dev#8680) - [ ] ensure sorting works - [ ] arithmetic operations (not essential for MVP) - [ ] cythonize the bottlenecks: - [ ] `from_breaks` - [ ] `_data` - [ ] `Interval`? - [ ] `MultiIndex` - [ ] `Categorical`/`cut` - [ ] serialization - [ ] lots more tests CC @jreback @cpcloud @immerrr
1 parent 3579304 commit 6aeea4c

File tree

6 files changed

+489
-39
lines changed

6 files changed

+489
-39
lines changed

pandas/core/api.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pandas.core.groupby import Grouper
1010
from pandas.core.format import set_eng_float_format
1111
from pandas.core.index import Index, Int64Index, Float64Index, MultiIndex
12+
from pandas.core.interval import Interval, IntervalIndex
1213

1314
from pandas.core.series import Series, TimeSeries
1415
from pandas.core.frame import DataFrame

pandas/core/index.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1244,7 +1244,7 @@ def union(self, other):
12441244

12451245
def _wrap_union_result(self, other, result):
12461246
name = self.name if self.name == other.name else None
1247-
return self.__class__(data=result, name=name)
1247+
return self._constructor(data=result, name=name)
12481248

12491249
def intersection(self, other):
12501250
"""

pandas/core/interval.py

+225
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import numpy as np
2+
3+
from pandas.core.base import PandasObject, IndexOpsMixin
4+
from pandas.core.common import _values_from_object
5+
from pandas.core.index import Index
6+
from pandas.util.decorators import cache_readonly
7+
8+
9+
_VALID_CLOSED = set(['left', 'right', 'both', 'neither'])
10+
11+
12+
class IntervalMixin(object):
13+
@property
14+
def left(self):
15+
return self._left
16+
17+
@property
18+
def right(self):
19+
return self._right
20+
21+
@property
22+
def closed(self):
23+
return self._closed
24+
25+
@cache_readonly
26+
def closed_left(self):
27+
return self.closed == 'left' or self.closed == 'both'
28+
29+
@cache_readonly
30+
def closed_right(self):
31+
return self.closed == 'right' or self.closed == 'both'
32+
33+
@property
34+
def open_left(self):
35+
return not self.closed_left
36+
37+
@property
38+
def open_right(self):
39+
return not self.closed_right
40+
41+
@cache_readonly
42+
def mid(self):
43+
# TODO: figure out how to do add/sub as arithemtic even on Index
44+
# objects. Is there a work around while we have deprecated +/- as
45+
# union/difference? Possibly need to add `add` and `sub` methods.
46+
try:
47+
return 0.5 * (self.left + self.right)
48+
except TypeError:
49+
# datetime safe version
50+
return self.left + 0.5 * (self.right - self.left)
51+
52+
def _validate(self):
53+
# TODO: exclude periods?
54+
if self.closed not in _VALID_CLOSED:
55+
raise ValueError("invalid options for 'closed': %s" % self.closed)
56+
57+
58+
# TODO: cythonize this whole class?
59+
class Interval(PandasObject, IntervalMixin):
60+
def __init__(self, left, right, closed='right'):
61+
"""Object representing an interval
62+
"""
63+
self._left = left
64+
self._right = right
65+
self._closed = closed
66+
self._validate()
67+
68+
def __hash__(self):
69+
return hash((self.left, self.right, self.closed))
70+
71+
def __eq__(self, other):
72+
try:
73+
return (self.left == other.left
74+
and self.right == other.right
75+
and self.closed == other.closed)
76+
except AttributeError:
77+
return False
78+
79+
def __ne__(self, other):
80+
return not self == other
81+
82+
def __lt__(self, other):
83+
other_left = getattr(other, 'left', other)
84+
if self.open_right or getattr(other, 'open_left', False):
85+
return self.right <= other_left
86+
return self.right < other_left
87+
88+
def __le__(self, other):
89+
return NotImplementedError
90+
91+
def __gt__(self, other):
92+
return NotImplementedError
93+
94+
def __ge__(self, other):
95+
return NotImplementedError
96+
97+
# TODO: finish comparisons
98+
# TODO: add arithmetic operations
99+
100+
def __str__(self):
101+
start_symbol = '[' if self.closed_left else '('
102+
end_symbol = ']' if self.closed_right else ')'
103+
return '%s%s, %s%s' % (start_symbol, self.left, self.right, end_symbol)
104+
105+
def __repr__(self):
106+
return ('%s(%r, %r, closed=%r)' %
107+
(type(self).__name__, self.left,
108+
self.right, self.closed))
109+
110+
111+
112+
class IntervalIndex(Index, IntervalMixin):
113+
def __new__(cls, left, right, closed='right', name=None):
114+
# TODO: validation
115+
result = object.__new__(cls)
116+
result._left = Index(left)
117+
result._right = Index(right)
118+
result._closed = closed
119+
result.name = name
120+
result._validate()
121+
result._reset_identity()
122+
return result
123+
124+
def _simple_new(cls, values, name=None, **kwargs):
125+
# ensure we don't end up here (this is a superclass method)
126+
raise NotImplementedError
127+
128+
@property
129+
def _constructor(self):
130+
return type(self).from_intervals
131+
132+
@classmethod
133+
def from_breaks(cls, breaks, closed='right', name=None):
134+
return cls(breaks[:-1], breaks[1:], closed, name)
135+
136+
@classmethod
137+
def from_intervals(cls, data, name=None):
138+
# TODO: cythonize (including validation for closed)
139+
left = [i.left for i in data]
140+
right = [i.right for i in data]
141+
closed = data[0].closed
142+
return cls(left, right, closed, name)
143+
144+
@cache_readonly
145+
def _data(self):
146+
# TODO: cythonize
147+
zipped = zip(self.left, self.right)
148+
items = [Interval(l, r, self.closed) for l, r in zipped]
149+
return np.array(items, dtype=object)
150+
151+
@cache_readonly
152+
def dtype(self):
153+
return np.dtype('O')
154+
155+
def get_loc(self, key):
156+
if isinstance(key, Interval):
157+
# TODO: fall back to something like slice_locs if key not found
158+
return self._engine.get_loc(_values_from_object(key))
159+
else:
160+
# TODO: handle decreasing monotonic intervals
161+
if not self.left.is_monotonic and self.right.is_monotonic:
162+
raise KeyError("cannot lookup values on a non-monotonic "
163+
"IntervalIndex")
164+
165+
side_start = 'left' if self.closed_right else 'right'
166+
start = self.right.searchsorted(key, side=side_start)
167+
168+
side_end = 'right' if self.closed_left else 'left'
169+
end = self.left.searchsorted(key, side=side_end)
170+
171+
if start == end:
172+
raise KeyError(key)
173+
174+
if start + 1 == end:
175+
return start
176+
else:
177+
return slice(start, end)
178+
179+
def get_indexer(self, key):
180+
# should reuse the core of get_loc
181+
# if the key consists of intervals, needs unique values to give
182+
# sensible results (like DatetimeIndex)
183+
# if the key consists of scalars, the index's intervals must also be
184+
# non-overlapping
185+
raise NotImplementedError
186+
187+
def slice_locs(self, start, end):
188+
# should be more efficient than directly calling the superclass method,
189+
# which calls get_loc (we don't need to do binary search twice for each
190+
# key)
191+
raise NotImplementedError
192+
193+
def __contains__(self, key):
194+
try:
195+
self.get_loc(key)
196+
return True
197+
except KeyError:
198+
return False
199+
200+
def __getitem__(self, value):
201+
left = self.left[value]
202+
right = self.right[value]
203+
if not isinstance(left, Index):
204+
return Interval(left, right, self.closed)
205+
else:
206+
return type(self)(left, right, self.closed)
207+
208+
def __repr__(self):
209+
lines = [repr(type(self))]
210+
lines.extend(str(interval) for interval in self)
211+
lines.append('Length: %s, Closed: %r' %
212+
(len(self), self.closed))
213+
return '\n'.join(lines)
214+
215+
def equals(self, other):
216+
if self.is_(other):
217+
return True
218+
try:
219+
return (self.left.equals(other.left)
220+
and self.right.equals(other.right)
221+
and self.closed == other.closed)
222+
except AttributeError:
223+
return False
224+
225+
# TODO: add comparisons and arithmetic operations

pandas/tests/test_indexing.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
import pandas.core.common as com
1515
from pandas import option_context
1616
from pandas.core.api import (DataFrame, Index, Series, Panel, isnull,
17-
MultiIndex, Float64Index, Timestamp, Timedelta)
17+
MultiIndex, Float64Index, IntervalIndex,
18+
Timestamp, Timedelta)
1819
from pandas.util.testing import (assert_almost_equal, assert_series_equal,
1920
assert_frame_equal, assert_panel_equal,
2021
assert_attr_equal)
@@ -3692,6 +3693,34 @@ def test_floating_index(self):
36923693
assert_series_equal(result1, result3)
36933694
assert_series_equal(result1, Series([1],index=[2.5]))
36943695

3696+
def test_interval_index(self):
3697+
s = Series(np.arange(5), IntervalIndex.from_breaks(np.arange(6)))
3698+
3699+
expected = s.iloc[:3]
3700+
assert_series_equal(expected, s.loc[:3])
3701+
assert_series_equal(expected, s.loc[:2.5])
3702+
assert_series_equal(expected, s.loc[0.1:2.5])
3703+
assert_series_equal(expected, s.loc[-1:3])
3704+
3705+
expected = s.iloc[1:4]
3706+
assert_series_equal(expected, s.loc[[1.5, 2.5, 3.5]])
3707+
assert_series_equal(expected, s.loc[[2, 3, 4]])
3708+
assert_series_equal(expected, s.loc[[1.5, 3, 4]])
3709+
3710+
idx = IntervalIndex.from_breaks(np.arange(6), closed='left')
3711+
s = Series(np.arange(5), idx)
3712+
3713+
expected = s.iloc[:3]
3714+
assert_series_equal(expected, s.loc[:2])
3715+
3716+
expected = s.iloc[1:4]
3717+
assert_series_equal(expected, s.loc[[1.5, 2, 3]])
3718+
3719+
expected = 0
3720+
self.assertEqual(expected, s.loc[s.loc[0.5]])
3721+
self.assertEqual(expected, s.loc[s.loc[1]])
3722+
self.assertEqual(expected, s.loc[Interval(0, 1)])
3723+
36953724
def test_scalar_indexer(self):
36963725
# float indexing checked above
36973726

0 commit comments

Comments
 (0)