Skip to content

Commit 22f2fa1

Browse files
committed
WIP:Make python engine support EA types when reading CSVs
The C engine is the real WIP.
1 parent bd0b9b7 commit 22f2fa1

File tree

10 files changed

+121
-15
lines changed

10 files changed

+121
-15
lines changed

pandas/_libs/parsers.pyx

+10-2
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,11 @@ cdef class TextReader:
12091209

12101210
if result is not None and dtype != 'int64':
12111211
if is_extension_array_dtype(dtype):
1212-
result = result.astype(dtype.numpy_dtype)
1212+
try:
1213+
result = dtype.construct_array_type()._from_sequence(
1214+
result, dtype=dtype)
1215+
except Exception as e:
1216+
raise
12131217
else:
12141218
result = result.astype(dtype)
12151219

@@ -1221,7 +1225,11 @@ cdef class TextReader:
12211225

12221226
if result is not None and dtype != 'float64':
12231227
if is_extension_array_dtype(dtype):
1224-
result = result.astype(dtype.numpy_dtype)
1228+
try:
1229+
result = dtype.construct_array_type()._from_sequence(
1230+
result)
1231+
except Exception as e:
1232+
raise
12251233
else:
12261234
result = result.astype(dtype)
12271235
return result, na_count

pandas/core/arrays/base.py

+21
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,27 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):
117117
"""
118118
raise AbstractMethodError(cls)
119119

120+
@classmethod
121+
def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
122+
"""Construct a new ExtensionArray from a sequence of scalars.
123+
124+
Parameters
125+
----------
126+
strings : Sequence
127+
Each element will be an instance of the scalar type for this
128+
array, ``cls.dtype.type``.
129+
dtype : dtype, optional
130+
Construct for this particular dtype. This should be a Dtype
131+
compatible with the ExtensionArray.
132+
copy : boolean, default False
133+
If True, copy the underlying data.
134+
135+
Returns
136+
-------
137+
ExtensionArray
138+
"""
139+
raise AbstractMethodError(cls)
140+
120141
@classmethod
121142
def _from_factorized(cls, values, original):
122143
"""Reconstruct an ExtensionArray after factorization.

pandas/core/arrays/integer.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def coerce_to_array(values, dtype, mask=None, copy=False):
156156
dtype = dtype.lower()
157157
if not issubclass(type(dtype), _IntegerDtype):
158158
try:
159-
dtype = _dtypes[str(np.dtype(dtype))]
159+
dtype = _dtypes[str(np.dtype(dtype.name.lower()))]
160160
except KeyError:
161161
raise ValueError("invalid dtype specified {}".format(dtype))
162162

@@ -263,6 +263,10 @@ def __init__(self, values, mask, copy=False):
263263
def _from_sequence(cls, scalars, dtype=None, copy=False):
264264
return integer_array(scalars, dtype=dtype, copy=copy)
265265

266+
@classmethod
267+
def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
268+
return cls._from_sequence([int(x) for x in strings], dtype, copy)
269+
266270
@classmethod
267271
def _from_factorized(cls, values, original):
268272
return integer_array(values, dtype=original.dtype)

pandas/core/dtypes/cast.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -661,8 +661,22 @@ def astype_nansafe(arr, dtype, copy=True, skipna=False):
661661

662662
# dispatch on extension dtype if needed
663663
if is_extension_array_dtype(dtype):
664-
return dtype.construct_array_type()._from_sequence(
665-
arr, dtype=dtype, copy=copy)
664+
if is_object_dtype(arr):
665+
try:
666+
return dtype.construct_array_type()._from_sequence_of_strings(
667+
arr, dtype=dtype, copy=copy)
668+
except AttributeError:
669+
dtype = pandas_dtype(dtype)
670+
return dtype.construct_array_type()._from_sequence_of_strings(
671+
arr, dtype=dtype, copy=copy)
672+
else:
673+
try:
674+
return dtype.construct_array_type()._from_sequence(
675+
arr, dtype=dtype, copy=copy)
676+
except AttributeError:
677+
dtype = pandas_dtype(dtype)
678+
return dtype.construct_array_type()._from_sequence(
679+
arr, dtype=dtype, copy=copy)
666680

667681
if not isinstance(dtype, np.dtype):
668682
dtype = pandas_dtype(dtype)

pandas/core/dtypes/common.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -1886,7 +1886,10 @@ def _get_dtype(arr_or_dtype):
18861886
if isinstance(arr_or_dtype, np.dtype):
18871887
return arr_or_dtype
18881888
elif isinstance(arr_or_dtype, type):
1889-
return np.dtype(arr_or_dtype)
1889+
try:
1890+
return pandas_dtype(arr_or_dtype)
1891+
except TypeError:
1892+
return np.dtype(arr_or_dtype)
18901893
elif isinstance(arr_or_dtype, ExtensionDtype):
18911894
return arr_or_dtype
18921895
elif isinstance(arr_or_dtype, DatetimeTZDtype):
@@ -1904,6 +1907,11 @@ def _get_dtype(arr_or_dtype):
19041907
return PeriodDtype.construct_from_string(arr_or_dtype)
19051908
elif is_interval_dtype(arr_or_dtype):
19061909
return IntervalDtype.construct_from_string(arr_or_dtype)
1910+
else:
1911+
try:
1912+
return pandas_dtype(arr_or_dtype)
1913+
except TypeError:
1914+
pass
19071915
elif isinstance(arr_or_dtype, (ABCCategorical, ABCCategoricalIndex,
19081916
ABCSparseArray, ABCSparseSeries)):
19091917
return arr_or_dtype.dtype
@@ -1934,7 +1942,15 @@ def _get_dtype_type(arr_or_dtype):
19341942
if isinstance(arr_or_dtype, np.dtype):
19351943
return arr_or_dtype.type
19361944
elif isinstance(arr_or_dtype, type):
1937-
return np.dtype(arr_or_dtype).type
1945+
try:
1946+
dtype = pandas_dtype(arr_or_dtype)
1947+
try:
1948+
return dtype.type
1949+
except AttributeError:
1950+
raise TypeError
1951+
except TypeError:
1952+
return np.dtype(arr_or_dtype).type
1953+
19381954
elif isinstance(arr_or_dtype, CategoricalDtype):
19391955
return CategoricalDtypeType
19401956
elif isinstance(arr_or_dtype, DatetimeTZDtype):

pandas/core/series.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4211,7 +4211,10 @@ def _try_cast(arr, take_fast_path):
42114211
# that we can convert the data to the requested dtype.
42124212
if is_integer_dtype(dtype):
42134213
subarr = maybe_cast_to_integer_array(arr, dtype)
4214-
4214+
if is_extension_array_dtype(dtype):
4215+
# create an extension array from its dtype
4216+
array_type = dtype.construct_array_type()._from_sequence
4217+
return array_type(arr, dtype=dtype, copy=copy)
42154218
subarr = maybe_cast_to_datetime(arr, dtype)
42164219
# Take care in creating object arrays (but iterators are not
42174220
# supported):

pandas/io/parsers.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from pandas.core.dtypes.common import (
2929
ensure_object, is_categorical_dtype, is_dtype_equal, is_float, is_integer,
3030
is_integer_dtype, is_list_like, is_object_dtype, is_scalar,
31-
is_string_dtype)
31+
is_string_dtype, is_extension_array_dtype,
32+
)
3233
from pandas.core.dtypes.dtypes import CategoricalDtype
3334
from pandas.core.dtypes.missing import isna
3435

@@ -1590,15 +1591,17 @@ def _convert_to_ndarrays(self, dct, na_values, na_fvalues, verbose=False,
15901591
try_num_bool=False)
15911592
else:
15921593
# skip inference if specified dtype is object
1593-
try_num_bool = not (cast_type and is_string_dtype(cast_type))
1594+
try_num_bool = not (cast_type and (is_string_dtype(cast_type)
1595+
or is_extension_array_dtype(cast_type)))
15941596

15951597
# general type inference and conversion
15961598
cvals, na_count = self._infer_types(
15971599
values, set(col_na_values) | col_na_fvalues,
15981600
try_num_bool)
15991601

16001602
# type specified in dtype param
1601-
if cast_type and not is_dtype_equal(cvals, cast_type):
1603+
if cast_type and (not is_dtype_equal(cvals, cast_type)
1604+
or is_extension_array_dtype(cast_type)):
16021605
cvals = self._cast_types(cvals, cast_type, c)
16031606

16041607
result[c] = cvals

pandas/tests/extension/base/io.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pandas as pd
2+
from pandas.compat import StringIO
3+
from pandas.core.arrays.integer import Int64Dtype
4+
from .base import BaseExtensionTests
5+
6+
7+
class ExtensionParsingTests(BaseExtensionTests):
8+
def test_EA_types(self):
9+
df = pd.DataFrame({'Int': pd.Series([1, 2, 3], dtype='Int64'),
10+
'A': [1, 2, 1]})
11+
data = df.to_csv(index=False)
12+
result = pd.read_csv(StringIO(data), dtype={'Int': Int64Dtype})
13+
assert result is not None
14+
15+
df = pd.DataFrame({'Int': pd.Series([1, 2, 3], dtype='Int8'),
16+
'A': [1, 2, 1]})
17+
data = df.to_csv(index=False)
18+
result = pd.read_csv(StringIO(data), dtype={'Int': 'Int8'})
19+
assert result is not None

pandas/tests/extension/decimal/array.py

+5
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ def dtype(self):
7373
def _from_sequence(cls, scalars, dtype=None, copy=False):
7474
return cls(scalars)
7575

76+
@classmethod
77+
def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
78+
return cls._from_sequence([decimal.Decimal(x) for x in strings],
79+
dtype, copy)
80+
7681
@classmethod
7782
def _from_factorized(cls, values, original):
7883
return cls(values)

pandas/tests/io/parser/common.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import platform
99
import re
1010
import sys
11+
import decimal
12+
from io import TextIOWrapper
1113

1214
import numpy as np
1315
import pytest
@@ -23,6 +25,7 @@
2325
from pandas.io.common import URLError
2426
from pandas.io.parsers import TextFileReader, TextParser
2527
from pandas.core.arrays.integer import Int64Dtype
28+
from pandas.tests.extension.decimal import DecimalDtype
2629

2730

2831
class ParserTests(object):
@@ -1630,13 +1633,23 @@ def test_buffer_rd_bytes_bad_unicode(self):
16301633

16311634
def test_EA_types(self):
16321635
df = pd.DataFrame({'Int': pd.Series([1, 2, 3], dtype='Int64'),
1633-
'A': [1, 2, 1]})
1636+
'A': pd.Series([1, 2, 1], dtype=Int64Dtype)})
16341637
data = df.to_csv(index=False)
1635-
result = pd.read_csv(StringIO(data), dtype={'Int': Int64Dtype})
1636-
assert result is not None
1638+
result = pd.read_csv(StringIO(data), dtype={'Int': 'Int64',
1639+
'A': Int64Dtype})
1640+
tm.assert_frame_equal(df, result)
16371641

16381642
df = pd.DataFrame({'Int': pd.Series([1, 2, 3], dtype='Int8'),
16391643
'A': [1, 2, 1]})
16401644
data = df.to_csv(index=False)
16411645
result = pd.read_csv(StringIO(data), dtype={'Int': 'Int8'})
1642-
assert result is not None
1646+
tm.assert_frame_equal(df, result)
1647+
1648+
df = pd.DataFrame({'Dec': pd.Series([decimal.Decimal('1.234'),
1649+
decimal.Decimal('2.123'),
1650+
decimal.Decimal('4.521')],
1651+
dtype=DecimalDtype),
1652+
'A': [1, 2, 1]})
1653+
data = df.to_csv(index=False)
1654+
result = pd.read_csv(StringIO(data), dtype={'Dec': DecimalDtype})
1655+
tm.assert_frame_equal(df, result)

0 commit comments

Comments
 (0)