Skip to content

Commit afb801f

Browse files
committed
ENH:Add EA types to read CSV
Closes GH23228
1 parent db2066b commit afb801f

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

pandas/_libs/parsers.pyx

+8-2
Original file line numberDiff line numberDiff line change
@@ -1207,7 +1207,10 @@ cdef class TextReader:
12071207
na_count = 0
12081208

12091209
if result is not None and dtype != 'int64':
1210-
result = result.astype(dtype)
1210+
try:
1211+
result = result.astype(dtype)
1212+
except TypeError:
1213+
result = result.astype(dtype.numpy_dtype)
12111214

12121215
return result, na_count
12131216

@@ -1216,7 +1219,10 @@ cdef class TextReader:
12161219
na_filter, na_hashset, na_flist)
12171220

12181221
if result is not None and dtype != 'float64':
1219-
result = result.astype(dtype)
1222+
try:
1223+
result = result.astype(dtype)
1224+
except TypeError:
1225+
result = result.astype(dtype.numpy_dtype)
12201226
return result, na_count
12211227

12221228
elif is_bool_dtype(dtype):

pandas/core/dtypes/common.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1963,7 +1963,10 @@ def _get_dtype_type(arr_or_dtype):
19631963
try:
19641964
return arr_or_dtype.dtype.type
19651965
except AttributeError:
1966-
return type(None)
1966+
try:
1967+
return arr_or_dtype.numpy_dtype.type
1968+
except AttributeError:
1969+
return type(None)
19671970

19681971

19691972
def _get_dtype_from_object(dtype):

pandas/tests/io/parser/common.py

+22
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from pandas.io.common import URLError
2424
from pandas.io.parsers import TextFileReader, TextParser
25+
from pandas.core.arrays.integer import Int64Dtype
2526

2627

2728
class ParserTests(object):
@@ -1618,3 +1619,24 @@ def test_skip_bad_lines(self):
16181619
val = sys.stderr.getvalue()
16191620
assert 'Skipping line 3' in val
16201621
assert 'Skipping line 5' in val
1622+
1623+
def test_buffer_rd_bytes_bad_unicode(self):
1624+
# Regression test for #22748
1625+
t = BytesIO(b"\xB0")
1626+
if PY3:
1627+
t = TextIOWrapper(t, encoding='ascii', errors='surrogateescape')
1628+
with pytest.raises(UnicodeError):
1629+
pd.read_csv(t, encoding='UTF-8')
1630+
1631+
def test_EA_types(self):
1632+
df = pd.DataFrame({'Int': pd.Series([1, 2, 3], dtype='Int64'),
1633+
'A': [1, 2, 1]})
1634+
data = df.to_csv(index=False)
1635+
result = pd.read_csv(StringIO(data), dtype={'Int': Int64Dtype})
1636+
assert result is not None
1637+
1638+
df = pd.DataFrame({'Int': pd.Series([1, 2, 3], dtype='Int8'),
1639+
'A': [1, 2, 1]})
1640+
data = df.to_csv(index=False)
1641+
result = pd.read_csv(StringIO(data), dtype={'Int': 'Int8'})
1642+
assert result is not None

0 commit comments

Comments
 (0)