Skip to content

Commit 787ab55

Browse files
jschendeljreback
authored andcommitted
ENH: Add IntervalDtype support to IntervalIndex.astype (pandas-dev#19231)
1 parent 53be520 commit 787ab55

File tree

6 files changed

+230
-24
lines changed

6 files changed

+230
-24
lines changed

doc/source/whatsnew/v0.23.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ Other Enhancements
202202
- ``Resampler`` objects now have a functioning :attr:`~pandas.core.resample.Resampler.pipe` method.
203203
Previously, calls to ``pipe`` were diverted to the ``mean`` method (:issue:`17905`).
204204
- :func:`~pandas.api.types.is_scalar` now returns ``True`` for ``DateOffset`` objects (:issue:`18943`).
205+
- ``IntervalIndex.astype`` now supports conversions between subtypes when passed an ``IntervalDtype`` (:issue:`19197`)
205206

206207
.. _whatsnew_0230.api_breaking:
207208

pandas/core/dtypes/dtypes.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,8 @@ def __eq__(self, other):
710710
# None should match any subtype
711711
return True
712712
else:
713-
return self.subtype == other.subtype
713+
from pandas.core.dtypes.common import is_dtype_equal
714+
return is_dtype_equal(self.subtype, other.subtype)
714715

715716
@classmethod
716717
def is_dtype(cls, dtype):

pandas/core/indexes/interval.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
is_scalar,
2121
is_float,
2222
is_number,
23-
is_integer)
23+
is_integer,
24+
pandas_dtype)
2425
from pandas.core.indexes.base import (
2526
Index, _ensure_index,
2627
default_pprint, _index_shared_docs)
@@ -699,8 +700,16 @@ def copy(self, deep=False, name=None):
699700

700701
@Appender(_index_shared_docs['astype'])
701702
def astype(self, dtype, copy=True):
702-
if is_interval_dtype(dtype):
703-
return self.copy() if copy else self
703+
dtype = pandas_dtype(dtype)
704+
if is_interval_dtype(dtype) and dtype != self.dtype:
705+
try:
706+
new_left = self.left.astype(dtype.subtype)
707+
new_right = self.right.astype(dtype.subtype)
708+
except TypeError:
709+
msg = ('Cannot convert {dtype} to {new_dtype}; subtypes are '
710+
'incompatible')
711+
raise TypeError(msg.format(dtype=self.dtype, new_dtype=dtype))
712+
return self._shallow_copy(new_left, new_right)
704713
return super(IntervalIndex, self).astype(dtype, copy=copy)
705714

706715
@cache_readonly

pandas/tests/dtypes/test_dtypes.py

+6
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,12 @@ def test_equality(self):
534534
assert not is_dtype_equal(IntervalDtype('int64'),
535535
IntervalDtype('float64'))
536536

537+
# invalid subtype comparisons do not raise when directly compared
538+
dtype1 = IntervalDtype('float64')
539+
dtype2 = IntervalDtype('datetime64[ns, US/Eastern]')
540+
assert dtype1 != dtype2
541+
assert dtype2 != dtype1
542+
537543
@pytest.mark.parametrize('subtype', [
538544
None, 'interval', 'Interval', 'int64', 'uint64', 'float64',
539545
'complex128', 'datetime64', 'timedelta64', PeriodDtype('Q')])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
from __future__ import division
2+
3+
import pytest
4+
import numpy as np
5+
from pandas import (
6+
Index,
7+
IntervalIndex,
8+
interval_range,
9+
CategoricalIndex,
10+
Timestamp,
11+
Timedelta,
12+
NaT)
13+
from pandas.core.dtypes.dtypes import CategoricalDtype, IntervalDtype
14+
import pandas.util.testing as tm
15+
16+
17+
class Base(object):
18+
"""Tests common to IntervalIndex with any subtype"""
19+
20+
def test_astype_idempotent(self, index):
21+
result = index.astype('interval')
22+
tm.assert_index_equal(result, index)
23+
24+
result = index.astype(index.dtype)
25+
tm.assert_index_equal(result, index)
26+
27+
def test_astype_object(self, index):
28+
result = index.astype(object)
29+
expected = Index(index.values, dtype='object')
30+
tm.assert_index_equal(result, expected)
31+
assert not result.equals(index)
32+
33+
def test_astype_category(self, index):
34+
result = index.astype('category')
35+
expected = CategoricalIndex(index.values)
36+
tm.assert_index_equal(result, expected)
37+
38+
result = index.astype(CategoricalDtype())
39+
tm.assert_index_equal(result, expected)
40+
41+
# non-default params
42+
categories = index.dropna().unique().values[:-1]
43+
dtype = CategoricalDtype(categories=categories, ordered=True)
44+
result = index.astype(dtype)
45+
expected = CategoricalIndex(
46+
index.values, categories=categories, ordered=True)
47+
tm.assert_index_equal(result, expected)
48+
49+
@pytest.mark.parametrize('dtype', [
50+
'int64', 'uint64', 'float64', 'complex128', 'period[M]',
51+
'timedelta64', 'timedelta64[ns]', 'datetime64', 'datetime64[ns]',
52+
'datetime64[ns, US/Eastern]'])
53+
def test_astype_cannot_cast(self, index, dtype):
54+
msg = 'Cannot cast IntervalIndex to dtype'
55+
with tm.assert_raises_regex(TypeError, msg):
56+
index.astype(dtype)
57+
58+
def test_astype_invalid_dtype(self, index):
59+
msg = 'data type "fake_dtype" not understood'
60+
with tm.assert_raises_regex(TypeError, msg):
61+
index.astype('fake_dtype')
62+
63+
64+
class TestIntSubtype(Base):
65+
"""Tests specific to IntervalIndex with integer-like subtype"""
66+
67+
indexes = [
68+
IntervalIndex.from_breaks(np.arange(-10, 11, dtype='int64')),
69+
IntervalIndex.from_breaks(
70+
np.arange(100, dtype='uint64'), closed='left'),
71+
]
72+
73+
@pytest.fixture(params=indexes)
74+
def index(self, request):
75+
return request.param
76+
77+
@pytest.mark.parametrize('subtype', [
78+
'float64', 'datetime64[ns]', 'timedelta64[ns]'])
79+
def test_subtype_conversion(self, index, subtype):
80+
dtype = IntervalDtype(subtype)
81+
result = index.astype(dtype)
82+
expected = IntervalIndex.from_arrays(index.left.astype(subtype),
83+
index.right.astype(subtype),
84+
closed=index.closed)
85+
tm.assert_index_equal(result, expected)
86+
87+
@pytest.mark.parametrize('subtype_start, subtype_end', [
88+
('int64', 'uint64'), ('uint64', 'int64')])
89+
def test_subtype_integer(self, subtype_start, subtype_end):
90+
index = IntervalIndex.from_breaks(np.arange(100, dtype=subtype_start))
91+
dtype = IntervalDtype(subtype_end)
92+
result = index.astype(dtype)
93+
expected = IntervalIndex.from_arrays(index.left.astype(subtype_end),
94+
index.right.astype(subtype_end),
95+
closed=index.closed)
96+
tm.assert_index_equal(result, expected)
97+
98+
@pytest.mark.xfail(reason='GH 15832')
99+
def test_subtype_integer_errors(self):
100+
# int64 -> uint64 fails with negative values
101+
index = interval_range(-10, 10)
102+
dtype = IntervalDtype('uint64')
103+
with pytest.raises(ValueError):
104+
index.astype(dtype)
105+
106+
107+
class TestFloatSubtype(Base):
108+
"""Tests specific to IntervalIndex with float subtype"""
109+
110+
indexes = [
111+
interval_range(-10.0, 10.0, closed='neither'),
112+
IntervalIndex.from_arrays([-1.5, np.nan, 0., 0., 1.5],
113+
[-0.5, np.nan, 1., 1., 3.],
114+
closed='both'),
115+
]
116+
117+
@pytest.fixture(params=indexes)
118+
def index(self, request):
119+
return request.param
120+
121+
@pytest.mark.parametrize('subtype', ['int64', 'uint64'])
122+
def test_subtype_integer(self, subtype):
123+
index = interval_range(0.0, 10.0)
124+
dtype = IntervalDtype(subtype)
125+
result = index.astype(dtype)
126+
expected = IntervalIndex.from_arrays(index.left.astype(subtype),
127+
index.right.astype(subtype),
128+
closed=index.closed)
129+
tm.assert_index_equal(result, expected)
130+
131+
# raises with NA
132+
msg = 'Cannot convert NA to integer'
133+
with tm.assert_raises_regex(ValueError, msg):
134+
index.insert(0, np.nan).astype(dtype)
135+
136+
@pytest.mark.xfail(reason='GH 15832')
137+
def test_subtype_integer_errors(self):
138+
# float64 -> uint64 fails with negative values
139+
index = interval_range(-10.0, 10.0)
140+
dtype = IntervalDtype('uint64')
141+
with pytest.raises(ValueError):
142+
index.astype(dtype)
143+
144+
# float64 -> integer-like fails with non-integer valued floats
145+
index = interval_range(0.0, 10.0, freq=0.25)
146+
dtype = IntervalDtype('int64')
147+
with pytest.raises(ValueError):
148+
index.astype(dtype)
149+
150+
dtype = IntervalDtype('uint64')
151+
with pytest.raises(ValueError):
152+
index.astype(dtype)
153+
154+
@pytest.mark.parametrize('subtype', ['datetime64[ns]', 'timedelta64[ns]'])
155+
def test_subtype_datetimelike(self, index, subtype):
156+
dtype = IntervalDtype(subtype)
157+
msg = 'Cannot convert .* to .*; subtypes are incompatible'
158+
with tm.assert_raises_regex(TypeError, msg):
159+
index.astype(dtype)
160+
161+
162+
class TestDatetimelikeSubtype(Base):
163+
"""Tests specific to IntervalIndex with datetime-like subtype"""
164+
165+
indexes = [
166+
interval_range(Timestamp('2018-01-01'), periods=10, closed='neither'),
167+
interval_range(Timestamp('2018-01-01'), periods=10).insert(2, NaT),
168+
interval_range(Timestamp('2018-01-01', tz='US/Eastern'), periods=10),
169+
interval_range(Timedelta('0 days'), periods=10, closed='both'),
170+
interval_range(Timedelta('0 days'), periods=10).insert(2, NaT),
171+
]
172+
173+
@pytest.fixture(params=indexes)
174+
def index(self, request):
175+
return request.param
176+
177+
@pytest.mark.parametrize('subtype', ['int64', 'uint64'])
178+
def test_subtype_integer(self, index, subtype):
179+
dtype = IntervalDtype(subtype)
180+
result = index.astype(dtype)
181+
expected = IntervalIndex.from_arrays(index.left.astype(subtype),
182+
index.right.astype(subtype),
183+
closed=index.closed)
184+
tm.assert_index_equal(result, expected)
185+
186+
def test_subtype_float(self, index):
187+
dtype = IntervalDtype('float64')
188+
msg = 'Cannot convert .* to .*; subtypes are incompatible'
189+
with tm.assert_raises_regex(TypeError, msg):
190+
index.astype(dtype)
191+
192+
def test_subtype_datetimelike(self):
193+
# datetime -> timedelta raises
194+
dtype = IntervalDtype('timedelta64[ns]')
195+
msg = 'Cannot convert .* to .*; subtypes are incompatible'
196+
197+
index = interval_range(Timestamp('2018-01-01'), periods=10)
198+
with tm.assert_raises_regex(TypeError, msg):
199+
index.astype(dtype)
200+
201+
index = interval_range(Timestamp('2018-01-01', tz='CET'), periods=10)
202+
with tm.assert_raises_regex(TypeError, msg):
203+
index.astype(dtype)
204+
205+
# timedelta -> datetime raises
206+
dtype = IntervalDtype('datetime64[ns]')
207+
index = interval_range(Timedelta('0 days'), periods=10)
208+
with tm.assert_raises_regex(TypeError, msg):
209+
index.astype(dtype)

pandas/tests/indexes/interval/test_interval.py

-20
Original file line numberDiff line numberDiff line change
@@ -415,26 +415,6 @@ def test_equals(self, closed):
415415
np.arange(5), closed=other_closed)
416416
assert not expected.equals(expected_other_closed)
417417

418-
def test_astype(self, closed):
419-
idx = self.create_index(closed=closed)
420-
result = idx.astype(object)
421-
tm.assert_index_equal(result, Index(idx.values, dtype='object'))
422-
assert not idx.equals(result)
423-
assert idx.equals(IntervalIndex.from_intervals(result))
424-
425-
result = idx.astype('interval')
426-
tm.assert_index_equal(result, idx)
427-
assert result.equals(idx)
428-
429-
@pytest.mark.parametrize('dtype', [
430-
np.int64, np.float64, 'period[M]', 'timedelta64', 'datetime64[ns]',
431-
'datetime64[ns, US/Eastern]'])
432-
def test_astype_errors(self, closed, dtype):
433-
idx = self.create_index(closed=closed)
434-
msg = 'Cannot cast IntervalIndex to dtype'
435-
with tm.assert_raises_regex(TypeError, msg):
436-
idx.astype(dtype)
437-
438418
@pytest.mark.parametrize('klass', [list, tuple, np.array, pd.Series])
439419
def test_where(self, closed, klass):
440420
idx = self.create_index(closed=closed)

0 commit comments

Comments
 (0)