Skip to content

Commit 0151eb0

Browse files
committed
ENH:Add EA types to read CSV
Closes GH23228
1 parent 145c227 commit 0151eb0

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

pandas/_libs/parsers.pyx

+8-2
Original file line numberDiff line numberDiff line change
@@ -1195,7 +1195,10 @@ cdef class TextReader:
11951195
na_count = 0
11961196

11971197
if result is not None and dtype != 'int64':
1198-
result = result.astype(dtype)
1198+
try:
1199+
result = result.astype(dtype)
1200+
except TypeError:
1201+
result = result.astype(dtype.numpy_dtype)
11991202

12001203
return result, na_count
12011204

@@ -1204,7 +1207,10 @@ cdef class TextReader:
12041207
na_filter, na_hashset, na_flist)
12051208

12061209
if result is not None and dtype != 'float64':
1207-
result = result.astype(dtype)
1210+
try:
1211+
result = result.astype(dtype)
1212+
except TypeError:
1213+
result = result.astype(dtype.numpy_dtype)
12081214
return result, na_count
12091215

12101216
elif is_bool_dtype(dtype):

pandas/core/dtypes/common.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1935,7 +1935,10 @@ def _get_dtype_type(arr_or_dtype):
19351935
try:
19361936
return arr_or_dtype.dtype.type
19371937
except AttributeError:
1938-
return type(None)
1938+
try:
1939+
return arr_or_dtype.numpy_dtype.type
1940+
except AttributeError:
1941+
return type(None)
19391942

19401943

19411944
def _get_dtype_from_object(dtype):

pandas/tests/io/parser/common.py

+14
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pandas.errors import DtypeWarning, EmptyDataError, ParserError
2525
from pandas.io.common import URLError
2626
from pandas.io.parsers import TextFileReader, TextParser
27+
from pandas.core.arrays.integer import Int64Dtype
2728

2829

2930
class ParserTests(object):
@@ -1604,3 +1605,16 @@ def test_buffer_rd_bytes_bad_unicode(self):
16041605
t = TextIOWrapper(t, encoding='ascii', errors='surrogateescape')
16051606
with pytest.raises(UnicodeError):
16061607
pd.read_csv(t, encoding='UTF-8')
1608+
1609+
def test_EA_types(self):
1610+
df = pd.DataFrame({'Int': pd.Series([1, 2, 3], dtype='Int64'),
1611+
'A': [1, 2, 1]})
1612+
data = df.to_csv(index=False)
1613+
result = pd.read_csv(StringIO(data), dtype={'Int': Int64Dtype})
1614+
assert result is not None
1615+
1616+
df = pd.DataFrame({'Int': pd.Series([1, 2, 3], dtype='Int8'),
1617+
'A': [1, 2, 1]})
1618+
data = df.to_csv(index=False)
1619+
result = pd.read_csv(StringIO(data), dtype={'Int': 'Int8'})
1620+
assert result is not None

0 commit comments

Comments
 (0)