Skip to content

Commit 0229538

Browse files
jschendeljreback
authored andcommitted
ENH: Support TZ Aware IntervalIndex (#18558)
1 parent f7eb4ae commit 0229538

File tree

5 files changed

+141
-58
lines changed

5 files changed

+141
-58
lines changed

doc/source/whatsnew/v0.22.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ Other Enhancements
134134
- :func:`pandas.read_clipboard` updated to use qtpy, falling back to PyQt5 and then PyQt4, adding compatibility with Python3 and multiple python-qt bindings (:issue:`17722`)
135135
- Improved wording of ``ValueError`` raised in :func:`read_csv` when the ``usecols`` argument cannot match all columns. (:issue:`17301`)
136136
- :func:`DataFrame.corrwith` now silently drops non-numeric columns when passed a Series. Before, an exception was raised (:issue:`18570`).
137+
- :class:`IntervalIndex` now supports time zone aware ``Interval`` objects (:issue:`18537`, :issue:`18538`)
137138

138139

139140
.. _whatsnew_0220.api_breaking:

pandas/_libs/interval.pyx

+8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ cimport cython
66
import cython
77
from numpy cimport ndarray
88
from tslib import Timestamp
9+
from tslibs.timezones cimport get_timezone
910

1011
from cpython.object cimport (Py_EQ, Py_NE, Py_GT, Py_LT, Py_GE, Py_LE,
1112
PyObject_RichCompare)
@@ -119,6 +120,13 @@ cdef class Interval(IntervalMixin):
119120
raise ValueError(msg)
120121
if not left <= right:
121122
raise ValueError('left side of interval must be <= right side')
123+
if (isinstance(left, Timestamp) and
124+
get_timezone(left.tzinfo) != get_timezone(right.tzinfo)):
125+
# GH 18538
126+
msg = ("left and right must have the same time zone, got "
127+
"'{left_tz}' and '{right_tz}'")
128+
raise ValueError(msg.format(left_tz=left.tzinfo,
129+
right_tz=right.tzinfo))
122130
self.left = left
123131
self.right = right
124132
self.closed = closed

pandas/core/indexes/interval.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import numpy as np
44

55
from pandas.core.dtypes.missing import notna, isna
6-
from pandas.core.dtypes.generic import ABCPeriodIndex
6+
from pandas.core.dtypes.generic import ABCDatetimeIndex, ABCPeriodIndex
77
from pandas.core.dtypes.dtypes import IntervalDtype
88
from pandas.core.dtypes.cast import maybe_convert_platform
99
from pandas.core.dtypes.common import (
1010
_ensure_platform_int,
1111
is_list_like,
1212
is_datetime_or_timedelta_dtype,
13+
is_datetime64tz_dtype,
1314
is_integer_dtype,
1415
is_object_dtype,
1516
is_categorical_dtype,
@@ -54,7 +55,7 @@ def _get_next_label(label):
5455
dtype = getattr(label, 'dtype', type(label))
5556
if isinstance(label, (Timestamp, Timedelta)):
5657
dtype = 'datetime64'
57-
if is_datetime_or_timedelta_dtype(dtype):
58+
if is_datetime_or_timedelta_dtype(dtype) or is_datetime64tz_dtype(dtype):
5859
return label + np.timedelta64(1, 'ns')
5960
elif is_integer_dtype(dtype):
6061
return label + 1
@@ -69,7 +70,7 @@ def _get_prev_label(label):
6970
dtype = getattr(label, 'dtype', type(label))
7071
if isinstance(label, (Timestamp, Timedelta)):
7172
dtype = 'datetime64'
72-
if is_datetime_or_timedelta_dtype(dtype):
73+
if is_datetime_or_timedelta_dtype(dtype) or is_datetime64tz_dtype(dtype):
7374
return label - np.timedelta64(1, 'ns')
7475
elif is_integer_dtype(dtype):
7576
return label - 1
@@ -227,17 +228,22 @@ def _simple_new(cls, left, right, closed=None, name=None,
227228
# coerce dtypes to match if needed
228229
if is_float_dtype(left) and is_integer_dtype(right):
229230
right = right.astype(left.dtype)
230-
if is_float_dtype(right) and is_integer_dtype(left):
231+
elif is_float_dtype(right) and is_integer_dtype(left):
231232
left = left.astype(right.dtype)
232233

233234
if type(left) != type(right):
234-
raise ValueError("must not have differing left [{left}] "
235-
"and right [{right}] types"
236-
.format(left=type(left), right=type(right)))
237-
238-
if isinstance(left, ABCPeriodIndex):
239-
raise ValueError("Period dtypes are not supported, "
240-
"use a PeriodIndex instead")
235+
msg = ('must not have differing left [{ltype}] and right '
236+
'[{rtype}] types')
237+
raise ValueError(msg.format(ltype=type(left).__name__,
238+
rtype=type(right).__name__))
239+
elif isinstance(left, ABCPeriodIndex):
240+
msg = 'Period dtypes are not supported, use a PeriodIndex instead'
241+
raise ValueError(msg)
242+
elif (isinstance(left, ABCDatetimeIndex) and
243+
str(left.tz) != str(right.tz)):
244+
msg = ("left and right must have the same time zone, got "
245+
"'{left_tz}' and '{right_tz}'")
246+
raise ValueError(msg.format(left_tz=left.tz, right_tz=right.tz))
241247

242248
result._left = left
243249
result._right = right
@@ -657,8 +663,8 @@ def mid(self):
657663
return Index(0.5 * (self.left.values + self.right.values))
658664
except TypeError:
659665
# datetime safe version
660-
delta = self.right.values - self.left.values
661-
return Index(self.left.values + 0.5 * delta)
666+
delta = self.right - self.left
667+
return self.left + 0.5 * delta
662668

663669
@cache_readonly
664670
def is_monotonic(self):

pandas/tests/indexes/test_interval.py

+92-44
Original file line numberDiff line numberDiff line change
@@ -42,24 +42,37 @@ def create_index_with_nan(self, closed='right'):
4242
np.where(mask, np.arange(10), np.nan),
4343
np.where(mask, np.arange(1, 11), np.nan), closed=closed)
4444

45-
def test_constructors(self, closed, name):
46-
left, right = Index([0, 1, 2, 3]), Index([1, 2, 3, 4])
45+
@pytest.mark.parametrize('data', [
46+
Index([0, 1, 2, 3, 4]),
47+
Index(list('abcde')),
48+
date_range('2017-01-01', periods=5),
49+
date_range('2017-01-01', periods=5, tz='US/Eastern'),
50+
timedelta_range('1 day', periods=5)])
51+
def test_constructors(self, data, closed, name):
52+
left, right = data[:-1], data[1:]
4753
ivs = [Interval(l, r, closed=closed) for l, r in lzip(left, right)]
4854
expected = IntervalIndex._simple_new(
4955
left=left, right=right, closed=closed, name=name)
5056

57+
# validate expected
58+
assert expected.closed == closed
59+
assert expected.name == name
60+
assert expected.dtype.subtype == data.dtype
61+
tm.assert_index_equal(expected.left, data[:-1])
62+
tm.assert_index_equal(expected.right, data[1:])
63+
64+
# validated constructors
5165
result = IntervalIndex(ivs, name=name)
5266
tm.assert_index_equal(result, expected)
5367

5468
result = IntervalIndex.from_intervals(ivs, name=name)
5569
tm.assert_index_equal(result, expected)
5670

57-
result = IntervalIndex.from_breaks(
58-
np.arange(5), closed=closed, name=name)
71+
result = IntervalIndex.from_breaks(data, closed=closed, name=name)
5972
tm.assert_index_equal(result, expected)
6073

6174
result = IntervalIndex.from_arrays(
62-
left.values, right.values, closed=closed, name=name)
75+
left, right, closed=closed, name=name)
6376
tm.assert_index_equal(result, expected)
6477

6578
result = IntervalIndex.from_tuples(
@@ -186,6 +199,9 @@ def test_constructors_errors(self):
186199
IntervalIndex.from_intervals([Interval(0, 1),
187200
Interval(1, 2, closed='left')])
188201

202+
with tm.assert_raises_regex(ValueError, msg):
203+
IntervalIndex([Interval(0, 1), Interval(2, 3, closed='left')])
204+
189205
with tm.assert_raises_regex(ValueError, msg):
190206
Index([Interval(0, 1), Interval(2, 3, closed='left')])
191207

@@ -209,26 +225,24 @@ def test_constructors_errors(self):
209225
with tm.assert_raises_regex(ValueError, msg):
210226
IntervalIndex.from_arrays(range(10, -1, -1), range(9, -2, -1))
211227

212-
def test_constructors_datetimelike(self, closed):
228+
@pytest.mark.parametrize('tz_left, tz_right', [
229+
(None, 'UTC'), ('UTC', None), ('UTC', 'US/Eastern')])
230+
def test_constructors_errors_tz(self, tz_left, tz_right):
231+
# GH 18537
232+
left = date_range('2017-01-01', periods=4, tz=tz_left)
233+
right = date_range('2017-01-02', periods=4, tz=tz_right)
213234

214-
# DTI / TDI
215-
for idx in [pd.date_range('20130101', periods=5),
216-
pd.timedelta_range('1 day', periods=5)]:
217-
result = IntervalIndex.from_breaks(idx, closed=closed)
218-
expected = IntervalIndex.from_breaks(idx.values, closed=closed)
219-
tm.assert_index_equal(result, expected)
220-
221-
expected_scalar_type = type(idx[0])
222-
i = result[0]
223-
assert isinstance(i.left, expected_scalar_type)
224-
assert isinstance(i.right, expected_scalar_type)
235+
# don't need to check IntervalIndex(...) or from_intervals, since
236+
# mixed tz are disallowed at the Interval level
237+
with pytest.raises(ValueError):
238+
IntervalIndex.from_arrays(left, right)
225239

226-
def test_constructors_error(self):
240+
with pytest.raises(ValueError):
241+
IntervalIndex.from_tuples(lzip(left, right))
227242

228-
# non-intervals
229-
def f():
230-
IntervalIndex.from_intervals([0.997, 4.0])
231-
pytest.raises(TypeError, f)
243+
with pytest.raises(ValueError):
244+
breaks = left.tolist() + [right[-1]]
245+
IntervalIndex.from_breaks(breaks)
232246

233247
def test_properties(self, closed):
234248
index = self.create_index(closed=closed)
@@ -964,23 +978,46 @@ def test_sort_values(self, closed):
964978
expected = IntervalIndex([np.nan, Interval(1, 2), Interval(0, 1)])
965979
tm.assert_index_equal(result, expected)
966980

967-
def test_datetime(self):
968-
dates = date_range('2000', periods=3)
969-
idx = IntervalIndex.from_breaks(dates)
970-
971-
tm.assert_index_equal(idx.left, dates[:2])
972-
tm.assert_index_equal(idx.right, dates[-2:])
973-
974-
expected = date_range('2000-01-01T12:00', periods=2)
975-
tm.assert_index_equal(idx.mid, expected)
976-
977-
assert Timestamp('2000-01-01T12') not in idx
978-
assert Timestamp('2000-01-01T12') not in idx
979-
980-
target = date_range('1999-12-31T12:00', periods=7, freq='12H')
981-
actual = idx.get_indexer(target)
981+
@pytest.mark.parametrize('tz', [None, 'US/Eastern'])
982+
def test_datetime(self, tz):
983+
start = Timestamp('2000-01-01', tz=tz)
984+
dates = date_range(start=start, periods=10)
985+
index = IntervalIndex.from_breaks(dates)
986+
987+
# test mid
988+
start = Timestamp('2000-01-01T12:00', tz=tz)
989+
expected = date_range(start=start, periods=9)
990+
tm.assert_index_equal(index.mid, expected)
991+
992+
# __contains__ doesn't check individual points
993+
assert Timestamp('2000-01-01', tz=tz) not in index
994+
assert Timestamp('2000-01-01T12', tz=tz) not in index
995+
assert Timestamp('2000-01-02', tz=tz) not in index
996+
iv_true = Interval(Timestamp('2000-01-01T08', tz=tz),
997+
Timestamp('2000-01-01T18', tz=tz))
998+
iv_false = Interval(Timestamp('1999-12-31', tz=tz),
999+
Timestamp('2000-01-01', tz=tz))
1000+
assert iv_true in index
1001+
assert iv_false not in index
1002+
1003+
# .contains does check individual points
1004+
assert not index.contains(Timestamp('2000-01-01', tz=tz))
1005+
assert index.contains(Timestamp('2000-01-01T12', tz=tz))
1006+
assert index.contains(Timestamp('2000-01-02', tz=tz))
1007+
assert index.contains(iv_true)
1008+
assert not index.contains(iv_false)
1009+
1010+
# test get_indexer
1011+
start = Timestamp('1999-12-31T12:00', tz=tz)
1012+
target = date_range(start=start, periods=7, freq='12H')
1013+
actual = index.get_indexer(target)
1014+
expected = np.array([-1, -1, 0, 0, 1, 1, 2], dtype='intp')
1015+
tm.assert_numpy_array_equal(actual, expected)
9821016

983-
expected = np.array([-1, -1, 0, 0, 1, 1, -1], dtype='intp')
1017+
start = Timestamp('2000-01-08T18:00', tz=tz)
1018+
target = date_range(start=start, periods=7, freq='6H')
1019+
actual = index.get_indexer(target)
1020+
expected = np.array([7, 7, 8, 8, 8, 8, -1], dtype='intp')
9841021
tm.assert_numpy_array_equal(actual, expected)
9851022

9861023
def test_append(self, closed):
@@ -1079,9 +1116,11 @@ def test_construction_from_numeric(self, closed, name):
10791116
closed=closed)
10801117
tm.assert_index_equal(result, expected)
10811118

1082-
def test_construction_from_timestamp(self, closed, name):
1119+
@pytest.mark.parametrize('tz', [None, 'US/Eastern'])
1120+
def test_construction_from_timestamp(self, closed, name, tz):
10831121
# combinations of start/end/periods without freq
1084-
start, end = Timestamp('2017-01-01'), Timestamp('2017-01-06')
1122+
start = Timestamp('2017-01-01', tz=tz)
1123+
end = Timestamp('2017-01-06', tz=tz)
10851124
breaks = date_range(start=start, end=end)
10861125
expected = IntervalIndex.from_breaks(breaks, name=name, closed=closed)
10871126

@@ -1099,7 +1138,8 @@ def test_construction_from_timestamp(self, closed, name):
10991138

11001139
# combinations of start/end/periods with fixed freq
11011140
freq = '2D'
1102-
start, end = Timestamp('2017-01-01'), Timestamp('2017-01-07')
1141+
start = Timestamp('2017-01-01', tz=tz)
1142+
end = Timestamp('2017-01-07', tz=tz)
11031143
breaks = date_range(start=start, end=end, freq=freq)
11041144
expected = IntervalIndex.from_breaks(breaks, name=name, closed=closed)
11051145

@@ -1116,14 +1156,15 @@ def test_construction_from_timestamp(self, closed, name):
11161156
tm.assert_index_equal(result, expected)
11171157

11181158
# output truncates early if freq causes end to be skipped.
1119-
end = Timestamp('2017-01-08')
1159+
end = Timestamp('2017-01-08', tz=tz)
11201160
result = interval_range(start=start, end=end, freq=freq, name=name,
11211161
closed=closed)
11221162
tm.assert_index_equal(result, expected)
11231163

11241164
# combinations of start/end/periods with non-fixed freq
11251165
freq = 'M'
1126-
start, end = Timestamp('2017-01-01'), Timestamp('2017-12-31')
1166+
start = Timestamp('2017-01-01', tz=tz)
1167+
end = Timestamp('2017-12-31', tz=tz)
11271168
breaks = date_range(start=start, end=end, freq=freq)
11281169
expected = IntervalIndex.from_breaks(breaks, name=name, closed=closed)
11291170

@@ -1140,7 +1181,7 @@ def test_construction_from_timestamp(self, closed, name):
11401181
tm.assert_index_equal(result, expected)
11411182

11421183
# output truncates early if freq causes end to be skipped.
1143-
end = Timestamp('2018-01-15')
1184+
end = Timestamp('2018-01-15', tz=tz)
11441185
result = interval_range(start=start, end=end, freq=freq, name=name,
11451186
closed=closed)
11461187
tm.assert_index_equal(result, expected)
@@ -1308,6 +1349,13 @@ def test_errors(self):
13081349
with tm.assert_raises_regex(ValueError, msg):
13091350
interval_range(end=Timedelta('1 day'), periods=10, freq='foo')
13101351

1352+
# mixed tz
1353+
start = Timestamp('2017-01-01', tz='US/Eastern')
1354+
end = Timestamp('2017-01-07', tz='US/Pacific')
1355+
msg = 'Start and end cannot both be tz-aware with different timezones'
1356+
with tm.assert_raises_regex(TypeError, msg):
1357+
interval_range(start=start, end=end)
1358+
13111359

13121360
class TestIntervalTree(object):
13131361
def setup_method(self, method):

pandas/tests/scalar/test_interval.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import division
22

3-
from pandas import Interval
3+
from pandas import Interval, Timestamp
4+
from pandas.core.common import _any_none
45

56
import pytest
67
import pandas.util.testing as tm
@@ -137,3 +138,22 @@ def test_math_div(self, interval):
137138

138139
with tm.assert_raises_regex(TypeError, msg):
139140
interval / 'foo'
141+
142+
def test_constructor_errors(self):
143+
msg = "invalid option for 'closed': foo"
144+
with tm.assert_raises_regex(ValueError, msg):
145+
Interval(0, 1, closed='foo')
146+
147+
msg = 'left side of interval must be <= right side'
148+
with tm.assert_raises_regex(ValueError, msg):
149+
Interval(1, 0)
150+
151+
@pytest.mark.parametrize('tz_left, tz_right', [
152+
(None, 'UTC'), ('UTC', None), ('UTC', 'US/Eastern')])
153+
def test_constructor_errors_tz(self, tz_left, tz_right):
154+
# GH 18538
155+
left = Timestamp('2017-01-01', tz=tz_left)
156+
right = Timestamp('2017-01-02', tz=tz_right)
157+
error = TypeError if _any_none(tz_left, tz_right) else ValueError
158+
with pytest.raises(error):
159+
Interval(left, right)

0 commit comments

Comments
 (0)