Skip to content

Commit 2711623

Browse files
committed
ENH: Added parametrized CategoricalDtype
We extended the CategoricalDtype to accept optional `categories` and `ordered` argument. CategoricalDtype is now part of the public API. This allows users to specify the desired categories and orderedness of an operation ahead of time. The current behavior, which is still possible with `categories=None`, the default, is to infer the categories from whatever is present. This change will make it easy to implement support for specifying categories that are know ahead of time in other places e.g. `.astype`, `.read_csv`, and the `Series` constructor.
1 parent 3119e90 commit 2711623

File tree

6 files changed

+182
-16
lines changed

6 files changed

+182
-16
lines changed

doc/source/categorical.rst

+22-2
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,14 @@ By passing a :class:`pandas.Categorical` object to a `Series` or assigning it to
9696
df["B"] = raw_cat
9797
df
9898
99-
You can also specify differently ordered categories or make the resulting data ordered, by passing these arguments to ``astype()``:
99+
You can also specify differently ordered categories or make the resulting data
100+
ordered by passing a :class:`CategoricalDtype`:
100101

101102
.. ipython:: python
102103
103104
s = pd.Series(["a","b","c","a"])
104-
s_cat = s.astype("category", categories=["b","c","d"], ordered=False)
105+
cat_type = pd.CategoricalDtype(categories=["b", "c", "d"], ordered=False)
106+
s_cat = s.astype(cat_type)
105107
s_cat
106108
107109
Categorical data has a specific ``category`` :ref:`dtype <basics.dtypes>`:
@@ -140,6 +142,24 @@ constructor to save the factorize step during normal constructor mode:
140142
splitter = np.random.choice([0,1], 5, p=[0.5,0.5])
141143
s = pd.Series(pd.Categorical.from_codes(splitter, categories=["train", "test"]))
142144
145+
146+
CategoricalDtype
147+
----------------
148+
149+
A categorical's type is fully described by 1.) its categories (an iterable with
150+
unique values and no missing values), and 2.) its orderedness (a boolean).
151+
This information can be stored in a :class:`~pandas.CategoricalDtype`.
152+
The ``categories`` argument is optional, which implies that the actual categories
153+
should be inferred from whatever is present in the data.
154+
155+
A :class:`~pandas.CategoricalDtype` can be used in any place pandas expects a
156+
`dtype`. For example :func:`pandas.read_csv`, :func:`pandas.DataFrame.astype`,
157+
the Series constructor, etc.
158+
159+
As a convenience, you can use the string `'category'` in place of a
160+
:class:`pandas.CategoricalDtype` when you want the default behavior of
161+
the categories being unordered, and equal to the set values present in the array.
162+
143163
Description
144164
-----------
145165

pandas/core/api.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from pandas.core.algorithms import factorize, unique, value_counts
88
from pandas.core.dtypes.missing import isnull, notnull
9+
from pandas.core.dtypes.dtypes import CategoricalDtype
910
from pandas.core.categorical import Categorical
1011
from pandas.core.groupby import Grouper
1112
from pandas.formats.format import set_eng_float_format

pandas/core/dtypes/common.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,9 @@ def _coerce_to_dtype(dtype):
501501
"""
502502

503503
if is_categorical_dtype(dtype):
504-
dtype = CategoricalDtype()
504+
categories = getattr(dtype, 'categories', None)
505+
ordered = getattr(dtype, 'ordered', False)
506+
dtype = CategoricalDtype(categories=categories, ordered=ordered)
505507
elif is_datetime64tz_dtype(dtype):
506508
dtype = DatetimeTZDtype(dtype)
507509
elif is_period_dtype(dtype):

pandas/core/dtypes/dtypes.py

+90-12
Original file line numberDiff line numberDiff line change
@@ -100,26 +100,66 @@ class CategoricalDtypeType(type):
100100
class CategoricalDtype(ExtensionDtype):
101101

102102
"""
103-
A np.dtype duck-typed class, suitable for holding a custom categorical
104-
dtype.
105-
106-
THIS IS NOT A REAL NUMPY DTYPE, but essentially a sub-class of np.object
103+
Type for categorical data with the categories and orderedness,
104+
but not the values.
105+
106+
.. versionadded:: 0.20.0
107+
108+
Parameters
109+
----------
110+
categories : list or None
111+
ordered : bool, default False
112+
113+
Notes
114+
-----
115+
An instance of ``CategoricalDtype`` compares equal with any other
116+
instance of ``CategoricalDtype``, regardless of categories or ordered.
117+
In addition they compare equal to the string ``'category'``.
118+
119+
To check whether two instances of a ``CategoricalDtype`` exactly,
120+
use the ``is`` operator.
121+
122+
>>> t1 = CategoricalDtype(['a', 'b'], ordered=True)
123+
>>> t2 = CategoricalDtype(['a', 'c'], ordered=False)
124+
>>> t1 == t2
125+
True
126+
>>> t1 == 'category'
127+
True
128+
>>> t1 is t2
129+
False
130+
>>> t1 is CategoricalDtype(['a', 'b'], ordered=True)
131+
True
132+
133+
Examples
134+
--------
135+
>>> t = CategoricalDtype(categories=['b', 'a'], ordered=True)
136+
>>> s = Series(['a', 'a', 'b', 'b', 'a'])
137+
>>> s.astype(t)
138+
0 a
139+
1 a
140+
2 b
141+
3 b
142+
4 a
143+
dtype: category
144+
Categories (2, object): [b < a]
107145
"""
146+
# TODO: Document public vs. private API
108147
name = 'category'
109148
type = CategoricalDtypeType
110149
kind = 'O'
111150
str = '|O08'
112151
base = np.dtype('O')
113152
_cache = {}
114153

115-
def __new__(cls):
116-
117-
try:
118-
return cls._cache[cls.name]
119-
except KeyError:
120-
c = object.__new__(cls)
121-
cls._cache[cls.name] = c
122-
return c
154+
def __new__(cls, categories=None, ordered=False):
155+
from pandas.indexes.base import Index
156+
if categories is not None:
157+
categories = Index(categories)
158+
cls._validate_categories(categories)
159+
hashed = cls._hash_categories(categories)
160+
else:
161+
hashed = None
162+
return cls._get_or_create(categories, ordered, hashed)
123163

124164
def __hash__(self):
125165
# make myself hashable
@@ -131,6 +171,33 @@ def __eq__(self, other):
131171

132172
return isinstance(other, CategoricalDtype)
133173

174+
@staticmethod
175+
def _hash_categories(categories):
176+
from pandas.tools.hashing import hash_array, _combine_hash_arrays
177+
cat_array = np.asarray(categories)
178+
hashed = _combine_hash_arrays(
179+
iter([hash_array(cat_array)]),
180+
num_items=1
181+
)
182+
hashed = np.bitwise_xor.reduce(hashed)
183+
return hashed
184+
185+
@classmethod
186+
def _get_or_create(cls, categories, ordered, hashed):
187+
188+
try:
189+
return cls._cache[(hashed, ordered)]
190+
except KeyError:
191+
categorical = object.__new__(cls)
192+
categorical.categories = categories
193+
categorical.ordered = ordered
194+
cls._cache[(hashed, ordered)] = categorical
195+
return categorical
196+
197+
def __unicode__(self):
198+
tpl = 'CategoricalDtype({!r}, ordered={})'
199+
return tpl.format(self.categories, self.ordered)
200+
134201
@classmethod
135202
def construct_from_string(cls, string):
136203
""" attempt to construct this type from a string, raise a TypeError if
@@ -143,6 +210,17 @@ def construct_from_string(cls, string):
143210

144211
raise TypeError("cannot construct a CategoricalDtype")
145212

213+
@staticmethod
214+
def _validate_categories(categories):
215+
from pandas import isnull
216+
217+
if not len(categories) == len(set(categories)):
218+
raise ValueError("`categories` must be unique.")
219+
if isnull(categories).any():
220+
raise ValueError("`categories` can not contain any nulls")
221+
222+
return True
223+
146224

147225
class DatetimeTZDtypeType(type):
148226
"""

pandas/indexes/category.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _format_attrs(self):
229229
('ordered', self.ordered)]
230230
if self.name is not None:
231231
attrs.append(('name', ibase.default_pprint(self.name)))
232-
attrs.append(('dtype', "'%s'" % self.dtype))
232+
attrs.append(('dtype', "'%s'" % self.dtype.name))
233233
max_seq_items = get_option('display.max_seq_items') or len(self)
234234
if len(self) > max_seq_items:
235235
attrs.append(('length', len(self)))

pandas/tests/core/dtypes/test_dtypes.py

+65
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
from itertools import product
33

4+
import pytest
45
import numpy as np
56
import pandas as pd
67
from pandas import Series, Categorical, IntervalIndex, date_range
@@ -356,6 +357,70 @@ def test_not_string(self):
356357
self.assertFalse(is_string_dtype(PeriodDtype('D')))
357358

358359

360+
class TestCategoricalDtypeParametrized(object):
361+
362+
@pytest.mark.parametrize('categories, ordered', [
363+
(['a', 'b', 'c', 'd'], False),
364+
(['a', 'b', 'c', 'd'], True),
365+
(np.arange(1000), False),
366+
(np.arange(1000), True),
367+
(['a', 'b', 10, 2, 1.3, True], False),
368+
([True, False], True),
369+
([True, False], False),
370+
(pd.date_range('2017', periods=4), True),
371+
(pd.date_range('2017', periods=4), False),
372+
])
373+
def test_basic(self, categories, ordered):
374+
c1 = CategoricalDtype(categories, ordered=ordered)
375+
tm.assert_index_equal(c1.categories, pd.Index(categories))
376+
assert c1.ordered is ordered
377+
378+
@pytest.mark.parametrize('ordered', [True, False])
379+
def test_is_singleton(self, ordered):
380+
c1 = CategoricalDtype(['a', 'b', 'c'], ordered=ordered)
381+
c2 = CategoricalDtype(['a', 'b', 'c'], ordered=ordered)
382+
assert c1 is c2
383+
384+
def test_order_matters(self):
385+
categories = ['a', 'b']
386+
c1 = CategoricalDtype(categories, ordered=False)
387+
c2 = CategoricalDtype(categories, ordered=True)
388+
assert c1 is not c2
389+
390+
def test_unordered_same(self):
391+
c1 = CategoricalDtype(['a', 'b'])
392+
c2 = CategoricalDtype(['b', 'a'])
393+
assert c1 is c2
394+
tm.assert_index_equal(c1.categories, c2.categories)
395+
396+
def test_categories(self):
397+
result = CategoricalDtype(['a', 'b', 'c'])
398+
tm.assert_index_equal(result.categories, pd.Index(['a', 'b', 'c']))
399+
assert result.ordered is False
400+
401+
def test_equal_but_different(self):
402+
c1 = CategoricalDtype([1, 2, 3])
403+
c2 = CategoricalDtype([1., 2., 3.])
404+
assert c1 is not c2
405+
406+
@pytest.mark.parametrize('v1, v2', [
407+
([1, 2, 3], [1, 2, 3]),
408+
([1, 2, 3], [3, 2, 1]),
409+
])
410+
def test_order_hashes_different(self, v1, v2):
411+
c1 = CategoricalDtype(v1)
412+
c2 = CategoricalDtype(v2, ordered=True)
413+
assert c1 is not c2
414+
415+
def test_nan_invalid(self):
416+
with pytest.raises(ValueError):
417+
pd.CategoricalDtype([1, 2, np.nan])
418+
419+
def test_non_unique_invalid(self):
420+
with pytest.raises(ValueError):
421+
pd.CategoricalDtype([1, 2, 1])
422+
423+
359424
class TestIntervalDtype(Base, tm.TestCase):
360425

361426
# TODO: placeholder

0 commit comments

Comments
 (0)