Skip to content

ENH: Support TZ Aware IntervalIndex #18558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 8, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.22.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ Other Enhancements
- :func:`pandas.read_clipboard` updated to use qtpy, falling back to PyQt5 and then PyQt4, adding compatibility with Python3 and multiple python-qt bindings (:issue:`17722`)
- Improved wording of ``ValueError`` raised in :func:`read_csv` when the ``usecols`` argument cannot match all columns. (:issue:`17301`)
- :func:`DataFrame.corrwith` now silently drops non-numeric columns when passed a Series. Before, an exception was raised (:issue:`18570`).
- :class:`IntervalIndex` now supports time zone aware ``Interval`` objects (:issue:`18537`, :issue:`18538`)


.. _whatsnew_0220.api_breaking:
Expand Down
8 changes: 8 additions & 0 deletions pandas/_libs/interval.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ cimport cython
import cython
from numpy cimport ndarray
from tslib import Timestamp
from tslibs.timezones cimport get_timezone

from cpython.object cimport (Py_EQ, Py_NE, Py_GT, Py_LT, Py_GE, Py_LE,
PyObject_RichCompare)
Expand Down Expand Up @@ -119,6 +120,13 @@ cdef class Interval(IntervalMixin):
raise ValueError(msg)
if not left <= right:
raise ValueError('left side of interval must be <= right side')
if (isinstance(left, Timestamp) and
get_timezone(left.tzinfo) != get_timezone(right.tzinfo)):
# GH 18538
msg = ("left and right must have the same time zone, got "
"'{left_tz}' and '{right_tz}'")
raise ValueError(msg.format(left_tz=left.tzinfo,
right_tz=right.tzinfo))
self.left = left
self.right = right
self.closed = closed
Expand Down
32 changes: 19 additions & 13 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import numpy as np

from pandas.core.dtypes.missing import notna, isna
from pandas.core.dtypes.generic import ABCPeriodIndex
from pandas.core.dtypes.generic import ABCDatetimeIndex, ABCPeriodIndex
from pandas.core.dtypes.dtypes import IntervalDtype
from pandas.core.dtypes.cast import maybe_convert_platform
from pandas.core.dtypes.common import (
_ensure_platform_int,
is_list_like,
is_datetime_or_timedelta_dtype,
is_datetime64tz_dtype,
is_integer_dtype,
is_object_dtype,
is_categorical_dtype,
Expand Down Expand Up @@ -54,7 +55,7 @@ def _get_next_label(label):
dtype = getattr(label, 'dtype', type(label))
if isinstance(label, (Timestamp, Timedelta)):
dtype = 'datetime64'
if is_datetime_or_timedelta_dtype(dtype):
if is_datetime_or_timedelta_dtype(dtype) or is_datetime64tz_dtype(dtype):
return label + np.timedelta64(1, 'ns')
elif is_integer_dtype(dtype):
return label + 1
Expand All @@ -69,7 +70,7 @@ def _get_prev_label(label):
dtype = getattr(label, 'dtype', type(label))
if isinstance(label, (Timestamp, Timedelta)):
dtype = 'datetime64'
if is_datetime_or_timedelta_dtype(dtype):
if is_datetime_or_timedelta_dtype(dtype) or is_datetime64tz_dtype(dtype):
return label - np.timedelta64(1, 'ns')
elif is_integer_dtype(dtype):
return label - 1
Expand Down Expand Up @@ -227,17 +228,22 @@ def _simple_new(cls, left, right, closed=None, name=None,
# coerce dtypes to match if needed
if is_float_dtype(left) and is_integer_dtype(right):
right = right.astype(left.dtype)
if is_float_dtype(right) and is_integer_dtype(left):
elif is_float_dtype(right) and is_integer_dtype(left):
left = left.astype(right.dtype)

if type(left) != type(right):
raise ValueError("must not have differing left [{left}] "
"and right [{right}] types"
.format(left=type(left), right=type(right)))

if isinstance(left, ABCPeriodIndex):
raise ValueError("Period dtypes are not supported, "
"use a PeriodIndex instead")
msg = ('must not have differing left [{ltype}] and right '
'[{rtype}] types')
raise ValueError(msg.format(ltype=type(left).__name__,
rtype=type(right).__name__))
elif isinstance(left, ABCPeriodIndex):
msg = 'Period dtypes are not supported, use a PeriodIndex instead'
raise ValueError(msg)
elif (isinstance(left, ABCDatetimeIndex) and
str(left.tz) != str(right.tz)):
msg = ("left and right must have the same time zone, got "
"'{left_tz}' and '{right_tz}'")
raise ValueError(msg.format(left_tz=left.tz, right_tz=right.tz))

result._left = left
result._right = right
Expand Down Expand Up @@ -657,8 +663,8 @@ def mid(self):
return Index(0.5 * (self.left.values + self.right.values))
except TypeError:
# datetime safe version
delta = self.right.values - self.left.values
return Index(self.left.values + 0.5 * delta)
delta = self.right - self.left
return self.left + 0.5 * delta

@cache_readonly
def is_monotonic(self):
Expand Down
136 changes: 92 additions & 44 deletions pandas/tests/indexes/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,37 @@ def create_index_with_nan(self, closed='right'):
np.where(mask, np.arange(10), np.nan),
np.where(mask, np.arange(1, 11), np.nan), closed=closed)

def test_constructors(self, closed, name):
left, right = Index([0, 1, 2, 3]), Index([1, 2, 3, 4])
@pytest.mark.parametrize('data', [
Index([0, 1, 2, 3, 4]),
Index(list('abcde')),
date_range('2017-01-01', periods=5),
date_range('2017-01-01', periods=5, tz='US/Eastern'),
timedelta_range('1 day', periods=5)])
def test_constructors(self, data, closed, name):
left, right = data[:-1], data[1:]
ivs = [Interval(l, r, closed=closed) for l, r in lzip(left, right)]
expected = IntervalIndex._simple_new(
left=left, right=right, closed=closed, name=name)

# validate expected
assert expected.closed == closed
assert expected.name == name
assert expected.dtype.subtype == data.dtype
tm.assert_index_equal(expected.left, data[:-1])
tm.assert_index_equal(expected.right, data[1:])

# validated constructors
result = IntervalIndex(ivs, name=name)
tm.assert_index_equal(result, expected)

result = IntervalIndex.from_intervals(ivs, name=name)
tm.assert_index_equal(result, expected)

result = IntervalIndex.from_breaks(
np.arange(5), closed=closed, name=name)
result = IntervalIndex.from_breaks(data, closed=closed, name=name)
tm.assert_index_equal(result, expected)

result = IntervalIndex.from_arrays(
left.values, right.values, closed=closed, name=name)
left, right, closed=closed, name=name)
tm.assert_index_equal(result, expected)

result = IntervalIndex.from_tuples(
Expand Down Expand Up @@ -186,6 +199,9 @@ def test_constructors_errors(self):
IntervalIndex.from_intervals([Interval(0, 1),
Interval(1, 2, closed='left')])

with tm.assert_raises_regex(ValueError, msg):
IntervalIndex([Interval(0, 1), Interval(2, 3, closed='left')])

with tm.assert_raises_regex(ValueError, msg):
Index([Interval(0, 1), Interval(2, 3, closed='left')])

Expand All @@ -209,26 +225,24 @@ def test_constructors_errors(self):
with tm.assert_raises_regex(ValueError, msg):
IntervalIndex.from_arrays(range(10, -1, -1), range(9, -2, -1))

def test_constructors_datetimelike(self, closed):
@pytest.mark.parametrize('tz_left, tz_right', [
(None, 'UTC'), ('UTC', None), ('UTC', 'US/Eastern')])
def test_constructors_errors_tz(self, tz_left, tz_right):
# GH 18537
left = date_range('2017-01-01', periods=4, tz=tz_left)
right = date_range('2017-01-02', periods=4, tz=tz_right)

# DTI / TDI
for idx in [pd.date_range('20130101', periods=5),
pd.timedelta_range('1 day', periods=5)]:
result = IntervalIndex.from_breaks(idx, closed=closed)
expected = IntervalIndex.from_breaks(idx.values, closed=closed)
tm.assert_index_equal(result, expected)

expected_scalar_type = type(idx[0])
i = result[0]
assert isinstance(i.left, expected_scalar_type)
assert isinstance(i.right, expected_scalar_type)
# don't need to check IntervalIndex(...) or from_intervals, since
# mixed tz are disallowed at the Interval level
with pytest.raises(ValueError):
IntervalIndex.from_arrays(left, right)

def test_constructors_error(self):
with pytest.raises(ValueError):
IntervalIndex.from_tuples(lzip(left, right))

# non-intervals
def f():
IntervalIndex.from_intervals([0.997, 4.0])
pytest.raises(TypeError, f)
with pytest.raises(ValueError):
breaks = left.tolist() + [right[-1]]
IntervalIndex.from_breaks(breaks)

def test_properties(self, closed):
index = self.create_index(closed=closed)
Expand Down Expand Up @@ -964,23 +978,46 @@ def test_sort_values(self, closed):
expected = IntervalIndex([np.nan, Interval(1, 2), Interval(0, 1)])
tm.assert_index_equal(result, expected)

def test_datetime(self):
dates = date_range('2000', periods=3)
idx = IntervalIndex.from_breaks(dates)

tm.assert_index_equal(idx.left, dates[:2])
tm.assert_index_equal(idx.right, dates[-2:])

expected = date_range('2000-01-01T12:00', periods=2)
tm.assert_index_equal(idx.mid, expected)

assert Timestamp('2000-01-01T12') not in idx
assert Timestamp('2000-01-01T12') not in idx

target = date_range('1999-12-31T12:00', periods=7, freq='12H')
actual = idx.get_indexer(target)
@pytest.mark.parametrize('tz', [None, 'US/Eastern'])
def test_datetime(self, tz):
start = Timestamp('2000-01-01', tz=tz)
dates = date_range(start=start, periods=10)
index = IntervalIndex.from_breaks(dates)

# test mid
start = Timestamp('2000-01-01T12:00', tz=tz)
expected = date_range(start=start, periods=9)
tm.assert_index_equal(index.mid, expected)

# __contains__ doesn't check individual points
assert Timestamp('2000-01-01', tz=tz) not in index
assert Timestamp('2000-01-01T12', tz=tz) not in index
assert Timestamp('2000-01-02', tz=tz) not in index
iv_true = Interval(Timestamp('2000-01-01T08', tz=tz),
Timestamp('2000-01-01T18', tz=tz))
iv_false = Interval(Timestamp('1999-12-31', tz=tz),
Timestamp('2000-01-01', tz=tz))
assert iv_true in index
assert iv_false not in index

# .contains does check individual points
assert not index.contains(Timestamp('2000-01-01', tz=tz))
assert index.contains(Timestamp('2000-01-01T12', tz=tz))
assert index.contains(Timestamp('2000-01-02', tz=tz))
assert index.contains(iv_true)
assert not index.contains(iv_false)

# test get_indexer
start = Timestamp('1999-12-31T12:00', tz=tz)
target = date_range(start=start, periods=7, freq='12H')
actual = index.get_indexer(target)
expected = np.array([-1, -1, 0, 0, 1, 1, 2], dtype='intp')
tm.assert_numpy_array_equal(actual, expected)

expected = np.array([-1, -1, 0, 0, 1, 1, -1], dtype='intp')
start = Timestamp('2000-01-08T18:00', tz=tz)
target = date_range(start=start, periods=7, freq='6H')
actual = index.get_indexer(target)
expected = np.array([7, 7, 8, 8, 8, 8, -1], dtype='intp')
tm.assert_numpy_array_equal(actual, expected)

def test_append(self, closed):
Expand Down Expand Up @@ -1079,9 +1116,11 @@ def test_construction_from_numeric(self, closed, name):
closed=closed)
tm.assert_index_equal(result, expected)

def test_construction_from_timestamp(self, closed, name):
@pytest.mark.parametrize('tz', [None, 'US/Eastern'])
def test_construction_from_timestamp(self, closed, name, tz):
# combinations of start/end/periods without freq
start, end = Timestamp('2017-01-01'), Timestamp('2017-01-06')
start = Timestamp('2017-01-01', tz=tz)
end = Timestamp('2017-01-06', tz=tz)
breaks = date_range(start=start, end=end)
expected = IntervalIndex.from_breaks(breaks, name=name, closed=closed)

Expand All @@ -1099,7 +1138,8 @@ def test_construction_from_timestamp(self, closed, name):

# combinations of start/end/periods with fixed freq
freq = '2D'
start, end = Timestamp('2017-01-01'), Timestamp('2017-01-07')
start = Timestamp('2017-01-01', tz=tz)
end = Timestamp('2017-01-07', tz=tz)
breaks = date_range(start=start, end=end, freq=freq)
expected = IntervalIndex.from_breaks(breaks, name=name, closed=closed)

Expand All @@ -1116,14 +1156,15 @@ def test_construction_from_timestamp(self, closed, name):
tm.assert_index_equal(result, expected)

# output truncates early if freq causes end to be skipped.
end = Timestamp('2017-01-08')
end = Timestamp('2017-01-08', tz=tz)
result = interval_range(start=start, end=end, freq=freq, name=name,
closed=closed)
tm.assert_index_equal(result, expected)

# combinations of start/end/periods with non-fixed freq
freq = 'M'
start, end = Timestamp('2017-01-01'), Timestamp('2017-12-31')
start = Timestamp('2017-01-01', tz=tz)
end = Timestamp('2017-12-31', tz=tz)
breaks = date_range(start=start, end=end, freq=freq)
expected = IntervalIndex.from_breaks(breaks, name=name, closed=closed)

Expand All @@ -1140,7 +1181,7 @@ def test_construction_from_timestamp(self, closed, name):
tm.assert_index_equal(result, expected)

# output truncates early if freq causes end to be skipped.
end = Timestamp('2018-01-15')
end = Timestamp('2018-01-15', tz=tz)
result = interval_range(start=start, end=end, freq=freq, name=name,
closed=closed)
tm.assert_index_equal(result, expected)
Expand Down Expand Up @@ -1308,6 +1349,13 @@ def test_errors(self):
with tm.assert_raises_regex(ValueError, msg):
interval_range(end=Timedelta('1 day'), periods=10, freq='foo')

# mixed tz
start = Timestamp('2017-01-01', tz='US/Eastern')
end = Timestamp('2017-01-07', tz='US/Pacific')
msg = 'Start and end cannot both be tz-aware with different timezones'
with tm.assert_raises_regex(TypeError, msg):
interval_range(start=start, end=end)


class TestIntervalTree(object):
def setup_method(self, method):
Expand Down
22 changes: 21 additions & 1 deletion pandas/tests/scalar/test_interval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import division

from pandas import Interval
from pandas import Interval, Timestamp
from pandas.core.common import _any_none

import pytest
import pandas.util.testing as tm
Expand Down Expand Up @@ -137,3 +138,22 @@ def test_math_div(self, interval):

with tm.assert_raises_regex(TypeError, msg):
interval / 'foo'

def test_constructor_errors(self):
msg = "invalid option for 'closed': foo"
with tm.assert_raises_regex(ValueError, msg):
Interval(0, 1, closed='foo')

msg = 'left side of interval must be <= right side'
with tm.assert_raises_regex(ValueError, msg):
Interval(1, 0)

@pytest.mark.parametrize('tz_left, tz_right', [
(None, 'UTC'), ('UTC', None), ('UTC', 'US/Eastern')])
def test_constructor_errors_tz(self, tz_left, tz_right):
# GH 18538
left = Timestamp('2017-01-01', tz=tz_left)
right = Timestamp('2017-01-02', tz=tz_right)
error = TypeError if _any_none(tz_left, tz_right) else ValueError
with pytest.raises(error):
Interval(left, right)