Skip to content

Commit 5b53438

Browse files
committed
WIP:Make python engine support EA types when reading CSVs
The C engine is the real WIP.
1 parent 06d9c01 commit 5b53438

File tree

9 files changed

+104
-11
lines changed

9 files changed

+104
-11
lines changed

pandas/_libs/parsers.pyx

+10-2
Original file line numberDiff line numberDiff line change
@@ -1231,7 +1231,11 @@ cdef class TextReader:
12311231

12321232
if result is not None and dtype != 'int64':
12331233
if is_extension_array_dtype(dtype):
1234-
result = result.astype(dtype.numpy_dtype)
1234+
try:
1235+
result = dtype.construct_array_type()._from_sequence(
1236+
result, dtype=dtype)
1237+
except Exception as e:
1238+
raise
12351239
else:
12361240
result = result.astype(dtype)
12371241

@@ -1243,7 +1247,11 @@ cdef class TextReader:
12431247

12441248
if result is not None and dtype != 'float64':
12451249
if is_extension_array_dtype(dtype):
1246-
result = result.astype(dtype.numpy_dtype)
1250+
try:
1251+
result = dtype.construct_array_type()._from_sequence(
1252+
result)
1253+
except Exception as e:
1254+
raise
12471255
else:
12481256
result = result.astype(dtype)
12491257
return result, na_count

pandas/core/arrays/base.py

+21
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,27 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):
119119
"""
120120
raise AbstractMethodError(cls)
121121

122+
@classmethod
123+
def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
124+
"""Construct a new ExtensionArray from a sequence of scalars.
125+
126+
Parameters
127+
----------
128+
strings : Sequence
129+
Each element will be an instance of the scalar type for this
130+
array, ``cls.dtype.type``.
131+
dtype : dtype, optional
132+
Construct for this particular dtype. This should be a Dtype
133+
compatible with the ExtensionArray.
134+
copy : boolean, default False
135+
If True, copy the underlying data.
136+
137+
Returns
138+
-------
139+
ExtensionArray
140+
"""
141+
raise AbstractMethodError(cls)
142+
122143
@classmethod
123144
def _from_factorized(cls, values, original):
124145
"""Reconstruct an ExtensionArray after factorization.

pandas/core/arrays/integer.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def coerce_to_array(values, dtype, mask=None, copy=False):
152152
dtype = dtype.lower()
153153
if not issubclass(type(dtype), _IntegerDtype):
154154
try:
155-
dtype = _dtypes[str(np.dtype(dtype))]
155+
dtype = _dtypes[str(np.dtype(dtype.name.lower()))]
156156
except KeyError:
157157
raise ValueError("invalid dtype specified {}".format(dtype))
158158

@@ -259,6 +259,10 @@ def __init__(self, values, mask, copy=False):
259259
def _from_sequence(cls, scalars, dtype=None, copy=False):
260260
return integer_array(scalars, dtype=dtype, copy=copy)
261261

262+
@classmethod
263+
def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
264+
return cls._from_sequence([int(x) for x in strings], dtype, copy)
265+
262266
@classmethod
263267
def _from_factorized(cls, values, original):
264268
return integer_array(values, dtype=original.dtype)

pandas/core/dtypes/cast.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -615,8 +615,22 @@ def astype_nansafe(arr, dtype, copy=True, skipna=False):
615615

616616
# dispatch on extension dtype if needed
617617
if is_extension_array_dtype(dtype):
618-
return dtype.construct_array_type()._from_sequence(
619-
arr, dtype=dtype, copy=copy)
618+
if is_object_dtype(arr):
619+
try:
620+
return dtype.construct_array_type()._from_sequence_of_strings(
621+
arr, dtype=dtype, copy=copy)
622+
except AttributeError:
623+
dtype = pandas_dtype(dtype)
624+
return dtype.construct_array_type()._from_sequence_of_strings(
625+
arr, dtype=dtype, copy=copy)
626+
else:
627+
try:
628+
return dtype.construct_array_type()._from_sequence(
629+
arr, dtype=dtype, copy=copy)
630+
except AttributeError:
631+
dtype = pandas_dtype(dtype)
632+
return dtype.construct_array_type()._from_sequence(
633+
arr, dtype=dtype, copy=copy)
620634

621635
if not isinstance(dtype, np.dtype):
622636
dtype = pandas_dtype(dtype)

pandas/core/dtypes/common.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -1827,7 +1827,10 @@ def _get_dtype(arr_or_dtype):
18271827
if isinstance(arr_or_dtype, np.dtype):
18281828
return arr_or_dtype
18291829
elif isinstance(arr_or_dtype, type):
1830-
return np.dtype(arr_or_dtype)
1830+
try:
1831+
return pandas_dtype(arr_or_dtype)
1832+
except TypeError:
1833+
return np.dtype(arr_or_dtype)
18311834
elif isinstance(arr_or_dtype, ExtensionDtype):
18321835
return arr_or_dtype
18331836
elif isinstance(arr_or_dtype, DatetimeTZDtype):
@@ -1845,6 +1848,11 @@ def _get_dtype(arr_or_dtype):
18451848
return PeriodDtype.construct_from_string(arr_or_dtype)
18461849
elif is_interval_dtype(arr_or_dtype):
18471850
return IntervalDtype.construct_from_string(arr_or_dtype)
1851+
else:
1852+
try:
1853+
return pandas_dtype(arr_or_dtype)
1854+
except TypeError:
1855+
pass
18481856
elif isinstance(arr_or_dtype, (ABCCategorical, ABCCategoricalIndex,
18491857
ABCSparseArray, ABCSparseSeries)):
18501858
return arr_or_dtype.dtype
@@ -1875,7 +1883,15 @@ def _get_dtype_type(arr_or_dtype):
18751883
if isinstance(arr_or_dtype, np.dtype):
18761884
return arr_or_dtype.type
18771885
elif isinstance(arr_or_dtype, type):
1878-
return np.dtype(arr_or_dtype).type
1886+
try:
1887+
dtype = pandas_dtype(arr_or_dtype)
1888+
try:
1889+
return dtype.type
1890+
except AttributeError:
1891+
raise TypeError
1892+
except TypeError:
1893+
return np.dtype(arr_or_dtype).type
1894+
18791895
elif isinstance(arr_or_dtype, CategoricalDtype):
18801896
return CategoricalDtypeType
18811897
elif isinstance(arr_or_dtype, DatetimeTZDtype):

pandas/core/series.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4317,7 +4317,10 @@ def _try_cast(arr, take_fast_path):
43174317
# that we can convert the data to the requested dtype.
43184318
if is_integer_dtype(dtype):
43194319
subarr = maybe_cast_to_integer_array(arr, dtype)
4320-
4320+
if is_extension_array_dtype(dtype):
4321+
# create an extension array from its dtype
4322+
array_type = dtype.construct_array_type()._from_sequence
4323+
return array_type(arr, dtype=dtype, copy=copy)
43214324
subarr = maybe_cast_to_datetime(arr, dtype)
43224325
# Take care in creating object arrays (but iterators are not
43234326
# supported):

pandas/io/parsers.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
from pandas.core.dtypes.common import (
3030
ensure_object, is_categorical_dtype, is_dtype_equal, is_float, is_integer,
3131
is_integer_dtype, is_list_like, is_object_dtype, is_scalar,
32-
is_string_dtype)
32+
is_string_dtype, is_extension_array_dtype,
33+
)
3334
from pandas.core.dtypes.dtypes import CategoricalDtype
3435
from pandas.core.dtypes.missing import isna
3536

@@ -1660,15 +1661,17 @@ def _convert_to_ndarrays(self, dct, na_values, na_fvalues, verbose=False,
16601661
try_num_bool=False)
16611662
else:
16621663
# skip inference if specified dtype is object
1663-
try_num_bool = not (cast_type and is_string_dtype(cast_type))
1664+
try_num_bool = not (cast_type and (is_string_dtype(cast_type)
1665+
or is_extension_array_dtype(cast_type)))
16641666

16651667
# general type inference and conversion
16661668
cvals, na_count = self._infer_types(
16671669
values, set(col_na_values) | col_na_fvalues,
16681670
try_num_bool)
16691671

16701672
# type specified in dtype param
1671-
if cast_type and not is_dtype_equal(cvals, cast_type):
1673+
if cast_type and (not is_dtype_equal(cvals, cast_type)
1674+
or is_extension_array_dtype(cast_type)):
16721675
cvals = self._cast_types(cvals, cast_type, c)
16731676

16741677
result[c] = cvals

pandas/tests/extension/base/io.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pandas as pd
2+
from pandas.compat import StringIO
3+
from pandas.core.arrays.integer import Int64Dtype
4+
from .base import BaseExtensionTests
5+
6+
7+
class ExtensionParsingTests(BaseExtensionTests):
8+
def test_EA_types(self):
9+
df = pd.DataFrame({'Int': pd.Series([1, 2, 3], dtype='Int64'),
10+
'A': [1, 2, 1]})
11+
data = df.to_csv(index=False)
12+
result = pd.read_csv(StringIO(data), dtype={'Int': Int64Dtype})
13+
assert result is not None
14+
15+
df = pd.DataFrame({'Int': pd.Series([1, 2, 3], dtype='Int8'),
16+
'A': [1, 2, 1]})
17+
data = df.to_csv(index=False)
18+
result = pd.read_csv(StringIO(data), dtype={'Int': 'Int8'})
19+
assert result is not None

pandas/tests/extension/decimal/array.py

+5
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ def dtype(self):
7373
def _from_sequence(cls, scalars, dtype=None, copy=False):
7474
return cls(scalars)
7575

76+
@classmethod
77+
def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
78+
return cls._from_sequence([decimal.Decimal(x) for x in strings],
79+
dtype, copy)
80+
7681
@classmethod
7782
def _from_factorized(cls, values, original):
7883
return cls(values)

0 commit comments

Comments
 (0)