|
6 | 6 |
|
7 | 7 | from pandas import (
|
8 | 8 | Interval, IntervalIndex, Index, Int64Index, Float64Index, Categorical,
|
9 |
| - date_range, timedelta_range, period_range, notna) |
| 9 | + CategoricalIndex, date_range, timedelta_range, period_range, notna) |
10 | 10 | from pandas.compat import lzip
|
| 11 | +from pandas.core.dtypes.common import is_categorical_dtype |
11 | 12 | from pandas.core.dtypes.dtypes import IntervalDtype
|
12 | 13 | import pandas.core.common as com
|
13 | 14 | import pandas.util.testing as tm
|
@@ -111,6 +112,22 @@ def test_constructor_string(self, constructor, breaks):
|
111 | 112 | with tm.assert_raises_regex(TypeError, msg):
|
112 | 113 | constructor(**self.get_kwargs_from_breaks(breaks))
|
113 | 114 |
|
| 115 | + @pytest.mark.parametrize('cat_constructor', [ |
| 116 | + Categorical, CategoricalIndex]) |
| 117 | + def test_constructor_categorical_valid(self, constructor, cat_constructor): |
| 118 | + # GH 21243/21253 |
| 119 | + if isinstance(constructor, partial) and constructor.func is Index: |
| 120 | + # Index is defined to create CategoricalIndex from categorical data |
| 121 | + pytest.skip() |
| 122 | + |
| 123 | + breaks = np.arange(10, dtype='int64') |
| 124 | + expected = IntervalIndex.from_breaks(breaks) |
| 125 | + |
| 126 | + cat_breaks = cat_constructor(breaks) |
| 127 | + result_kwargs = self.get_kwargs_from_breaks(cat_breaks) |
| 128 | + result = constructor(**result_kwargs) |
| 129 | + tm.assert_index_equal(result, expected) |
| 130 | + |
114 | 131 | def test_generic_errors(self, constructor):
|
115 | 132 | # filler input data to be used when supplying invalid kwargs
|
116 | 133 | filler = self.get_kwargs_from_breaks(range(10))
|
@@ -238,6 +255,8 @@ def get_kwargs_from_breaks(self, breaks, closed='right'):
|
238 | 255 | tuples = lzip(breaks[:-1], breaks[1:])
|
239 | 256 | if isinstance(breaks, (list, tuple)):
|
240 | 257 | return {'data': tuples}
|
| 258 | + elif is_categorical_dtype(breaks): |
| 259 | + return {'data': breaks._constructor(tuples)} |
241 | 260 | return {'data': com._asarray_tuplesafe(tuples)}
|
242 | 261 |
|
243 | 262 | def test_constructor_errors(self):
|
@@ -286,6 +305,8 @@ def get_kwargs_from_breaks(self, breaks, closed='right'):
|
286 | 305 |
|
287 | 306 | if isinstance(breaks, list):
|
288 | 307 | return {'data': ivs}
|
| 308 | + elif is_categorical_dtype(breaks): |
| 309 | + return {'data': breaks._constructor(ivs)} |
289 | 310 | return {'data': np.array(ivs, dtype=object)}
|
290 | 311 |
|
291 | 312 | def test_generic_errors(self, constructor):
|
|
0 commit comments