|
9 | 9 | from pandas import (
|
10 | 10 | Series, Categorical, CategoricalIndex, IntervalIndex, date_range)
|
11 | 11 |
|
| 12 | +from pandas.compat import string_types |
12 | 13 | from pandas.core.dtypes.dtypes import (
|
13 | 14 | DatetimeTZDtype, PeriodDtype,
|
14 | 15 | IntervalDtype, CategoricalDtype)
|
@@ -123,6 +124,41 @@ def test_tuple_categories(self):
|
123 | 124 | result = CategoricalDtype(categories)
|
124 | 125 | assert all(result.categories == categories)
|
125 | 126 |
|
| 127 | + @pytest.mark.parametrize('dtype', [ |
| 128 | + CategoricalDtype(list('abc'), False), |
| 129 | + CategoricalDtype(list('abc'), True)]) |
| 130 | + @pytest.mark.parametrize('new_dtype', [ |
| 131 | + 'category', |
| 132 | + CategoricalDtype(None, False), |
| 133 | + CategoricalDtype(None, True), |
| 134 | + CategoricalDtype(list('abc'), False), |
| 135 | + CategoricalDtype(list('abc'), True), |
| 136 | + CategoricalDtype(list('cba'), False), |
| 137 | + CategoricalDtype(list('cba'), True), |
| 138 | + CategoricalDtype(list('wxyz'), False), |
| 139 | + CategoricalDtype(list('wxyz'), True)]) |
| 140 | + def test_update_dtype(self, dtype, new_dtype): |
| 141 | + if isinstance(new_dtype, string_types) and new_dtype == 'category': |
| 142 | + expected_categories = dtype.categories |
| 143 | + expected_ordered = dtype.ordered |
| 144 | + else: |
| 145 | + expected_categories = new_dtype.categories |
| 146 | + if expected_categories is None: |
| 147 | + expected_categories = dtype.categories |
| 148 | + expected_ordered = new_dtype.ordered |
| 149 | + |
| 150 | + result = dtype._update_dtype(new_dtype) |
| 151 | + tm.assert_index_equal(result.categories, expected_categories) |
| 152 | + assert result.ordered is expected_ordered |
| 153 | + |
| 154 | + @pytest.mark.parametrize('bad_dtype', [ |
| 155 | + 'foo', object, np.int64, PeriodDtype('Q'), IntervalDtype(object)]) |
| 156 | + def test_update_dtype_errors(self, bad_dtype): |
| 157 | + dtype = CategoricalDtype(list('abc'), False) |
| 158 | + msg = 'a CategoricalDtype must be passed to perform an update, ' |
| 159 | + with tm.assert_raises_regex(ValueError, msg): |
| 160 | + dtype._update_dtype(bad_dtype) |
| 161 | + |
126 | 162 |
|
127 | 163 | class TestDatetimeTZDtype(Base):
|
128 | 164 |
|
|
0 commit comments