Skip to content

Commit ccbaa04

Browse files
committed
ENH: Accept CategoricalDtype in CSV reader
1 parent 507467f commit ccbaa04

File tree

4 files changed

+84
-6
lines changed

4 files changed

+84
-6
lines changed

doc/source/io.rst

+14-1
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,8 @@ Specifying Categorical dtype
452452

453453
.. versionadded:: 0.19.0
454454

455-
``Categorical`` columns can be parsed directly by specifying ``dtype='category'``
455+
``Categorical`` columns can be parsed directly by specifying ``dtype='category'`` or
456+
``dtype=CategoricalDtype(categories, ordered)``.
456457

457458
.. ipython:: python
458459
@@ -468,6 +469,18 @@ Individual columns can be parsed as a ``Categorical`` using a dict specification
468469
469470
pd.read_csv(StringIO(data), dtype={'col1': 'category'}).dtypes
470471
472+
Specifying ``dtype='cateogry'`` will result in a ``Categorical`` that is
473+
unordered, and whose ``categories`` are the unique values observed in the data.
474+
For more control on the categories and order, create a
475+
:class:`~pandas.api.types.CategoricalDtype` ahead of time.
476+
477+
.. ipython:: python
478+
479+
from pandas.api.types import CategoricalDtype
480+
481+
dtype = CategoricalDtype(['d', 'c', 'b', 'a'], ordered=True)
482+
pd.read_csv(StringIO(data), dtype={'col1': dtype}).dtypes
483+
471484
.. note::
472485

473486
The resulting categories will always be parsed as strings (object dtype).

pandas/_libs/parsers.pyx

+19-4
Original file line numberDiff line numberDiff line change
@@ -1267,6 +1267,8 @@ cdef class TextReader:
12671267
return self._string_convert(i, start, end, na_filter,
12681268
na_hashset)
12691269
elif is_categorical_dtype(dtype):
1270+
# TODO: I suspect that this could be optimized when dtype
1271+
# is an instance of CategoricalDtype
12701272
codes, cats, na_count = _categorical_convert(
12711273
self.parser, i, start, end, na_filter,
12721274
na_hashset, self.c_encoding)
@@ -1278,8 +1280,18 @@ cdef class TextReader:
12781280
indexer = cats.get_indexer(unsorted)
12791281
codes = take_1d(indexer, codes, fill_value=-1)
12801282

1281-
return Categorical(codes, categories=cats, ordered=False,
1282-
fastpath=True), na_count
1283+
cat = Categorical(codes, categories=cats, ordered=False,
1284+
fastpath=True)
1285+
1286+
if isinstance(dtype, CategoricalDtype):
1287+
if dtype.categories is None:
1288+
# skip recoding
1289+
if dtype.ordered:
1290+
cat = cat.set_ordered(ordered=dtype.ordered)
1291+
else:
1292+
cat = cat.set_categories(dtype.categories,
1293+
ordered=dtype.ordered)
1294+
return cat, na_count
12831295
elif is_object_dtype(dtype):
12841296
return self._string_convert(i, start, end, na_filter,
12851297
na_hashset)
@@ -2230,8 +2242,11 @@ def _concatenate_chunks(list chunks):
22302242
if common_type == np.object:
22312243
warning_columns.append(str(name))
22322244

2233-
if is_categorical_dtype(dtypes.pop()):
2234-
result[name] = union_categoricals(arrs, sort_categories=True)
2245+
dtype = dtypes.pop()
2246+
if is_categorical_dtype(dtype):
2247+
sort_categories = isinstance(dtype, str)
2248+
result[name] = union_categoricals(arrs,
2249+
sort_categories=sort_categories)
22352250
else:
22362251
result[name] = np.concatenate(arrs)
22372252

pandas/io/parsers.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
is_float, is_dtype_equal,
2222
is_object_dtype, is_string_dtype,
2323
is_scalar, is_categorical_dtype)
24+
from pandas.core.dtypes.dtypes import CategoricalDtype
2425
from pandas.core.dtypes.missing import isna
2526
from pandas.core.dtypes.cast import astype_nansafe
2627
from pandas.core.index import (Index, MultiIndex, RangeIndex,
@@ -1578,7 +1579,11 @@ def _cast_types(self, values, cast_type, column):
15781579
# as strings
15791580
if not is_object_dtype(values):
15801581
values = astype_nansafe(values, str)
1581-
values = Categorical(values)
1582+
if isinstance(cast_type, CategoricalDtype):
1583+
values = Categorical(values, categories=cast_type.categories,
1584+
ordered=cast_type.ordered)
1585+
else:
1586+
values = Categorical(values)
15821587
else:
15831588
try:
15841589
values = astype_nansafe(values, cast_type, copy=True)

pandas/tests/io/parser/dtypes.py

+45
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,51 @@ def test_categorical_dtype_chunksize(self):
149149
for actual, expected in zip(actuals, expecteds):
150150
tm.assert_frame_equal(actual, expected)
151151

152+
@pytest.mark.parametrize('ordered', [False, True])
153+
@pytest.mark.parametrize('categories', [
154+
['a', 'b', 'c'],
155+
['a', 'c', 'b'],
156+
['a', 'b', 'c', 'd'],
157+
])
158+
def test_categorical_categoricaldtype(self, categories, ordered):
159+
data = """a,b
160+
1,a
161+
1,b
162+
1,b
163+
2,c"""
164+
expected = pd.DataFrame({
165+
"a": [1, 1, 1, 2],
166+
"b": Categorical(['a', 'b', 'b', 'c'],
167+
categories=categories,
168+
ordered=ordered)
169+
})
170+
dtype = {"b": CategoricalDtype(categories=categories,
171+
ordered=ordered)}
172+
result = self.read_csv(StringIO(data), dtype=dtype)
173+
tm.assert_frame_equal(result, expected)
174+
175+
def test_categorical_categoricaldtype_chunksize(self):
176+
# GH 10153
177+
data = """a,b
178+
1,a
179+
1,b
180+
1,b
181+
2,c"""
182+
cats = ['a', 'b', 'c']
183+
expecteds = [pd.DataFrame({'a': [1, 1],
184+
'b': Categorical(['a', 'b'],
185+
categories=cats)}),
186+
pd.DataFrame({'a': [1, 2],
187+
'b': Categorical(['b', 'c'],
188+
categories=cats)},
189+
index=[2, 3])]
190+
dtype = CategoricalDtype(cats)
191+
actuals = self.read_csv(StringIO(data), dtype={'b': dtype},
192+
chunksize=2)
193+
194+
for actual, expected in zip(actuals, expecteds):
195+
tm.assert_frame_equal(actual, expected)
196+
152197
def test_empty_pass_dtype(self):
153198
data = 'one,two'
154199
result = self.read_csv(StringIO(data), dtype={'one': 'u1'})

0 commit comments

Comments
 (0)