Skip to content

Commit 373f9d4

Browse files
committed
BUG: Ensure Index.astype('category') returns a CategoricalIndex
1 parent 695e893 commit 373f9d4

File tree

12 files changed

+76
-8
lines changed

12 files changed

+76
-8
lines changed

doc/source/whatsnew/v0.22.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ Conversion
249249
- Fixed a bug where creating a Series from an array that contains both tz-naive and tz-aware values will result in a Series whose dtype is tz-aware instead of object (:issue:`16406`)
250250
- Adding a ``Period`` object to a ``datetime`` or ``Timestamp`` object will now correctly raise a ``TypeError`` (:issue:`17983`)
251251
- Fixed a bug where ``FY5253`` date offsets could incorrectly raise an ``AssertionError`` in arithmetic operatons (:issue:`14774`)
252+
- Bug in :meth:`Index.astype` with a categorical dtype where the resultant index would not be converted to a :class:`CategoricalIndex` for all types of index (:issue:`18630`)
252253

253254

254255
Indexing

pandas/core/dtypes/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1934,7 +1934,7 @@ def pandas_dtype(dtype):
19341934
except TypeError:
19351935
pass
19361936

1937-
elif dtype.startswith('interval[') or dtype.startswith('Interval['):
1937+
elif dtype.startswith('interval') or dtype.startswith('Interval'):
19381938
try:
19391939
return IntervalDtype.construct_from_string(dtype)
19401940
except TypeError:

pandas/core/indexes/base.py

+4
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,10 @@ def _to_embed(self, keep_tz=False, dtype=None):
10531053

10541054
@Appender(_index_shared_docs['astype'])
10551055
def astype(self, dtype, copy=True):
1056+
if is_categorical_dtype(dtype):
1057+
from pandas.core.indexes.category import CategoricalIndex
1058+
return CategoricalIndex(self.values, name=self.name, dtype=dtype,
1059+
copy=copy)
10561060
return Index(self.values.astype(dtype, copy=copy), name=self.name,
10571061
dtype=dtype)
10581062

pandas/core/indexes/category.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
_ensure_platform_int,
1010
is_list_like,
1111
is_interval_dtype,
12-
is_scalar)
12+
is_scalar,
13+
pandas_dtype)
1314
from pandas.core.common import (_asarray_tuplesafe,
1415
_values_from_object)
1516
from pandas.core.dtypes.missing import array_equivalent, isna
@@ -341,9 +342,13 @@ def __array__(self, dtype=None):
341342

342343
@Appender(_index_shared_docs['astype'])
343344
def astype(self, dtype, copy=True):
345+
dtype = pandas_dtype(dtype)
344346
if is_interval_dtype(dtype):
345347
from pandas import IntervalIndex
346348
return IntervalIndex.from_intervals(np.array(self))
349+
elif is_categorical_dtype(dtype) and (dtype == self.dtype):
350+
# fastpath if dtype is the same current
351+
return self.copy() if copy else self
347352
return super(CategoricalIndex, self).astype(dtype=dtype, copy=copy)
348353

349354
@cache_readonly

pandas/core/indexes/datetimes.py

+5
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
is_period_dtype,
2121
is_bool_dtype,
2222
is_string_dtype,
23+
is_categorical_dtype,
2324
is_string_like,
2425
is_list_like,
2526
is_scalar,
@@ -916,6 +917,10 @@ def astype(self, dtype, copy=True):
916917
elif copy is True:
917918
return self.copy()
918919
return self
920+
elif is_categorical_dtype(dtype):
921+
from pandas.core.indexes.category import CategoricalIndex
922+
return CategoricalIndex(self.values, name=self.name, dtype=dtype,
923+
copy=copy)
919924
elif is_string_dtype(dtype):
920925
return Index(self.format(), name=self.name, dtype=object)
921926
elif is_period_dtype(dtype):

pandas/core/indexes/multi.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
is_object_dtype,
1818
is_iterator,
1919
is_list_like,
20+
pandas_dtype,
2021
is_scalar)
2122
from pandas.core.dtypes.missing import isna, array_equivalent
2223
from pandas.errors import PerformanceWarning, UnsortedIndexError
@@ -2715,7 +2716,7 @@ def difference(self, other):
27152716

27162717
@Appender(_index_shared_docs['astype'])
27172718
def astype(self, dtype, copy=True):
2718-
if not is_object_dtype(np.dtype(dtype)):
2719+
if not is_object_dtype(pandas_dtype(dtype)):
27192720
raise TypeError('Setting %s dtype to anything other than object '
27202721
'is not supported' % self.__class__)
27212722
elif copy is True:

pandas/core/indexes/numeric.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
is_float_dtype,
88
is_object_dtype,
99
is_integer_dtype,
10+
is_categorical_dtype,
1011
is_bool,
1112
is_bool_dtype,
1213
is_scalar)
@@ -313,10 +314,14 @@ def astype(self, dtype, copy=True):
313314
values = self._values.astype(dtype, copy=copy)
314315
elif is_object_dtype(dtype):
315316
values = self._values.astype('object', copy=copy)
317+
elif is_categorical_dtype(dtype):
318+
from pandas.core.indexes.category import CategoricalIndex
319+
return CategoricalIndex(self, name=self.name, dtype=dtype,
320+
copy=copy)
316321
else:
317-
raise TypeError('Setting %s dtype to anything other than '
318-
'float64 or object is not supported' %
319-
self.__class__)
322+
raise TypeError('Setting {cls} dtype to anything other than '
323+
'float64, object, or category is not supported'
324+
.format(cls=self.__class__))
320325
return Index(values, name=self.name, dtype=dtype)
321326

322327
@Appender(_index_shared_docs['_convert_scalar_indexer'])

pandas/core/indexes/period.py

+5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
is_timedelta64_dtype,
1717
is_period_dtype,
1818
is_bool_dtype,
19+
is_categorical_dtype,
1920
pandas_dtype,
2021
_ensure_object)
2122
from pandas.core.dtypes.dtypes import PeriodDtype
@@ -517,6 +518,10 @@ def astype(self, dtype, copy=True, how='start'):
517518
return self.to_timestamp(how=how).tz_localize(dtype.tz)
518519
elif is_period_dtype(dtype):
519520
return self.asfreq(freq=dtype.freq)
521+
elif is_categorical_dtype(dtype):
522+
from pandas.core.indexes.category import CategoricalIndex
523+
return CategoricalIndex(self.values, name=self.name, dtype=dtype,
524+
copy=copy)
520525
raise TypeError('Cannot cast PeriodIndex to dtype %s' % dtype)
521526

522527
@Substitution(klass='PeriodIndex')

pandas/core/indexes/timedeltas.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
is_object_dtype,
1313
is_timedelta64_dtype,
1414
is_timedelta64_ns_dtype,
15+
is_categorical_dtype,
16+
pandas_dtype,
1517
_ensure_int64)
1618
from pandas.core.dtypes.missing import isna
1719
from pandas.core.dtypes.generic import ABCSeries
@@ -479,7 +481,7 @@ def to_pytimedelta(self):
479481

480482
@Appender(_index_shared_docs['astype'])
481483
def astype(self, dtype, copy=True):
482-
dtype = np.dtype(dtype)
484+
dtype = pandas_dtype(dtype)
483485

484486
if is_object_dtype(dtype):
485487
return self._box_values_as_index()
@@ -498,6 +500,10 @@ def astype(self, dtype, copy=True):
498500
elif is_integer_dtype(dtype):
499501
return Index(self.values.astype('i8', copy=copy), dtype='i8',
500502
name=self.name)
503+
elif is_categorical_dtype(dtype):
504+
from pandas.core.indexes.category import CategoricalIndex
505+
return CategoricalIndex(self.values, name=self.name, dtype=dtype,
506+
copy=copy)
501507
raise TypeError('Cannot cast TimedeltaIndex to dtype %s' % dtype)
502508

503509
def union(self, other):

pandas/tests/indexes/common.py

+18
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pandas.core.indexes.base import InvalidIndexError
1414
from pandas.core.indexes.datetimelike import DatetimeIndexOpsMixin
1515
from pandas.core.dtypes.common import needs_i8_conversion
16+
from pandas.core.dtypes.dtypes import CategoricalDtype
1617
from pandas._libs.tslib import iNaT
1718

1819
import pandas.util.testing as tm
@@ -1058,3 +1059,20 @@ def test_putmask_with_wrong_mask(self):
10581059

10591060
with pytest.raises(ValueError):
10601061
index.putmask('foo', 1)
1062+
1063+
def test_astype_category(self):
1064+
# GH 18630
1065+
index = self.create_index()
1066+
1067+
expected = CategoricalIndex(index.values)
1068+
result = index.astype('category', copy=True)
1069+
tm.assert_index_equal(result, expected)
1070+
1071+
expected = CategoricalIndex(index.values, name='foo')
1072+
result = index.rename('foo').astype('category', copy=False)
1073+
tm.assert_index_equal(result, expected)
1074+
1075+
dtype = CategoricalDtype(index.unique()[:-1], ordered=True)
1076+
expected = CategoricalIndex(index.values, dtype=dtype)
1077+
result = index.astype(dtype)
1078+
tm.assert_index_equal(result, expected)

pandas/tests/indexes/test_interval.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pandas import (
77
Interval, IntervalIndex, Index, isna, notna, interval_range, Timestamp,
88
Timedelta, compat, date_range, timedelta_range, DateOffset)
9+
from pandas.core.dtypes.dtypes import CategoricalDtype
910
from pandas.compat import lzip
1011
from pandas.tseries.offsets import Day
1112
from pandas._libs.interval import IntervalTree
@@ -362,8 +363,15 @@ def test_astype(self, closed):
362363
tm.assert_index_equal(result, idx)
363364
assert result.equals(idx)
364365

365-
result = idx.astype('category')
366+
def test_astype_category(self, closed):
367+
# GH 18630
368+
idx = self.create_index(closed=closed)
366369
expected = pd.Categorical(idx, ordered=True)
370+
371+
result = idx.astype('category')
372+
tm.assert_categorical_equal(result, expected)
373+
374+
result = idx.astype(CategoricalDtype())
367375
tm.assert_categorical_equal(result, expected)
368376

369377
@pytest.mark.parametrize('klass', [list, tuple, np.array, pd.Series])

pandas/tests/indexes/test_multi.py

+10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
compat, date_range, period_range)
1717
from pandas.compat import PY3, long, lrange, lzip, range, u, PYPY
1818
from pandas.errors import PerformanceWarning, UnsortedIndexError
19+
from pandas.core.dtypes.dtypes import CategoricalDtype
1920
from pandas.core.indexes.base import InvalidIndexError
2021
from pandas._libs import lib
2122
from pandas._libs.lib import Timestamp
@@ -554,6 +555,15 @@ def test_astype(self):
554555
with tm.assert_raises_regex(TypeError, "^Setting.*dtype.*object"):
555556
self.index.astype(np.dtype(int))
556557

558+
def test_astype_category(self):
559+
# GH 18630
560+
msg = 'Setting .* dtype to anything other than object is not supported'
561+
with tm.assert_raises_regex(TypeError, msg):
562+
self.index.astype('category')
563+
564+
with tm.assert_raises_regex(TypeError, msg):
565+
self.index.astype(CategoricalDtype())
566+
557567
def test_constructor_single_level(self):
558568
result = MultiIndex(levels=[['foo', 'bar', 'baz', 'qux']],
559569
labels=[[0, 1, 2, 3]], names=['first'])

0 commit comments

Comments
 (0)