Skip to content

Commit 5a7c498

Browse files
committed
ENH: Accept CategoricalDtype in astype and Series
1 parent 2711623 commit 5a7c498

File tree

4 files changed

+69
-2
lines changed

4 files changed

+69
-2
lines changed

pandas/core/internals.py

+10
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,16 @@ def _astype(self, dtype, copy=False, errors='raise', values=None,
487487
# may need to convert to categorical
488488
# this is only called for non-categoricals
489489
if self.is_categorical_astype(dtype):
490+
if (('categories' in kwargs or 'ordered' in kwargs) and
491+
isinstance(dtype, CategoricalDtype)):
492+
raise TypeError("Cannot specify a CategoricalDtype and also "
493+
"`categories` or `ordered`")
494+
kwargs = kwargs.copy()
495+
categories = getattr(dtype, 'categories', None)
496+
ordered = getattr(dtype, 'ordered', False)
497+
498+
kwargs.setdefault('categories', categories)
499+
kwargs.setdefault('ordered', ordered)
490500
return self.make_block(Categorical(self.values, **kwargs))
491501

492502
# astype processing

pandas/core/series.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2908,7 +2908,8 @@ def _try_cast(arr, take_fast_path):
29082908
subarr = np.array(subarr, dtype=dtype, copy=copy)
29092909
except (ValueError, TypeError):
29102910
if is_categorical_dtype(dtype):
2911-
subarr = Categorical(arr)
2911+
subarr = Categorical(arr, dtype.categories,
2912+
ordered=dtype.ordered)
29122913
elif dtype is not None and raise_cast_failure:
29132914
raise
29142915
else:

pandas/tests/series/test_constructors.py

+25
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from datetime import datetime, timedelta
55

6+
import pytest
67
from numpy import nan
78
import numpy as np
89
import numpy.ma as ma
@@ -148,6 +149,30 @@ def test_constructor_categorical(self):
148149
self.assertTrue(is_categorical_dtype(s))
149150
self.assertTrue(is_categorical_dtype(s.dtype))
150151

152+
def test_constructor_categorical_dtype(self):
153+
result = pd.Series(['a', 'b'],
154+
dtype=pd.CategoricalDtype(['a', 'b', 'c'],
155+
ordered=True))
156+
assert is_categorical_dtype(result) is True
157+
tm.assert_index_equal(result.cat.categories, pd.Index(['a', 'b', 'c']))
158+
assert result.cat.ordered
159+
160+
result = pd.Series(['a', 'b'], dtype=pd.CategoricalDtype(['b', 'a']))
161+
assert is_categorical_dtype(result)
162+
tm.assert_index_equal(result.cat.categories, pd.Index(['b', 'a']))
163+
assert result.cat.ordered is False
164+
165+
@pytest.mark.xfail
166+
def test_unordered_compare_equal(self):
167+
result = pd.Series(['a', 'b', 'c'],
168+
dtype=pd.CategoricalDtype(['a', 'b']))
169+
# TODO: is this a bug? Shouldn't unorderd categories not care about
170+
# order in the comparison?
171+
# https://github.com/pandas-dev/pandas/issues/16014
172+
expected = pd.Series(pd.Categorical(['a', 'b', np.nan],
173+
categories=['a', 'b']))
174+
tm.assert_series_equal(result, expected)
175+
151176
def test_constructor_maskedarray(self):
152177
data = ma.masked_all((3, ), dtype=float)
153178
result = Series(data)

pandas/tests/series/test_dtypes.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from numpy import nan
1313
import numpy as np
1414

15-
from pandas import Series, Timestamp, Timedelta, DataFrame, date_range
15+
from pandas import (
16+
Series, Timestamp, Timedelta, DataFrame, date_range,
17+
Categorical, CategoricalDtype, Index
18+
)
1619

1720
from pandas.compat import lrange, range, u
1821
from pandas import compat
@@ -171,6 +174,34 @@ def test_astype_dict(self):
171174
with pytest.raises(KeyError):
172175
s.astype({0: str})
173176

177+
def test_astype_categoricaldtype(self):
178+
s = Series(['a', 'b', 'a'])
179+
result = s.astype(CategoricalDtype(['a', 'b'], ordered=True))
180+
expected = Series(Categorical(['a', 'b', 'a'], ordered=True))
181+
tm.assert_series_equal(result, expected)
182+
183+
result = s.astype(CategoricalDtype(['a', 'b'], ordered=False))
184+
expected = Series(Categorical(['a', 'b', 'a'], ordered=False))
185+
tm.assert_series_equal(result, expected)
186+
187+
result = s.astype(CategoricalDtype(['a', 'b', 'c'], ordered=False))
188+
expected = Series(Categorical(['a', 'b', 'a'],
189+
categories=['a', 'b', 'c'],
190+
ordered=False))
191+
tm.assert_series_equal(result, expected)
192+
tm.assert_index_equal(result.cat.categories, Index(['a', 'b', 'c']))
193+
194+
def test_astype_categoricaldtype_with_args(self):
195+
s = Series(['a', 'b'])
196+
type_ = CategoricalDtype(['a', 'b'])
197+
198+
with pytest.raises(TypeError):
199+
s.astype(type_, ordered=True)
200+
with pytest.raises(TypeError):
201+
s.astype(type_, categories=['a', 'b'])
202+
with pytest.raises(TypeError):
203+
s.astype(type_, categories=['a', 'b'], ordered=False)
204+
174205
def test_astype_generic_timestamp_deprecated(self):
175206
# see gh-15524
176207
data = [1]

0 commit comments

Comments
 (0)