Skip to content

BUG: Ensure Index.astype('category') returns a CategoricalIndex #18677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.22.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ Conversion
- Fixed a bug where creating a Series from an array that contains both tz-naive and tz-aware values will result in a Series whose dtype is tz-aware instead of object (:issue:`16406`)
- Adding a ``Period`` object to a ``datetime`` or ``Timestamp`` object will now correctly raise a ``TypeError`` (:issue:`17983`)
- Fixed a bug where ``FY5253`` date offsets could incorrectly raise an ``AssertionError`` in arithmetic operatons (:issue:`14774`)
- Bug in :meth:`Index.astype` with a categorical dtype where the resultant index is not converted to a :class:`CategoricalIndex` for all types of index (:issue:`18630`)


Indexing
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1934,7 +1934,7 @@ def pandas_dtype(dtype):
except TypeError:
pass

elif dtype.startswith('interval[') or dtype.startswith('Interval['):
elif dtype.startswith('interval') or dtype.startswith('Interval'):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this because switching to pandas_dtype caused a test to break since it was passing 'interval' as the dtype, which appears to be valid:

In [1]: from pandas.core.dtypes.common import is_interval_dtype

In [2]: is_interval_dtype('interval')
Out[2]: True

try:
return IntervalDtype.construct_from_string(dtype)
except TypeError:
Expand Down
27 changes: 27 additions & 0 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,33 @@ def _validate_categories(categories, fastpath=False):

return categories

def _update_dtype(self, dtype):
"""
Returns a CategoricalDtype with categories and ordered taken from dtype
if specified, otherwise falling back to self if unspecified

Parameters
----------
dtype : CategoricalDtype

Returns
-------
new_dtype : CategoricalDtype
"""
if isinstance(dtype, compat.string_types) and dtype == 'category':
# dtype='category' should not change anything
return self
elif not self.is_dtype(dtype):
msg = ('a CategoricalDtype must be passed to perform an update, '
'got {dtype!r}').format(dtype=dtype)
raise ValueError(msg)

# dtype is CDT: keep current categories if None (ordered can't be None)
new_categories = dtype.categories
if new_categories is None:
new_categories = self.categories
return CategoricalDtype(new_categories, dtype.ordered)

@property
def categories(self):
"""
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,10 @@ def _to_embed(self, keep_tz=False, dtype=None):

@Appender(_index_shared_docs['astype'])
def astype(self, dtype, copy=True):
if is_categorical_dtype(dtype):
from .category import CategoricalIndex
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we have an import issue if we import this at the top (with the fully qualified path)?

return CategoricalIndex(self.values, name=self.name, dtype=dtype,
copy=copy)
return Index(self.values.astype(dtype, copy=copy), name=self.name,
dtype=dtype)

Expand Down
9 changes: 7 additions & 2 deletions pandas/core/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pandas import compat
from pandas.compat.numpy import function as nv
from pandas.core.dtypes.generic import ABCCategorical, ABCSeries
from pandas.core.dtypes.dtypes import CategoricalDtype
from pandas.core.dtypes.common import (
is_categorical_dtype,
_ensure_platform_int,
Expand Down Expand Up @@ -165,8 +166,6 @@ def _create_categorical(self, data, categories=None, ordered=None,
data = Categorical(data, categories=categories, ordered=ordered,
dtype=dtype)
else:
from pandas.core.dtypes.dtypes import CategoricalDtype

if categories is not None:
data = data.set_categories(categories, ordered=ordered)
elif ordered is not None and ordered != data.ordered:
Expand Down Expand Up @@ -344,6 +343,12 @@ def astype(self, dtype, copy=True):
if is_interval_dtype(dtype):
from pandas import IntervalIndex
return IntervalIndex.from_intervals(np.array(self))
elif is_categorical_dtype(dtype):
# GH 18630
dtype = self.dtype._update_dtype(dtype)
if dtype == self.dtype:
return self.copy() if copy else self

return super(CategoricalIndex, self).astype(dtype=dtype, copy=copy)

@cache_readonly
Expand Down
5 changes: 5 additions & 0 deletions pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
is_period_dtype,
is_bool_dtype,
is_string_dtype,
is_categorical_dtype,
is_string_like,
is_list_like,
is_scalar,
Expand All @@ -35,6 +36,7 @@
from pandas.core.algorithms import checked_add_with_arr

from pandas.core.indexes.base import Index, _index_shared_docs
from pandas.core.indexes.category import CategoricalIndex
from pandas.core.indexes.numeric import Int64Index, Float64Index
import pandas.compat as compat
from pandas.tseries.frequencies import (
Expand Down Expand Up @@ -915,6 +917,9 @@ def astype(self, dtype, copy=True):
elif copy is True:
return self.copy()
return self
elif is_categorical_dtype(dtype):
return CategoricalIndex(self.values, name=self.name, dtype=dtype,
copy=copy)
elif is_string_dtype(dtype):
return Index(self.format(), name=self.name, dtype=object)
elif is_period_dtype(dtype):
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Interval, IntervalMixin, IntervalTree,
intervals_to_interval_bounds)

from pandas.core.indexes.category import CategoricalIndex
from pandas.core.indexes.datetimes import date_range
from pandas.core.indexes.timedeltas import timedelta_range
from pandas.core.indexes.multi import MultiIndex
Expand Down Expand Up @@ -632,8 +633,8 @@ def astype(self, dtype, copy=True):
elif is_object_dtype(dtype):
return Index(self.values, dtype=object)
elif is_categorical_dtype(dtype):
from pandas import Categorical
return Categorical(self, ordered=True)
return CategoricalIndex(self.values, name=self.name, dtype=dtype,
copy=copy)
raise ValueError('Cannot cast IntervalIndex to dtype {dtype}'
.format(dtype=dtype))

Expand Down
13 changes: 10 additions & 3 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
from pandas.core.dtypes.common import (
_ensure_int64,
_ensure_platform_int,
is_categorical_dtype,
is_object_dtype,
is_iterator,
is_list_like,
pandas_dtype,
is_scalar)
from pandas.core.dtypes.missing import isna, array_equivalent
from pandas.errors import PerformanceWarning, UnsortedIndexError
Expand Down Expand Up @@ -2715,9 +2717,14 @@ def difference(self, other):

@Appender(_index_shared_docs['astype'])
def astype(self, dtype, copy=True):
if not is_object_dtype(np.dtype(dtype)):
raise TypeError('Setting %s dtype to anything other than object '
'is not supported' % self.__class__)
dtype = pandas_dtype(dtype)
if is_categorical_dtype(dtype):
msg = '> 1 ndim Categorical are not supported at this time'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test fo this?

Copy link
Member Author

@jschendel jschendel Dec 9, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise NotImplementedError(msg)
elif not is_object_dtype(dtype):
msg = ('Setting {cls} dtype to anything other than object '
'is not supported').format(cls=self.__class__)
raise TypeError(msg)
elif copy is True:
return self._shallow_copy()
return self
Expand Down
11 changes: 8 additions & 3 deletions pandas/core/indexes/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
is_float_dtype,
is_object_dtype,
is_integer_dtype,
is_categorical_dtype,
is_bool,
is_bool_dtype,
is_scalar)
Expand All @@ -16,6 +17,7 @@
from pandas.core import algorithms
from pandas.core.indexes.base import (
Index, InvalidIndexError, _index_shared_docs)
from pandas.core.indexes.category import CategoricalIndex
from pandas.util._decorators import Appender, cache_readonly
import pandas.core.dtypes.concat as _concat
import pandas.core.indexes.base as ibase
Expand Down Expand Up @@ -321,10 +323,13 @@ def astype(self, dtype, copy=True):
values = self._values.astype(dtype, copy=copy)
elif is_object_dtype(dtype):
values = self._values.astype('object', copy=copy)
elif is_categorical_dtype(dtype):
return CategoricalIndex(self, name=self.name, dtype=dtype,
copy=copy)
else:
raise TypeError('Setting %s dtype to anything other than '
'float64 or object is not supported' %
self.__class__)
raise TypeError('Setting {cls} dtype to anything other than '
'float64, object, or category is not supported'
.format(cls=self.__class__))
return Index(values, name=self.name, dtype=dtype)

@Appender(_index_shared_docs['_convert_scalar_indexer'])
Expand Down
5 changes: 5 additions & 0 deletions pandas/core/indexes/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
is_timedelta64_dtype,
is_period_dtype,
is_bool_dtype,
is_categorical_dtype,
pandas_dtype,
_ensure_object)
from pandas.core.dtypes.dtypes import PeriodDtype
from pandas.core.dtypes.generic import ABCSeries

import pandas.tseries.frequencies as frequencies
from pandas.tseries.frequencies import get_freq_code as _gfc
from pandas.core.indexes.category import CategoricalIndex
from pandas.core.indexes.datetimes import DatetimeIndex, Int64Index, Index
from pandas.core.indexes.timedeltas import TimedeltaIndex
from pandas.core.indexes.datetimelike import DatelikeOps, DatetimeIndexOpsMixin
Expand Down Expand Up @@ -517,6 +519,9 @@ def astype(self, dtype, copy=True, how='start'):
return self.to_timestamp(how=how).tz_localize(dtype.tz)
elif is_period_dtype(dtype):
return self.asfreq(freq=dtype.freq)
elif is_categorical_dtype(dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

side thing, I think that we could make a more generic astype in indexes.base and remove some boiler plate maybe (of course separate PR), you can make an issue if you want (or just PR!)

return CategoricalIndex(self.values, name=self.name, dtype=dtype,
copy=copy)
raise TypeError('Cannot cast PeriodIndex to dtype %s' % dtype)

@Substitution(klass='PeriodIndex')
Expand Down
8 changes: 7 additions & 1 deletion pandas/core/indexes/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
is_object_dtype,
is_timedelta64_dtype,
is_timedelta64_ns_dtype,
is_categorical_dtype,
pandas_dtype,
_ensure_int64)
from pandas.core.dtypes.missing import isna
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.common import _maybe_box, _values_from_object

from pandas.core.indexes.base import Index
from pandas.core.indexes.category import CategoricalIndex
from pandas.core.indexes.numeric import Int64Index
import pandas.compat as compat
from pandas.compat import u
Expand Down Expand Up @@ -479,7 +482,7 @@ def to_pytimedelta(self):

@Appender(_index_shared_docs['astype'])
def astype(self, dtype, copy=True):
dtype = np.dtype(dtype)
dtype = pandas_dtype(dtype)

if is_object_dtype(dtype):
return self._box_values_as_index()
Expand All @@ -498,6 +501,9 @@ def astype(self, dtype, copy=True):
elif is_integer_dtype(dtype):
return Index(self.values.astype('i8', copy=copy), dtype='i8',
name=self.name)
elif is_categorical_dtype(dtype):
return CategoricalIndex(self.values, name=self.name, dtype=dtype,
copy=copy)
raise TypeError('Cannot cast TimedeltaIndex to dtype %s' % dtype)

def union(self, other):
Expand Down
36 changes: 36 additions & 0 deletions pandas/tests/dtypes/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pandas import (
Series, Categorical, CategoricalIndex, IntervalIndex, date_range)

from pandas.compat import string_types
from pandas.core.dtypes.dtypes import (
DatetimeTZDtype, PeriodDtype,
IntervalDtype, CategoricalDtype)
Expand Down Expand Up @@ -123,6 +124,41 @@ def test_tuple_categories(self):
result = CategoricalDtype(categories)
assert all(result.categories == categories)

@pytest.mark.parametrize('dtype', [
CategoricalDtype(list('abc'), False),
CategoricalDtype(list('abc'), True)])
@pytest.mark.parametrize('new_dtype', [
'category',
CategoricalDtype(None, False),
CategoricalDtype(None, True),
CategoricalDtype(list('abc'), False),
CategoricalDtype(list('abc'), True),
CategoricalDtype(list('cba'), False),
CategoricalDtype(list('cba'), True),
CategoricalDtype(list('wxyz'), False),
CategoricalDtype(list('wxyz'), True)])
def test_update_dtype(self, dtype, new_dtype):
if isinstance(new_dtype, string_types) and new_dtype == 'category':
expected_categories = dtype.categories
expected_ordered = dtype.ordered
else:
expected_categories = new_dtype.categories
if expected_categories is None:
expected_categories = dtype.categories
expected_ordered = new_dtype.ordered

result = dtype._update_dtype(new_dtype)
tm.assert_index_equal(result.categories, expected_categories)
assert result.ordered is expected_ordered

@pytest.mark.parametrize('bad_dtype', [
'foo', object, np.int64, PeriodDtype('Q'), IntervalDtype(object)])
def test_update_dtype_errors(self, bad_dtype):
dtype = CategoricalDtype(list('abc'), False)
msg = 'a CategoricalDtype must be passed to perform an update, '
with tm.assert_raises_regex(ValueError, msg):
dtype._update_dtype(bad_dtype)


class TestDatetimeTZDtype(Base):

Expand Down
28 changes: 28 additions & 0 deletions pandas/tests/indexes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pandas.core.indexes.base import InvalidIndexError
from pandas.core.indexes.datetimelike import DatetimeIndexOpsMixin
from pandas.core.dtypes.common import needs_i8_conversion
from pandas.core.dtypes.dtypes import CategoricalDtype
from pandas._libs.tslib import iNaT

import pandas.util.testing as tm
Expand Down Expand Up @@ -1058,3 +1059,30 @@ def test_putmask_with_wrong_mask(self):

with pytest.raises(ValueError):
index.putmask('foo', 1)

@pytest.mark.parametrize('copy', [True, False])
@pytest.mark.parametrize('name', [None, 'foo'])
@pytest.mark.parametrize('ordered', [True, False])
def test_astype_category(self, copy, name, ordered):
# GH 18630
index = self.create_index()
if name:
index = index.rename(name)

# standard categories
dtype = CategoricalDtype(ordered=ordered)
result = index.astype(dtype, copy=copy)
expected = CategoricalIndex(index.values, name=name, ordered=ordered)
tm.assert_index_equal(result, expected)

# non-standard categories
dtype = CategoricalDtype(index.unique().tolist()[:-1], ordered)
result = index.astype(dtype, copy=copy)
expected = CategoricalIndex(index.values, name=name, dtype=dtype)
tm.assert_index_equal(result, expected)

if ordered is False:
# dtype='category' defaults to ordered=False, so only test once
result = index.astype('category', copy=copy)
expected = CategoricalIndex(index.values, name=name)
tm.assert_index_equal(result, expected)
34 changes: 31 additions & 3 deletions pandas/tests/indexes/test_category.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,6 @@ def test_delete(self):
def test_astype(self):

ci = self.create_index()
result = ci.astype('category')
tm.assert_index_equal(result, ci, exact=True)

result = ci.astype(object)
tm.assert_index_equal(result, Index(np.array(ci)))

Expand All @@ -414,6 +411,37 @@ def test_astype(self):
result = IntervalIndex.from_intervals(result.values)
tm.assert_index_equal(result, expected)

@pytest.mark.parametrize('copy', [True, False])
@pytest.mark.parametrize('name', [None, 'foo'])
@pytest.mark.parametrize('dtype_ordered', [True, False])
@pytest.mark.parametrize('index_ordered', [True, False])
def test_astype_category(self, copy, name, dtype_ordered, index_ordered):
# GH 18630
index = self.create_index(ordered=index_ordered)
if name:
index = index.rename(name)

# standard categories
dtype = CategoricalDtype(ordered=dtype_ordered)
result = index.astype(dtype, copy=copy)
expected = CategoricalIndex(index.tolist(),
name=name,
categories=index.categories,
ordered=dtype_ordered)
tm.assert_index_equal(result, expected)

# non-standard categories
dtype = CategoricalDtype(index.unique().tolist()[:-1], dtype_ordered)
result = index.astype(dtype, copy=copy)
expected = CategoricalIndex(index.tolist(), name=name, dtype=dtype)
tm.assert_index_equal(result, expected)

if dtype_ordered is False:
# dtype='category' can't specify ordered, so only test once
result = index.astype('category', copy=copy)
expected = index
tm.assert_index_equal(result, expected)

def test_reindex_base(self):
# Determined by cat ordering.
idx = CategoricalIndex(list("cab"), categories=list("cab"))
Expand Down
Loading