Skip to content

Commit c65c124

Browse files
jschendelTomAugspurger
authored andcommitted
BUG: Allow IntervalIndex to be constructed from categorical data with appropriate dtype (pandas-dev#21254)
(cherry picked from commit 686f604)
1 parent 8350429 commit c65c124

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

pandas/core/indexes/interval.py

+4
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def maybe_convert_platform_interval(values):
112112
-------
113113
array
114114
"""
115+
if is_categorical_dtype(values):
116+
# GH 21243/21253
117+
values = np.array(values)
118+
115119
if isinstance(values, (list, tuple)) and len(values) == 0:
116120
# GH 19016
117121
# empty lists/tuples get object dtype by default, but this is not

pandas/tests/indexes/interval/test_construction.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66

77
from pandas import (
88
Interval, IntervalIndex, Index, Int64Index, Float64Index, Categorical,
9-
date_range, timedelta_range, period_range, notna)
9+
CategoricalIndex, date_range, timedelta_range, period_range, notna)
1010
from pandas.compat import lzip
11+
from pandas.core.dtypes.common import is_categorical_dtype
1112
from pandas.core.dtypes.dtypes import IntervalDtype
1213
import pandas.core.common as com
1314
import pandas.util.testing as tm
@@ -111,6 +112,22 @@ def test_constructor_string(self, constructor, breaks):
111112
with tm.assert_raises_regex(TypeError, msg):
112113
constructor(**self.get_kwargs_from_breaks(breaks))
113114

115+
@pytest.mark.parametrize('cat_constructor', [
116+
Categorical, CategoricalIndex])
117+
def test_constructor_categorical_valid(self, constructor, cat_constructor):
118+
# GH 21243/21253
119+
if isinstance(constructor, partial) and constructor.func is Index:
120+
# Index is defined to create CategoricalIndex from categorical data
121+
pytest.skip()
122+
123+
breaks = np.arange(10, dtype='int64')
124+
expected = IntervalIndex.from_breaks(breaks)
125+
126+
cat_breaks = cat_constructor(breaks)
127+
result_kwargs = self.get_kwargs_from_breaks(cat_breaks)
128+
result = constructor(**result_kwargs)
129+
tm.assert_index_equal(result, expected)
130+
114131
def test_generic_errors(self, constructor):
115132
# filler input data to be used when supplying invalid kwargs
116133
filler = self.get_kwargs_from_breaks(range(10))
@@ -238,6 +255,8 @@ def get_kwargs_from_breaks(self, breaks, closed='right'):
238255
tuples = lzip(breaks[:-1], breaks[1:])
239256
if isinstance(breaks, (list, tuple)):
240257
return {'data': tuples}
258+
elif is_categorical_dtype(breaks):
259+
return {'data': breaks._constructor(tuples)}
241260
return {'data': com._asarray_tuplesafe(tuples)}
242261

243262
def test_constructor_errors(self):
@@ -286,6 +305,8 @@ def get_kwargs_from_breaks(self, breaks, closed='right'):
286305

287306
if isinstance(breaks, list):
288307
return {'data': ivs}
308+
elif is_categorical_dtype(breaks):
309+
return {'data': breaks._constructor(ivs)}
289310
return {'data': np.array(ivs, dtype=object)}
290311

291312
def test_generic_errors(self, constructor):

0 commit comments

Comments
 (0)