Skip to content

Commit 6042131

Browse files
committed
refactor dtype update
1 parent afcc50a commit 6042131

File tree

3 files changed

+66
-15
lines changed

3 files changed

+66
-15
lines changed

pandas/core/dtypes/dtypes.py

+27
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,33 @@ def _validate_categories(categories, fastpath=False):
340340

341341
return categories
342342

343+
def _update_dtype(self, dtype):
344+
"""
345+
Returns a CategoricalDtype with categories and ordered taken from dtype
346+
if specified, otherwise falling back to self if unspecified
347+
348+
Parameters
349+
----------
350+
dtype : CategoricalDtype
351+
352+
Returns
353+
-------
354+
new_dtype : CategoricalDtype
355+
"""
356+
if isinstance(dtype, compat.string_types) and dtype == 'category':
357+
# dtype='category' should not change anything
358+
return self
359+
elif not self.is_dtype(dtype):
360+
msg = ('a CategoricalDtype must be passed to perform an update, '
361+
'got {dtype!r}').format(dtype=dtype)
362+
raise ValueError(msg)
363+
364+
# dtype is CDT: keep current categories if None (ordered can't be None)
365+
new_categories = dtype.categories
366+
if new_categories is None:
367+
new_categories = self.categories
368+
return CategoricalDtype(new_categories, dtype.ordered)
369+
343370
@property
344371
def categories(self):
345372
"""

pandas/core/indexes/category.py

+3-15
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
_ensure_platform_int,
1111
is_list_like,
1212
is_interval_dtype,
13-
is_scalar,
14-
pandas_dtype)
13+
is_scalar)
1514
from pandas.core.common import (_asarray_tuplesafe,
1615
_values_from_object)
1716
from pandas.core.dtypes.missing import array_equivalent, isna
@@ -341,23 +340,12 @@ def __array__(self, dtype=None):
341340

342341
@Appender(_index_shared_docs['astype'])
343342
def astype(self, dtype, copy=True):
344-
if isinstance(dtype, compat.string_types) and dtype == 'category':
345-
# GH 18630: CI.astype('category') should not change anything
346-
return self.copy() if copy else self
347-
348-
dtype = pandas_dtype(dtype)
349343
if is_interval_dtype(dtype):
350344
from pandas import IntervalIndex
351345
return IntervalIndex.from_intervals(np.array(self))
352346
elif is_categorical_dtype(dtype):
353-
# GH 18630: keep current categories if None (ordered can't be None)
354-
if dtype.categories is None:
355-
new_categories = self.categories
356-
else:
357-
new_categories = dtype.categories
358-
dtype = CategoricalDtype(new_categories, dtype.ordered)
359-
360-
# fastpath if dtypes are equal
347+
# GH 18630
348+
dtype = self.dtype._update_dtype(dtype)
361349
if dtype == self.dtype:
362350
return self.copy() if copy else self
363351

pandas/tests/dtypes/test_dtypes.py

+36
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pandas import (
1010
Series, Categorical, CategoricalIndex, IntervalIndex, date_range)
1111

12+
from pandas.compat import string_types
1213
from pandas.core.dtypes.dtypes import (
1314
DatetimeTZDtype, PeriodDtype,
1415
IntervalDtype, CategoricalDtype)
@@ -123,6 +124,41 @@ def test_tuple_categories(self):
123124
result = CategoricalDtype(categories)
124125
assert all(result.categories == categories)
125126

127+
@pytest.mark.parametrize('dtype', [
128+
CategoricalDtype(list('abc'), False),
129+
CategoricalDtype(list('abc'), True)])
130+
@pytest.mark.parametrize('new_dtype', [
131+
'category',
132+
CategoricalDtype(None, False),
133+
CategoricalDtype(None, True),
134+
CategoricalDtype(list('abc'), False),
135+
CategoricalDtype(list('abc'), True),
136+
CategoricalDtype(list('cba'), False),
137+
CategoricalDtype(list('cba'), True),
138+
CategoricalDtype(list('wxyz'), False),
139+
CategoricalDtype(list('wxyz'), True)])
140+
def test_update_dtype(self, dtype, new_dtype):
141+
if isinstance(new_dtype, string_types) and new_dtype == 'category':
142+
expected_categories = dtype.categories
143+
expected_ordered = dtype.ordered
144+
else:
145+
expected_categories = new_dtype.categories
146+
if expected_categories is None:
147+
expected_categories = dtype.categories
148+
expected_ordered = new_dtype.ordered
149+
150+
result = dtype._update_dtype(new_dtype)
151+
tm.assert_index_equal(result.categories, expected_categories)
152+
assert result.ordered is expected_ordered
153+
154+
@pytest.mark.parametrize('bad_dtype', [
155+
'foo', object, np.int64, PeriodDtype('Q'), IntervalDtype(object)])
156+
def test_update_dtype_errors(self, bad_dtype):
157+
dtype = CategoricalDtype(list('abc'), False)
158+
msg = 'a CategoricalDtype must be passed to perform an update, '
159+
with tm.assert_raises_regex(ValueError, msg):
160+
dtype._update_dtype(bad_dtype)
161+
126162

127163
class TestDatetimeTZDtype(Base):
128164

0 commit comments

Comments
 (0)