Skip to content

Commit f67aa13

Browse files
kpresteljreback
authored andcommitted
ENH:Add EA types to read CSV (#23255)
1 parent 48dad14 commit f67aa13

File tree

16 files changed

+158
-17
lines changed

16 files changed

+158
-17
lines changed

doc/source/io.rst

+6-5
Original file line numberDiff line numberDiff line change
@@ -362,16 +362,17 @@ columns:
362362

363363
.. ipython:: python
364364
365-
data = ('a,b,c\n'
366-
'1,2,3\n'
367-
'4,5,6\n'
368-
'7,8,9')
365+
data = ('a,b,c,d\n'
366+
'1,2,3,4\n'
367+
'5,6,7,8\n'
368+
'9,10,11')
369369
print(data)
370370
371371
df = pd.read_csv(StringIO(data), dtype=object)
372372
df
373373
df['a'][0]
374-
df = pd.read_csv(StringIO(data), dtype={'b': object, 'c': np.float64})
374+
df = pd.read_csv(StringIO(data),
375+
dtype={'b': object, 'c': np.float64, 'd': 'Int64'})
375376
df.dtypes
376377
377378
Fortunately, pandas offers more than one way to ensure that your column(s)

doc/source/whatsnew/v0.24.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ New features
3131
- :func:`read_feather` now accepts ``columns`` as an argument, allowing the user to specify which columns should be read. (:issue:`24025`)
3232
- :func:`DataFrame.to_html` now accepts ``render_links`` as an argument, allowing the user to generate HTML with links to any URLs that appear in the DataFrame.
3333
See the :ref:`section on writing HTML <io.html>` in the IO docs for example usage. (:issue:`2679`)
34+
- :func:`pandas.read_csv` now supports pandas extension types as an argument to ``dtype``, allowing the user to use pandas extension types when reading CSVs. (:issue:`23228`)
3435
- :meth:`DataFrame.shift` :meth:`Series.shift`, :meth:`ExtensionArray.shift`, :meth:`SparseArray.shift`, :meth:`Period.shift`, :meth:`GroupBy.shift`, :meth:`Categorical.shift`, :meth:`NDFrame.shift` and :meth:`Block.shift` now accept `fill_value` as an argument, allowing the user to specify a value which will be used instead of NA/NaT in the empty periods. (:issue:`15486`)
3536

3637
.. _whatsnew_0240.values_api:

pandas/_libs/parsers.pyx

+25-5
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ from pandas.core.dtypes.common import (
5050
is_integer_dtype, is_float_dtype,
5151
is_bool_dtype, is_object_dtype,
5252
is_datetime64_dtype,
53-
pandas_dtype)
53+
pandas_dtype, is_extension_array_dtype)
5454
from pandas.core.arrays import Categorical
5555
from pandas.core.dtypes.concat import union_categoricals
5656
import pandas.io.common as icom
@@ -983,7 +983,6 @@ cdef class TextReader:
983983
footer=footer,
984984
upcast_na=True)
985985
self._end_clock('Type conversion')
986-
987986
self._start_clock()
988987
if len(columns) > 0:
989988
rows_read = len(list(columns.values())[0])
@@ -1123,7 +1122,9 @@ cdef class TextReader:
11231122
if na_filter:
11241123
self._free_na_set(na_hashset)
11251124

1126-
if upcast_na and na_count > 0:
1125+
# don't try to upcast EAs
1126+
try_upcast = upcast_na and na_count > 0
1127+
if try_upcast and not is_extension_array_dtype(col_dtype):
11271128
col_res = _maybe_upcast(col_res)
11281129

11291130
if col_res is None:
@@ -1215,6 +1216,22 @@ cdef class TextReader:
12151216
cats, codes, dtype, true_values=true_values)
12161217
return cat, na_count
12171218

1219+
elif is_extension_array_dtype(dtype):
1220+
result, na_count = self._string_convert(i, start, end, na_filter,
1221+
na_hashset)
1222+
array_type = dtype.construct_array_type()
1223+
try:
1224+
# use _from_sequence_of_strings if the class defines it
1225+
result = array_type._from_sequence_of_strings(result,
1226+
dtype=dtype)
1227+
except NotImplementedError:
1228+
raise NotImplementedError(
1229+
"Extension Array: {ea} must implement "
1230+
"_from_sequence_of_strings in order "
1231+
"to be used in parser methods".format(ea=array_type))
1232+
1233+
return result, na_count
1234+
12181235
elif is_integer_dtype(dtype):
12191236
try:
12201237
result, na_count = _try_int64(self.parser, i, start,
@@ -1240,7 +1257,6 @@ cdef class TextReader:
12401257
if result is not None and dtype != 'float64':
12411258
result = result.astype(dtype)
12421259
return result, na_count
1243-
12441260
elif is_bool_dtype(dtype):
12451261
result, na_count = _try_bool_flex(self.parser, i, start, end,
12461262
na_filter, na_hashset,
@@ -2173,7 +2189,11 @@ def _concatenate_chunks(list chunks):
21732189
result[name] = union_categoricals(arrs,
21742190
sort_categories=sort_categories)
21752191
else:
2176-
result[name] = np.concatenate(arrs)
2192+
if is_extension_array_dtype(dtype):
2193+
array_type = dtype.construct_array_type()
2194+
result[name] = array_type._concat_same_type(arrs)
2195+
else:
2196+
result[name] = np.concatenate(arrs)
21772197

21782198
if warning_columns:
21792199
warning_names = ','.join(warning_columns)

pandas/core/arrays/base.py

+29
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ class ExtensionArray(object):
7878
7979
* _reduce
8080
81+
One can implement methods to handle parsing from strings that will be used
82+
in methods such as ``pandas.io.parsers.read_csv``.
83+
84+
* _from_sequence_of_strings
85+
8186
This class does not inherit from 'abc.ABCMeta' for performance reasons.
8287
Methods and properties required by the interface raise
8388
``pandas.errors.AbstractMethodError`` and no ``register`` method is
@@ -128,6 +133,30 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):
128133
"""
129134
raise AbstractMethodError(cls)
130135

136+
@classmethod
137+
def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
138+
"""Construct a new ExtensionArray from a sequence of strings.
139+
140+
.. versionadded:: 0.24.0
141+
142+
Parameters
143+
----------
144+
strings : Sequence
145+
Each element will be an instance of the scalar type for this
146+
array, ``cls.dtype.type``.
147+
dtype : dtype, optional
148+
Construct for this particular dtype. This should be a Dtype
149+
compatible with the ExtensionArray.
150+
copy : boolean, default False
151+
If True, copy the underlying data.
152+
153+
Returns
154+
-------
155+
ExtensionArray
156+
157+
"""
158+
raise AbstractMethodError(cls)
159+
131160
@classmethod
132161
def _from_factorized(cls, values, original):
133162
"""

pandas/core/arrays/integer.py

+6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from pandas.core import nanops
2121
from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin
22+
from pandas.core.tools.numeric import to_numeric
2223

2324

2425
class _IntegerDtype(ExtensionDtype):
@@ -261,6 +262,11 @@ def __init__(self, values, mask, copy=False):
261262
def _from_sequence(cls, scalars, dtype=None, copy=False):
262263
return integer_array(scalars, dtype=dtype, copy=copy)
263264

265+
@classmethod
266+
def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
267+
scalars = to_numeric(strings, errors="raise")
268+
return cls._from_sequence(scalars, dtype, copy)
269+
264270
@classmethod
265271
def _from_factorized(cls, values, original):
266272
return integer_array(values, dtype=original.dtype)

pandas/io/parsers.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from pandas.core.dtypes.cast import astype_nansafe
2929
from pandas.core.dtypes.common import (
3030
ensure_object, is_bool_dtype, is_categorical_dtype, is_dtype_equal,
31-
is_float, is_integer, is_integer_dtype, is_list_like, is_object_dtype,
32-
is_scalar, is_string_dtype)
31+
is_extension_array_dtype, is_float, is_integer, is_integer_dtype,
32+
is_list_like, is_object_dtype, is_scalar, is_string_dtype, pandas_dtype)
3333
from pandas.core.dtypes.dtypes import CategoricalDtype
3434
from pandas.core.dtypes.missing import isna
3535

@@ -134,7 +134,8 @@
134134
'X'...'X'. Passing in False will cause data to be overwritten if there
135135
are duplicate names in the columns.
136136
dtype : Type name or dict of column -> type, optional
137-
Data type for data or columns. E.g. {{'a': np.float64, 'b': np.int32}}
137+
Data type for data or columns. E.g. {{'a': np.float64, 'b': np.int32,
138+
'c': 'Int64'}}
138139
Use `str` or `object` together with suitable `na_values` settings
139140
to preserve and not interpret dtype.
140141
If converters are specified, they will be applied INSTEAD
@@ -1659,16 +1660,20 @@ def _convert_to_ndarrays(self, dct, na_values, na_fvalues, verbose=False,
16591660
values, set(col_na_values) | col_na_fvalues,
16601661
try_num_bool=False)
16611662
else:
1663+
is_str_or_ea_dtype = (is_string_dtype(cast_type)
1664+
or is_extension_array_dtype(cast_type))
16621665
# skip inference if specified dtype is object
1663-
try_num_bool = not (cast_type and is_string_dtype(cast_type))
1666+
# or casting to an EA
1667+
try_num_bool = not (cast_type and is_str_or_ea_dtype)
16641668

16651669
# general type inference and conversion
16661670
cvals, na_count = self._infer_types(
16671671
values, set(col_na_values) | col_na_fvalues,
16681672
try_num_bool)
16691673

1670-
# type specified in dtype param
1671-
if cast_type and not is_dtype_equal(cvals, cast_type):
1674+
# type specified in dtype param or cast_type is an EA
1675+
if cast_type and (not is_dtype_equal(cvals, cast_type)
1676+
or is_extension_array_dtype(cast_type)):
16721677
try:
16731678
if (is_bool_dtype(cast_type) and
16741679
not is_categorical_dtype(cast_type)
@@ -1765,6 +1770,20 @@ def _cast_types(self, values, cast_type, column):
17651770
cats, cats.get_indexer(values), cast_type,
17661771
true_values=self.true_values)
17671772

1773+
# use the EA's implementation of casting
1774+
elif is_extension_array_dtype(cast_type):
1775+
# ensure cast_type is an actual dtype and not a string
1776+
cast_type = pandas_dtype(cast_type)
1777+
array_type = cast_type.construct_array_type()
1778+
try:
1779+
return array_type._from_sequence_of_strings(values,
1780+
dtype=cast_type)
1781+
except NotImplementedError:
1782+
raise NotImplementedError(
1783+
"Extension Array: {ea} must implement "
1784+
"_from_sequence_of_strings in order "
1785+
"to be used in parser methods".format(ea=array_type))
1786+
17681787
else:
17691788
try:
17701789
values = astype_nansafe(values, cast_type,
@@ -2174,8 +2193,8 @@ def __init__(self, f, **kwds):
21742193

21752194
self.verbose = kwds['verbose']
21762195
self.converters = kwds['converters']
2177-
self.dtype = kwds['dtype']
21782196

2197+
self.dtype = kwds['dtype']
21792198
self.thousands = kwds['thousands']
21802199
self.decimal = kwds['decimal']
21812200

pandas/tests/extension/base/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,4 @@ class TestMyDtype(BaseDtypeTests):
5353
from .missing import BaseMissingTests # noqa
5454
from .reshaping import BaseReshapingTests # noqa
5555
from .setitem import BaseSetitemTests # noqa
56+
from .io import BaseParsingTests # noqa

pandas/tests/extension/base/io.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pandas.compat import StringIO
5+
6+
import pandas as pd
7+
8+
from .base import BaseExtensionTests
9+
10+
11+
class BaseParsingTests(BaseExtensionTests):
12+
13+
@pytest.mark.parametrize('engine', ['c', 'python'])
14+
def test_EA_types(self, engine, data):
15+
df = pd.DataFrame({
16+
'with_dtype': pd.Series(data, dtype=str(data.dtype))
17+
})
18+
csv_output = df.to_csv(index=False, na_rep=np.nan)
19+
result = pd.read_csv(StringIO(csv_output), dtype={
20+
'with_dtype': str(data.dtype)
21+
}, engine=engine)
22+
expected = df
23+
self.assert_frame_equal(result, expected)

pandas/tests/extension/decimal/array.py

+5
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def dtype(self):
7575
def _from_sequence(cls, scalars, dtype=None, copy=False):
7676
return cls(scalars)
7777

78+
@classmethod
79+
def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
80+
return cls._from_sequence([decimal.Decimal(x) for x in strings],
81+
dtype, copy)
82+
7883
@classmethod
7984
def _from_factorized(cls, values, original):
8085
return cls(values)

pandas/tests/extension/test_categorical.py

+4
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,7 @@ def _compare_other(self, s, data, op_name, other):
237237
else:
238238
with pytest.raises(TypeError):
239239
op(data, other)
240+
241+
242+
class TestParsing(base.BaseParsingTests):
243+
pass

pandas/tests/extension/test_integer.py

+4
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,7 @@ class TestBooleanReduce(base.BaseBooleanReduceTests):
218218

219219
class TestPrinting(base.BasePrintingTests):
220220
pass
221+
222+
223+
class TestParsing(base.BaseParsingTests):
224+
pass

pandas/tests/extension/test_interval.py

+8
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,11 @@ class TestPrinting(BaseInterval, base.BasePrintingTests):
152152
@pytest.mark.skip(reason="custom repr")
153153
def test_array_repr(self, data, size):
154154
pass
155+
156+
157+
class TestParsing(BaseInterval, base.BaseParsingTests):
158+
@pytest.mark.parametrize('engine', ['c', 'python'])
159+
def test_EA_types(self, engine, data):
160+
expected_msg = r'.*must implement _from_sequence_of_strings.*'
161+
with pytest.raises(NotImplementedError, match=expected_msg):
162+
super(TestParsing, self).test_EA_types(engine, data)

pandas/tests/extension/test_numpy.py

+4
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,7 @@ def test_concat_mixed_dtypes(self, data):
210210

211211
class TestSetitem(BaseNumPyTests, base.BaseSetitemTests):
212212
pass
213+
214+
215+
class TestParsing(BaseNumPyTests, base.BaseParsingTests):
216+
pass

pandas/tests/extension/test_period.py

+8
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,11 @@ class TestGroupby(BasePeriodTests, base.BaseGroupbyTests):
156156

157157
class TestPrinting(BasePeriodTests, base.BasePrintingTests):
158158
pass
159+
160+
161+
class TestParsing(BasePeriodTests, base.BaseParsingTests):
162+
@pytest.mark.parametrize('engine', ['c', 'python'])
163+
def test_EA_types(self, engine, data):
164+
expected_msg = r'.*must implement _from_sequence_of_strings.*'
165+
with pytest.raises(NotImplementedError, match=expected_msg):
166+
super(TestParsing, self).test_EA_types(engine, data)

pandas/tests/extension/test_sparse.py

+8
Original file line numberDiff line numberDiff line change
@@ -359,3 +359,11 @@ class TestPrinting(BaseSparseTests, base.BasePrintingTests):
359359
@pytest.mark.xfail(reason='Different repr', strict=True)
360360
def test_array_repr(self, data, size):
361361
super(TestPrinting, self).test_array_repr(data, size)
362+
363+
364+
class TestParsing(BaseSparseTests, base.BaseParsingTests):
365+
@pytest.mark.parametrize('engine', ['c', 'python'])
366+
def test_EA_types(self, engine, data):
367+
expected_msg = r'.*must implement _from_sequence_of_strings.*'
368+
with pytest.raises(NotImplementedError, match=expected_msg):
369+
super(TestParsing, self).test_EA_types(engine, data)

pandas/tests/io/parser/common.py

Whitespace-only changes.

0 commit comments

Comments
 (0)