Skip to content

Commit 3821040

Browse files
jschendeljreback
authored andcommitted
BUG: Ensure Index.astype('category') returns a CategoricalIndex (#18677)
1 parent c753e1e commit 3821040

18 files changed

+201
-30
lines changed

doc/source/whatsnew/v0.22.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ Conversion
259259
- Fixed a bug where creating a Series from an array that contains both tz-naive and tz-aware values will result in a Series whose dtype is tz-aware instead of object (:issue:`16406`)
260260
- Adding a ``Period`` object to a ``datetime`` or ``Timestamp`` object will now correctly raise a ``TypeError`` (:issue:`17983`)
261261
- Fixed a bug where ``FY5253`` date offsets could incorrectly raise an ``AssertionError`` in arithmetic operatons (:issue:`14774`)
262+
- Bug in :meth:`Index.astype` with a categorical dtype where the resultant index is not converted to a :class:`CategoricalIndex` for all types of index (:issue:`18630`)
262263

263264

264265
Indexing

pandas/core/dtypes/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1934,7 +1934,7 @@ def pandas_dtype(dtype):
19341934
except TypeError:
19351935
pass
19361936

1937-
elif dtype.startswith('interval[') or dtype.startswith('Interval['):
1937+
elif dtype.startswith('interval') or dtype.startswith('Interval'):
19381938
try:
19391939
return IntervalDtype.construct_from_string(dtype)
19401940
except TypeError:

pandas/core/dtypes/dtypes.py

+27
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,33 @@ def _validate_categories(categories, fastpath=False):
340340

341341
return categories
342342

343+
def _update_dtype(self, dtype):
344+
"""
345+
Returns a CategoricalDtype with categories and ordered taken from dtype
346+
if specified, otherwise falling back to self if unspecified
347+
348+
Parameters
349+
----------
350+
dtype : CategoricalDtype
351+
352+
Returns
353+
-------
354+
new_dtype : CategoricalDtype
355+
"""
356+
if isinstance(dtype, compat.string_types) and dtype == 'category':
357+
# dtype='category' should not change anything
358+
return self
359+
elif not self.is_dtype(dtype):
360+
msg = ('a CategoricalDtype must be passed to perform an update, '
361+
'got {dtype!r}').format(dtype=dtype)
362+
raise ValueError(msg)
363+
364+
# dtype is CDT: keep current categories if None (ordered can't be None)
365+
new_categories = dtype.categories
366+
if new_categories is None:
367+
new_categories = self.categories
368+
return CategoricalDtype(new_categories, dtype.ordered)
369+
343370
@property
344371
def categories(self):
345372
"""

pandas/core/indexes/base.py

+4
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,10 @@ def _to_embed(self, keep_tz=False, dtype=None):
10531053

10541054
@Appender(_index_shared_docs['astype'])
10551055
def astype(self, dtype, copy=True):
1056+
if is_categorical_dtype(dtype):
1057+
from .category import CategoricalIndex
1058+
return CategoricalIndex(self.values, name=self.name, dtype=dtype,
1059+
copy=copy)
10561060
return Index(self.values.astype(dtype, copy=copy), name=self.name,
10571061
dtype=dtype)
10581062

pandas/core/indexes/category.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pandas import compat
55
from pandas.compat.numpy import function as nv
66
from pandas.core.dtypes.generic import ABCCategorical, ABCSeries
7+
from pandas.core.dtypes.dtypes import CategoricalDtype
78
from pandas.core.dtypes.common import (
89
is_categorical_dtype,
910
_ensure_platform_int,
@@ -165,8 +166,6 @@ def _create_categorical(self, data, categories=None, ordered=None,
165166
data = Categorical(data, categories=categories, ordered=ordered,
166167
dtype=dtype)
167168
else:
168-
from pandas.core.dtypes.dtypes import CategoricalDtype
169-
170169
if categories is not None:
171170
data = data.set_categories(categories, ordered=ordered)
172171
elif ordered is not None and ordered != data.ordered:
@@ -344,6 +343,12 @@ def astype(self, dtype, copy=True):
344343
if is_interval_dtype(dtype):
345344
from pandas import IntervalIndex
346345
return IntervalIndex.from_intervals(np.array(self))
346+
elif is_categorical_dtype(dtype):
347+
# GH 18630
348+
dtype = self.dtype._update_dtype(dtype)
349+
if dtype == self.dtype:
350+
return self.copy() if copy else self
351+
347352
return super(CategoricalIndex, self).astype(dtype=dtype, copy=copy)
348353

349354
@cache_readonly

pandas/core/indexes/datetimes.py

+5
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
is_period_dtype,
2121
is_bool_dtype,
2222
is_string_dtype,
23+
is_categorical_dtype,
2324
is_string_like,
2425
is_list_like,
2526
is_scalar,
@@ -35,6 +36,7 @@
3536
from pandas.core.algorithms import checked_add_with_arr
3637

3738
from pandas.core.indexes.base import Index, _index_shared_docs
39+
from pandas.core.indexes.category import CategoricalIndex
3840
from pandas.core.indexes.numeric import Int64Index, Float64Index
3941
import pandas.compat as compat
4042
from pandas.tseries.frequencies import (
@@ -915,6 +917,9 @@ def astype(self, dtype, copy=True):
915917
elif copy is True:
916918
return self.copy()
917919
return self
920+
elif is_categorical_dtype(dtype):
921+
return CategoricalIndex(self.values, name=self.name, dtype=dtype,
922+
copy=copy)
918923
elif is_string_dtype(dtype):
919924
return Index(self.format(), name=self.name, dtype=object)
920925
elif is_period_dtype(dtype):

pandas/core/indexes/interval.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
Interval, IntervalMixin, IntervalTree,
3030
intervals_to_interval_bounds)
3131

32+
from pandas.core.indexes.category import CategoricalIndex
3233
from pandas.core.indexes.datetimes import date_range
3334
from pandas.core.indexes.timedeltas import timedelta_range
3435
from pandas.core.indexes.multi import MultiIndex
@@ -632,8 +633,8 @@ def astype(self, dtype, copy=True):
632633
elif is_object_dtype(dtype):
633634
return Index(self.values, dtype=object)
634635
elif is_categorical_dtype(dtype):
635-
from pandas import Categorical
636-
return Categorical(self, ordered=True)
636+
return CategoricalIndex(self.values, name=self.name, dtype=dtype,
637+
copy=copy)
637638
raise ValueError('Cannot cast IntervalIndex to dtype {dtype}'
638639
.format(dtype=dtype))
639640

pandas/core/indexes/multi.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
from pandas.core.dtypes.common import (
1515
_ensure_int64,
1616
_ensure_platform_int,
17+
is_categorical_dtype,
1718
is_object_dtype,
1819
is_iterator,
1920
is_list_like,
21+
pandas_dtype,
2022
is_scalar)
2123
from pandas.core.dtypes.missing import isna, array_equivalent
2224
from pandas.errors import PerformanceWarning, UnsortedIndexError
@@ -2715,9 +2717,14 @@ def difference(self, other):
27152717

27162718
@Appender(_index_shared_docs['astype'])
27172719
def astype(self, dtype, copy=True):
2718-
if not is_object_dtype(np.dtype(dtype)):
2719-
raise TypeError('Setting %s dtype to anything other than object '
2720-
'is not supported' % self.__class__)
2720+
dtype = pandas_dtype(dtype)
2721+
if is_categorical_dtype(dtype):
2722+
msg = '> 1 ndim Categorical are not supported at this time'
2723+
raise NotImplementedError(msg)
2724+
elif not is_object_dtype(dtype):
2725+
msg = ('Setting {cls} dtype to anything other than object '
2726+
'is not supported').format(cls=self.__class__)
2727+
raise TypeError(msg)
27212728
elif copy is True:
27222729
return self._shallow_copy()
27232730
return self

pandas/core/indexes/numeric.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
is_float_dtype,
88
is_object_dtype,
99
is_integer_dtype,
10+
is_categorical_dtype,
1011
is_bool,
1112
is_bool_dtype,
1213
is_scalar)
@@ -16,6 +17,7 @@
1617
from pandas.core import algorithms
1718
from pandas.core.indexes.base import (
1819
Index, InvalidIndexError, _index_shared_docs)
20+
from pandas.core.indexes.category import CategoricalIndex
1921
from pandas.util._decorators import Appender, cache_readonly
2022
import pandas.core.dtypes.concat as _concat
2123
import pandas.core.indexes.base as ibase
@@ -321,10 +323,13 @@ def astype(self, dtype, copy=True):
321323
values = self._values.astype(dtype, copy=copy)
322324
elif is_object_dtype(dtype):
323325
values = self._values.astype('object', copy=copy)
326+
elif is_categorical_dtype(dtype):
327+
return CategoricalIndex(self, name=self.name, dtype=dtype,
328+
copy=copy)
324329
else:
325-
raise TypeError('Setting %s dtype to anything other than '
326-
'float64 or object is not supported' %
327-
self.__class__)
330+
raise TypeError('Setting {cls} dtype to anything other than '
331+
'float64, object, or category is not supported'
332+
.format(cls=self.__class__))
328333
return Index(values, name=self.name, dtype=dtype)
329334

330335
@Appender(_index_shared_docs['_convert_scalar_indexer'])

pandas/core/indexes/period.py

+5
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
is_timedelta64_dtype,
1717
is_period_dtype,
1818
is_bool_dtype,
19+
is_categorical_dtype,
1920
pandas_dtype,
2021
_ensure_object)
2122
from pandas.core.dtypes.dtypes import PeriodDtype
2223
from pandas.core.dtypes.generic import ABCSeries
2324

2425
import pandas.tseries.frequencies as frequencies
2526
from pandas.tseries.frequencies import get_freq_code as _gfc
27+
from pandas.core.indexes.category import CategoricalIndex
2628
from pandas.core.indexes.datetimes import DatetimeIndex, Int64Index, Index
2729
from pandas.core.indexes.timedeltas import TimedeltaIndex
2830
from pandas.core.indexes.datetimelike import DatelikeOps, DatetimeIndexOpsMixin
@@ -517,6 +519,9 @@ def astype(self, dtype, copy=True, how='start'):
517519
return self.to_timestamp(how=how).tz_localize(dtype.tz)
518520
elif is_period_dtype(dtype):
519521
return self.asfreq(freq=dtype.freq)
522+
elif is_categorical_dtype(dtype):
523+
return CategoricalIndex(self.values, name=self.name, dtype=dtype,
524+
copy=copy)
520525
raise TypeError('Cannot cast PeriodIndex to dtype %s' % dtype)
521526

522527
@Substitution(klass='PeriodIndex')

pandas/core/indexes/timedeltas.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@
1212
is_object_dtype,
1313
is_timedelta64_dtype,
1414
is_timedelta64_ns_dtype,
15+
is_categorical_dtype,
16+
pandas_dtype,
1517
_ensure_int64)
1618
from pandas.core.dtypes.missing import isna
1719
from pandas.core.dtypes.generic import ABCSeries
1820
from pandas.core.common import _maybe_box, _values_from_object
1921

2022
from pandas.core.indexes.base import Index
23+
from pandas.core.indexes.category import CategoricalIndex
2124
from pandas.core.indexes.numeric import Int64Index
2225
import pandas.compat as compat
2326
from pandas.compat import u
@@ -479,7 +482,7 @@ def to_pytimedelta(self):
479482

480483
@Appender(_index_shared_docs['astype'])
481484
def astype(self, dtype, copy=True):
482-
dtype = np.dtype(dtype)
485+
dtype = pandas_dtype(dtype)
483486

484487
if is_object_dtype(dtype):
485488
return self._box_values_as_index()
@@ -498,6 +501,9 @@ def astype(self, dtype, copy=True):
498501
elif is_integer_dtype(dtype):
499502
return Index(self.values.astype('i8', copy=copy), dtype='i8',
500503
name=self.name)
504+
elif is_categorical_dtype(dtype):
505+
return CategoricalIndex(self.values, name=self.name, dtype=dtype,
506+
copy=copy)
501507
raise TypeError('Cannot cast TimedeltaIndex to dtype %s' % dtype)
502508

503509
def union(self, other):

pandas/tests/dtypes/test_dtypes.py

+36
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pandas import (
1010
Series, Categorical, CategoricalIndex, IntervalIndex, date_range)
1111

12+
from pandas.compat import string_types
1213
from pandas.core.dtypes.dtypes import (
1314
DatetimeTZDtype, PeriodDtype,
1415
IntervalDtype, CategoricalDtype)
@@ -123,6 +124,41 @@ def test_tuple_categories(self):
123124
result = CategoricalDtype(categories)
124125
assert all(result.categories == categories)
125126

127+
@pytest.mark.parametrize('dtype', [
128+
CategoricalDtype(list('abc'), False),
129+
CategoricalDtype(list('abc'), True)])
130+
@pytest.mark.parametrize('new_dtype', [
131+
'category',
132+
CategoricalDtype(None, False),
133+
CategoricalDtype(None, True),
134+
CategoricalDtype(list('abc'), False),
135+
CategoricalDtype(list('abc'), True),
136+
CategoricalDtype(list('cba'), False),
137+
CategoricalDtype(list('cba'), True),
138+
CategoricalDtype(list('wxyz'), False),
139+
CategoricalDtype(list('wxyz'), True)])
140+
def test_update_dtype(self, dtype, new_dtype):
141+
if isinstance(new_dtype, string_types) and new_dtype == 'category':
142+
expected_categories = dtype.categories
143+
expected_ordered = dtype.ordered
144+
else:
145+
expected_categories = new_dtype.categories
146+
if expected_categories is None:
147+
expected_categories = dtype.categories
148+
expected_ordered = new_dtype.ordered
149+
150+
result = dtype._update_dtype(new_dtype)
151+
tm.assert_index_equal(result.categories, expected_categories)
152+
assert result.ordered is expected_ordered
153+
154+
@pytest.mark.parametrize('bad_dtype', [
155+
'foo', object, np.int64, PeriodDtype('Q'), IntervalDtype(object)])
156+
def test_update_dtype_errors(self, bad_dtype):
157+
dtype = CategoricalDtype(list('abc'), False)
158+
msg = 'a CategoricalDtype must be passed to perform an update, '
159+
with tm.assert_raises_regex(ValueError, msg):
160+
dtype._update_dtype(bad_dtype)
161+
126162

127163
class TestDatetimeTZDtype(Base):
128164

pandas/tests/indexes/common.py

+28
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pandas.core.indexes.base import InvalidIndexError
1414
from pandas.core.indexes.datetimelike import DatetimeIndexOpsMixin
1515
from pandas.core.dtypes.common import needs_i8_conversion
16+
from pandas.core.dtypes.dtypes import CategoricalDtype
1617
from pandas._libs.tslib import iNaT
1718

1819
import pandas.util.testing as tm
@@ -1058,3 +1059,30 @@ def test_putmask_with_wrong_mask(self):
10581059

10591060
with pytest.raises(ValueError):
10601061
index.putmask('foo', 1)
1062+
1063+
@pytest.mark.parametrize('copy', [True, False])
1064+
@pytest.mark.parametrize('name', [None, 'foo'])
1065+
@pytest.mark.parametrize('ordered', [True, False])
1066+
def test_astype_category(self, copy, name, ordered):
1067+
# GH 18630
1068+
index = self.create_index()
1069+
if name:
1070+
index = index.rename(name)
1071+
1072+
# standard categories
1073+
dtype = CategoricalDtype(ordered=ordered)
1074+
result = index.astype(dtype, copy=copy)
1075+
expected = CategoricalIndex(index.values, name=name, ordered=ordered)
1076+
tm.assert_index_equal(result, expected)
1077+
1078+
# non-standard categories
1079+
dtype = CategoricalDtype(index.unique().tolist()[:-1], ordered)
1080+
result = index.astype(dtype, copy=copy)
1081+
expected = CategoricalIndex(index.values, name=name, dtype=dtype)
1082+
tm.assert_index_equal(result, expected)
1083+
1084+
if ordered is False:
1085+
# dtype='category' defaults to ordered=False, so only test once
1086+
result = index.astype('category', copy=copy)
1087+
expected = CategoricalIndex(index.values, name=name)
1088+
tm.assert_index_equal(result, expected)

pandas/tests/indexes/test_category.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -388,9 +388,6 @@ def test_delete(self):
388388
def test_astype(self):
389389

390390
ci = self.create_index()
391-
result = ci.astype('category')
392-
tm.assert_index_equal(result, ci, exact=True)
393-
394391
result = ci.astype(object)
395392
tm.assert_index_equal(result, Index(np.array(ci)))
396393

@@ -414,6 +411,37 @@ def test_astype(self):
414411
result = IntervalIndex.from_intervals(result.values)
415412
tm.assert_index_equal(result, expected)
416413

414+
@pytest.mark.parametrize('copy', [True, False])
415+
@pytest.mark.parametrize('name', [None, 'foo'])
416+
@pytest.mark.parametrize('dtype_ordered', [True, False])
417+
@pytest.mark.parametrize('index_ordered', [True, False])
418+
def test_astype_category(self, copy, name, dtype_ordered, index_ordered):
419+
# GH 18630
420+
index = self.create_index(ordered=index_ordered)
421+
if name:
422+
index = index.rename(name)
423+
424+
# standard categories
425+
dtype = CategoricalDtype(ordered=dtype_ordered)
426+
result = index.astype(dtype, copy=copy)
427+
expected = CategoricalIndex(index.tolist(),
428+
name=name,
429+
categories=index.categories,
430+
ordered=dtype_ordered)
431+
tm.assert_index_equal(result, expected)
432+
433+
# non-standard categories
434+
dtype = CategoricalDtype(index.unique().tolist()[:-1], dtype_ordered)
435+
result = index.astype(dtype, copy=copy)
436+
expected = CategoricalIndex(index.tolist(), name=name, dtype=dtype)
437+
tm.assert_index_equal(result, expected)
438+
439+
if dtype_ordered is False:
440+
# dtype='category' can't specify ordered, so only test once
441+
result = index.astype('category', copy=copy)
442+
expected = index
443+
tm.assert_index_equal(result, expected)
444+
417445
def test_reindex_base(self):
418446
# Determined by cat ordering.
419447
idx = CategoricalIndex(list("cab"), categories=list("cab"))

0 commit comments

Comments
 (0)