|
| 1 | +import numpy as np |
| 2 | + |
| 3 | +from pandas.core.base import PandasObject, IndexOpsMixin |
| 4 | +from pandas.core.common import _values_from_object |
| 5 | +from pandas.core.index import Index |
| 6 | +from pandas.util.decorators import cache_readonly |
| 7 | + |
| 8 | + |
| 9 | +_VALID_CLOSED = set(['left', 'right', 'both', 'neither']) |
| 10 | + |
| 11 | + |
| 12 | +class IntervalMixin(object): |
| 13 | + @property |
| 14 | + def left(self): |
| 15 | + return self._left |
| 16 | + |
| 17 | + @property |
| 18 | + def right(self): |
| 19 | + return self._right |
| 20 | + |
| 21 | + @property |
| 22 | + def closed(self): |
| 23 | + return self._closed |
| 24 | + |
| 25 | + @cache_readonly |
| 26 | + def closed_left(self): |
| 27 | + return self.closed == 'left' or self.closed == 'both' |
| 28 | + |
| 29 | + @cache_readonly |
| 30 | + def closed_right(self): |
| 31 | + return self.closed == 'right' or self.closed == 'both' |
| 32 | + |
| 33 | + @property |
| 34 | + def open_left(self): |
| 35 | + return not self.closed_left |
| 36 | + |
| 37 | + @property |
| 38 | + def open_right(self): |
| 39 | + return not self.closed_right |
| 40 | + |
| 41 | + @cache_readonly |
| 42 | + def mid(self): |
| 43 | + # TODO: figure out how to do add/sub as arithemtic even on Index |
| 44 | + # objects. Is there a work around while we have deprecated +/- as |
| 45 | + # union/difference? Possibly need to add `add` and `sub` methods. |
| 46 | + try: |
| 47 | + return 0.5 * (self.left + self.right) |
| 48 | + except TypeError: |
| 49 | + # datetime safe version |
| 50 | + return self.left + 0.5 * (self.right - self.left) |
| 51 | + |
| 52 | + def _validate(self): |
| 53 | + # TODO: exclude periods? |
| 54 | + if self.closed not in _VALID_CLOSED: |
| 55 | + raise ValueError("invalid options for 'closed': %s" % self.closed) |
| 56 | + |
| 57 | + |
| 58 | +# TODO: cythonize this whole class? |
| 59 | +class Interval(PandasObject, IntervalMixin): |
| 60 | + def __init__(self, left, right, closed='right'): |
| 61 | + """Object representing an interval |
| 62 | + """ |
| 63 | + self._left = left |
| 64 | + self._right = right |
| 65 | + self._closed = closed |
| 66 | + self._validate() |
| 67 | + |
| 68 | + def __hash__(self): |
| 69 | + return hash((self.left, self.right, self.closed)) |
| 70 | + |
| 71 | + def __eq__(self, other): |
| 72 | + try: |
| 73 | + return (self.left == other.left |
| 74 | + and self.right == other.right |
| 75 | + and self.closed == other.closed) |
| 76 | + except AttributeError: |
| 77 | + return False |
| 78 | + |
| 79 | + def __ne__(self, other): |
| 80 | + return not self == other |
| 81 | + |
| 82 | + def __lt__(self, other): |
| 83 | + other_left = getattr(other, 'left', other) |
| 84 | + if self.open_right or getattr(other, 'open_left', False): |
| 85 | + return self.right <= other_left |
| 86 | + return self.right < other_left |
| 87 | + |
| 88 | + def __le__(self, other): |
| 89 | + return NotImplementedError |
| 90 | + |
| 91 | + def __gt__(self, other): |
| 92 | + return NotImplementedError |
| 93 | + |
| 94 | + def __ge__(self, other): |
| 95 | + return NotImplementedError |
| 96 | + |
| 97 | + # TODO: finish comparisons |
| 98 | + # TODO: add arithmetic operations |
| 99 | + |
| 100 | + def __str__(self): |
| 101 | + start_symbol = '[' if self.closed_left else '(' |
| 102 | + end_symbol = ']' if self.closed_right else ')' |
| 103 | + return '%s%s, %s%s' % (start_symbol, self.left, self.right, end_symbol) |
| 104 | + |
| 105 | + def __repr__(self): |
| 106 | + return ('%s(%r, %r, closed=%r)' % |
| 107 | + (type(self).__name__, self.left, |
| 108 | + self.right, self.closed)) |
| 109 | + |
| 110 | + |
| 111 | + |
| 112 | +class IntervalIndex(Index, IntervalMixin): |
| 113 | + def __new__(cls, left, right, closed='right', name=None): |
| 114 | + # TODO: validation |
| 115 | + result = object.__new__(cls) |
| 116 | + result._left = Index(left) |
| 117 | + result._right = Index(right) |
| 118 | + result._closed = closed |
| 119 | + result.name = name |
| 120 | + result._validate() |
| 121 | + result._reset_identity() |
| 122 | + return result |
| 123 | + |
| 124 | + def _simple_new(cls, values, name=None, **kwargs): |
| 125 | + # ensure we don't end up here (this is a superclass method) |
| 126 | + raise NotImplementedError |
| 127 | + |
| 128 | + @property |
| 129 | + def _constructor(self): |
| 130 | + return type(self).from_intervals |
| 131 | + |
| 132 | + @classmethod |
| 133 | + def from_breaks(cls, breaks, closed='right', name=None): |
| 134 | + return cls(breaks[:-1], breaks[1:], closed, name) |
| 135 | + |
| 136 | + @classmethod |
| 137 | + def from_intervals(cls, data, name=None): |
| 138 | + # TODO: cythonize (including validation for closed) |
| 139 | + left = [i.left for i in data] |
| 140 | + right = [i.right for i in data] |
| 141 | + closed = data[0].closed |
| 142 | + return cls(left, right, closed, name) |
| 143 | + |
| 144 | + @cache_readonly |
| 145 | + def _data(self): |
| 146 | + # TODO: cythonize |
| 147 | + zipped = zip(self.left, self.right) |
| 148 | + items = [Interval(l, r, self.closed) for l, r in zipped] |
| 149 | + return np.array(items, dtype=object) |
| 150 | + |
| 151 | + @cache_readonly |
| 152 | + def dtype(self): |
| 153 | + return np.dtype('O') |
| 154 | + |
| 155 | + def get_loc(self, key): |
| 156 | + if isinstance(key, Interval): |
| 157 | + # TODO: fall back to something like slice_locs if key not found |
| 158 | + return self._engine.get_loc(_values_from_object(key)) |
| 159 | + else: |
| 160 | + # TODO: handle decreasing monotonic intervals |
| 161 | + if not self.left.is_monotonic and self.right.is_monotonic: |
| 162 | + raise KeyError("cannot lookup values on a non-monotonic " |
| 163 | + "IntervalIndex") |
| 164 | + |
| 165 | + side_start = 'left' if self.closed_right else 'right' |
| 166 | + start = self.right.searchsorted(key, side=side_start) |
| 167 | + |
| 168 | + side_end = 'right' if self.closed_left else 'left' |
| 169 | + end = self.left.searchsorted(key, side=side_end) |
| 170 | + |
| 171 | + if start == end: |
| 172 | + raise KeyError(key) |
| 173 | + |
| 174 | + if start + 1 == end: |
| 175 | + return start |
| 176 | + else: |
| 177 | + return slice(start, end) |
| 178 | + |
| 179 | + def get_indexer(self, key): |
| 180 | + # should reuse the core of get_loc |
| 181 | + # if the key consists of intervals, needs unique values to give |
| 182 | + # sensible results (like DatetimeIndex) |
| 183 | + # if the key consists of scalars, the index's intervals must also be |
| 184 | + # non-overlapping |
| 185 | + raise NotImplementedError |
| 186 | + |
| 187 | + def slice_locs(self, start, end): |
| 188 | + # should be more efficient than directly calling the superclass method, |
| 189 | + # which calls get_loc (we don't need to do binary search twice for each |
| 190 | + # key) |
| 191 | + raise NotImplementedError |
| 192 | + |
| 193 | + def __contains__(self, key): |
| 194 | + try: |
| 195 | + self.get_loc(key) |
| 196 | + return True |
| 197 | + except KeyError: |
| 198 | + return False |
| 199 | + |
| 200 | + def __getitem__(self, value): |
| 201 | + left = self.left[value] |
| 202 | + right = self.right[value] |
| 203 | + if not isinstance(left, Index): |
| 204 | + return Interval(left, right, self.closed) |
| 205 | + else: |
| 206 | + return type(self)(left, right, self.closed) |
| 207 | + |
| 208 | + def __repr__(self): |
| 209 | + lines = [repr(type(self))] |
| 210 | + lines.extend(str(interval) for interval in self) |
| 211 | + lines.append('Length: %s, Closed: %r' % |
| 212 | + (len(self), self.closed)) |
| 213 | + return '\n'.join(lines) |
| 214 | + |
| 215 | + def equals(self, other): |
| 216 | + if self.is_(other): |
| 217 | + return True |
| 218 | + try: |
| 219 | + return (self.left.equals(other.left) |
| 220 | + and self.right.equals(other.right) |
| 221 | + and self.closed == other.closed) |
| 222 | + except AttributeError: |
| 223 | + return False |
| 224 | + |
| 225 | + # TODO: add comparisons and arithmetic operations |
0 commit comments