Skip to content

Commit 416d1d7

Browse files
committed
Consistent CategoricalDtype use in Categorical init
Get a valid instance of `CategoricalDtype` as early as possible, and use that throughout.
1 parent ed5c814 commit 416d1d7

File tree

4 files changed

+112
-15
lines changed

4 files changed

+112
-15
lines changed

pandas/core/categorical.py

+31-15
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,21 @@ class Categorical(PandasObject):
234234
def __init__(self, values, categories=None, ordered=None, dtype=None,
235235
fastpath=False):
236236

237+
# Ways of specifying the dtype (prioritized ordered)
238+
# 1. dtype is a CategoricalDtype
239+
# a.) with known categories, use dtype.categories
240+
# b.) else with Categorical values, use values.dtype
241+
# c.) else, infer from values
242+
# d.) specifying dtype=CategoricalDtype and categories is an error
243+
# 2. dtype is a string 'category'
244+
# a.) use categories, ordered
245+
# b.) use values.dtype
246+
# c.) infer from values
247+
# 3. dtype is None
248+
# a.) use categories, ordered
249+
# b.) use values.dtype
250+
# c.) infer from values
251+
237252
if dtype is not None:
238253
if isinstance(dtype, compat.string_types):
239254
if dtype == 'category':
@@ -247,20 +262,24 @@ def __init__(self, values, categories=None, ordered=None, dtype=None,
247262
categories = dtype.categories
248263
ordered = dtype.ordered
249264

250-
if ordered is None:
251-
ordered = False
265+
elif is_categorical(values):
266+
dtype = values.dtype._from_categorical_dtype(values.dtype,
267+
categories, ordered)
268+
else:
269+
dtype = CategoricalDtype(categories, ordered)
270+
271+
# At this point, dtype is always a CategoricalDtype
272+
# if dtype.categories is None, we are inferring
252273

253274
if fastpath:
254-
if dtype is None:
255-
dtype = CategoricalDtype(categories, ordered)
256275
self._codes = coerce_indexer_dtype(values, categories)
257276
self._dtype = dtype
258277
return
259278

260279
# sanitize input
261280
if is_categorical_dtype(values):
262281

263-
# we are either a Series, CategoricalIndex
282+
# we are either a Series or a CategoricalIndex
264283
if isinstance(values, (ABCSeries, ABCCategoricalIndex)):
265284
values = values._values
266285

@@ -271,6 +290,7 @@ def __init__(self, values, categories=None, ordered=None, dtype=None,
271290
values = values.get_values()
272291

273292
elif isinstance(values, (ABCIndexClass, ABCSeries)):
293+
# we'll do inference later
274294
pass
275295

276296
else:
@@ -288,12 +308,12 @@ def __init__(self, values, categories=None, ordered=None, dtype=None,
288308
# "object" dtype to prevent this. In the end objects will be
289309
# casted to int/... in the category assignment step.
290310
if len(values) == 0 or isna(values).any():
291-
dtype = 'object'
311+
sanitize_dtype = 'object'
292312
else:
293-
dtype = None
294-
values = _sanitize_array(values, None, dtype=dtype)
313+
sanitize_dtype = None
314+
values = _sanitize_array(values, None, dtype=sanitize_dtype)
295315

296-
if categories is None:
316+
if dtype.categories is None:
297317
try:
298318
codes, categories = factorize(values, sort=True)
299319
except TypeError:
@@ -310,7 +330,8 @@ def __init__(self, values, categories=None, ordered=None, dtype=None,
310330
raise NotImplementedError("> 1 ndim Categorical are not "
311331
"supported at this time")
312332

313-
if dtype is None or isinstance(dtype, str):
333+
if dtype.categories is None:
334+
# we're inferring from values
314335
dtype = CategoricalDtype(categories, ordered)
315336

316337
else:
@@ -321,11 +342,6 @@ def __init__(self, values, categories=None, ordered=None, dtype=None,
321342
# - the new one, where each value is also in the categories array
322343
# (or np.nan)
323344

324-
# make sure that we always have the same type here, no matter what
325-
# we get passed in
326-
if dtype is None or isinstance(dtype, str):
327-
dtype = CategoricalDtype(categories, ordered)
328-
329345
codes = _get_codes_for_values(values, dtype.categories)
330346

331347
# TODO: check for old style usage. These warnings should be removes

pandas/core/dtypes/dtypes.py

+13
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,22 @@ def _from_fastpath(cls, categories=None, ordered=False):
160160
self._finalize(categories, ordered, fastpath=True)
161161
return self
162162

163+
@classmethod
164+
def _from_categorical_dtype(cls, dtype, categories=None, ordered=None):
165+
if categories is ordered is None:
166+
return dtype
167+
if categories is None:
168+
categories = dtype.categories
169+
if ordered is None:
170+
ordered = dtype.ordered
171+
return cls(categories, ordered)
172+
163173
def _finalize(self, categories, ordered, fastpath=False):
164174
from pandas.core.indexes.base import Index
165175

176+
if ordered is None:
177+
ordered = False
178+
166179
if categories is not None:
167180
categories = Index(categories, tupleize_cols=False)
168181
# validation

pandas/tests/dtypes/test_dtypes.py

+27
Original file line numberDiff line numberDiff line change
@@ -622,3 +622,30 @@ def test_mixed(self):
622622
a = CategoricalDtype(['a', 'b', 1, 2])
623623
b = CategoricalDtype(['a', 'b', '1', '2'])
624624
assert hash(a) != hash(b)
625+
626+
def test_from_categorical_dtype_identity(self):
627+
c1 = Categorical([1, 2], categories=[1, 2, 3], ordered=True)
628+
# Identity test for no changes
629+
c2 = CategoricalDtype._from_categorical_dtype(c1)
630+
assert c2 is c1
631+
632+
def test_from_categorical_dtype_categories(self):
633+
c1 = Categorical([1, 2], categories=[1, 2, 3], ordered=True)
634+
# override categories
635+
result = CategoricalDtype._from_categorical_dtype(
636+
c1, categories=[2, 3])
637+
assert result == CategoricalDtype([2, 3], ordered=True)
638+
639+
def test_from_categorical_dtype_ordered(self):
640+
c1 = Categorical([1, 2], categories=[1, 2, 3], ordered=True)
641+
# override ordered
642+
result = CategoricalDtype._from_categorical_dtype(
643+
c1, ordered=False)
644+
assert result == CategoricalDtype([1, 2, 3], ordered=False)
645+
646+
def test_from_categorical_dtype_both(self):
647+
c1 = Categorical([1, 2], categories=[1, 2, 3], ordered=True)
648+
# override ordered
649+
result = CategoricalDtype._from_categorical_dtype(
650+
c1, categories=[1, 2], ordered=False)
651+
assert result == CategoricalDtype([1, 2], ordered=False)

pandas/tests/test_categorical.py

+41
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,37 @@ def test_constructor_str_unknown(self):
488488
with tm.assert_raises_regex(ValueError, "Unknown `dtype`"):
489489
Categorical([1, 2], dtype="foo")
490490

491+
def test_constructor_from_categorical_with_dtype(self):
492+
dtype = CategoricalDtype(['a', 'b', 'c'], ordered=True)
493+
values = Categorical(['a', 'b', 'd'])
494+
result = Categorical(values, dtype=dtype)
495+
# We use dtype.categories, not values.categories
496+
expected = Categorical(['a', 'b', 'd'], categories=['a', 'b', 'c'],
497+
ordered=True)
498+
tm.assert_categorical_equal(result, expected)
499+
500+
def test_constructor_from_categorical_with_unknown_dtype(self):
501+
dtype = CategoricalDtype(None, ordered=True)
502+
values = Categorical(['a', 'b', 'd'])
503+
result = Categorical(values, dtype=dtype)
504+
# We use values.categories, not dtype.categories
505+
expected = Categorical(['a', 'b', 'd'], categories=['a', 'b', 'd'],
506+
ordered=True)
507+
tm.assert_categorical_equal(result, expected)
508+
509+
def test_contructor_from_categorical_string(self):
510+
values = Categorical(['a', 'b', 'd'])
511+
# use categories, ordered
512+
result = Categorical(values, categories=['a', 'b', 'c'], ordered=True,
513+
dtype='category')
514+
expected = Categorical(['a', 'b', 'd'], categories=['a', 'b', 'c'],
515+
ordered=True)
516+
tm.assert_categorical_equal(result, expected)
517+
518+
# No string
519+
result = Categorical(values, categories=['a', 'b', 'c'], ordered=True)
520+
tm.assert_categorical_equal(result, expected)
521+
491522
def test_from_codes(self):
492523

493524
# too few categories
@@ -932,6 +963,16 @@ def test_set_dtype_nans(self):
932963
tm.assert_numpy_array_equal(result.codes, np.array([0, -1, -1],
933964
dtype='int8'))
934965

966+
def test_set_categories(self):
967+
cat = Categorical(['a', 'b', 'c'], categories=['a', 'b', 'c', 'd'])
968+
result = cat._set_categories(['a', 'b', 'c', 'd', 'e'])
969+
expected = Categorical(['a', 'b', 'c'], categories=list('abcde'))
970+
tm.assert_categorical_equal(result, expected)
971+
972+
# fastpath
973+
result = cat._set_categories(['a', 'b', 'c', 'd', 'e'], fastpath=True)
974+
tm.assert_categorical_equal(result, expected)
975+
935976
@pytest.mark.parametrize('values, categories, new_categories', [
936977
# No NaNs, same cats, same order
937978
(['a', 'b', 'a'], ['a', 'b'], ['a', 'b'],),

0 commit comments

Comments
 (0)