Skip to content

Commit e8db3d2

Browse files
committed
Address merge comments
1 parent 2c3d27a commit e8db3d2

File tree

7 files changed

+44
-89
lines changed

7 files changed

+44
-89
lines changed

doc/source/whatsnew/v0.24.0.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ New features
2626
- :meth:`DataFrame.corr` and :meth:`Series.corr` now accept a callable for generic calculation methods of correlation, e.g. histogram intersection (:issue:`22684`)
2727
- :func:`DataFrame.to_string` now accepts ``decimal`` as an argument, allowing the user to specify which decimal separator should be used in the output. (:issue:`23614`)
2828
- :func:`DataFrame.read_feather` now accepts ``columns`` as an argument, allowing the user to specify which columns should be read. (:issue:`24025`)
29-
- :func:`pandas.read_csv` now supports ``EA`` types as an argument to ``dtype``,
30-
allowing the user to use ``EA`` types when reading CSVs. (:issue:`23228`)
29+
- :func:`pandas.read_csv` now supports pandas extension types as an argument to ``dtype``,
30+
allowing the user to use pandas extension types when reading CSVs. (:issue:`23228`)
3131

3232
.. _whatsnew_0240.values_api:
3333

pandas/_libs/parsers.pyx

+16-32
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,8 @@ from pandas.core.arrays import Categorical
5656
from pandas.core.dtypes.concat import union_categoricals
5757
import pandas.io.common as icom
5858

59-
from pandas.errors import (
60-
ParserError, DtypeWarning,
61-
EmptyDataError, ParserWarning, AbstractMethodError,
62-
)
59+
from pandas.errors import ( ParserError, DtypeWarning,
60+
EmptyDataError, ParserWarning )
6361

6462
# Import CParserError as alias of ParserError for backwards compatibility.
6563
# Ultimately, we want to remove this import. See gh-12665 and gh-14479.
@@ -1217,6 +1215,18 @@ cdef class TextReader:
12171215
cats, codes, dtype, true_values=true_values)
12181216
return cat, na_count
12191217

1218+
elif is_extension_array_dtype(dtype):
1219+
result, na_count = self._string_convert(i, start, end, na_filter,
1220+
na_hashset)
1221+
try:
1222+
# use _from_sequence_of_strings if the class defines it
1223+
result = dtype.construct_array_type() \
1224+
._from_sequence_of_strings(result, dtype=dtype)
1225+
except NotImplementedError:
1226+
result = dtype.construct_array_type() \
1227+
._from_sequence(result, dtype=dtype)
1228+
return result, na_count
1229+
12201230
elif is_integer_dtype(dtype):
12211231
try:
12221232
result, na_count = _try_int64(self.parser, i, start,
@@ -1231,20 +1241,7 @@ cdef class TextReader:
12311241
na_count = 0
12321242

12331243
if result is not None and dtype != 'int64':
1234-
if is_extension_array_dtype(dtype):
1235-
try:
1236-
array_type = dtype.construct_array_type()
1237-
except AttributeError:
1238-
dtype = pandas_dtype(dtype)
1239-
array_type = dtype.construct_array_type()
1240-
try:
1241-
# use _from_sequence_of_strings if the class defines it
1242-
result = array_type._from_sequence_of_strings(result,
1243-
dtype=dtype) # noqa
1244-
except AbstractMethodError:
1245-
result = array_type._from_sequence(result, dtype=dtype)
1246-
else:
1247-
result = result.astype(dtype)
1244+
result = result.astype(dtype)
12481245

12491246
return result, na_count
12501247

@@ -1253,20 +1250,7 @@ cdef class TextReader:
12531250
na_filter, na_hashset, na_flist)
12541251

12551252
if result is not None and dtype != 'float64':
1256-
if is_extension_array_dtype(dtype):
1257-
try:
1258-
array_type = dtype.construct_array_type()
1259-
except AttributeError:
1260-
dtype = pandas_dtype(dtype)
1261-
array_type = dtype.construct_array_type()
1262-
try:
1263-
# use _from_sequence_of_strings if the class defines it
1264-
result = array_type._from_sequence_of_strings(result,
1265-
dtype=dtype) # noqa
1266-
except AbstractMethodError:
1267-
result = array_type._from_sequence(result, dtype=dtype)
1268-
else:
1269-
result = result.astype(dtype)
1253+
result = result.astype(dtype)
12701254
return result, na_count
12711255
elif is_bool_dtype(dtype):
12721256
result, na_count = _try_bool_flex(self.parser, i, start, end,

pandas/core/arrays/base.py

+3
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):
127127
def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
128128
"""Construct a new ExtensionArray from a sequence of scalars.
129129
130+
.. versionadded:: 0.24.0
131+
130132
Parameters
131133
----------
132134
strings : Sequence
@@ -141,6 +143,7 @@ def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
141143
Returns
142144
-------
143145
ExtensionArray
146+
144147
"""
145148
raise AbstractMethodError(cls)
146149

pandas/core/arrays/integer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from pandas.core import nanops
2121
from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin
22+
from pandas.core.tools.numeric import to_numeric
2223

2324

2425
class _IntegerDtype(ExtensionDtype):
@@ -157,7 +158,7 @@ def coerce_to_array(values, dtype, mask=None, copy=False):
157158
try:
158159
dtype = _dtypes[str(np.dtype(dtype.name.lower()))]
159160
except AttributeError:
160-
dtype = _dtypes[str(np.dtype(dtype.lower()))]
161+
dtype = _dtypes[str(np.dtype(dtype))]
161162
except KeyError:
162163
raise ValueError("invalid dtype specified {}".format(dtype))
163164

@@ -266,7 +267,8 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):
266267

267268
@classmethod
268269
def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
269-
return cls._from_sequence([int(x) for x in strings], dtype, copy)
270+
scalars = to_numeric(strings, errors='raise')
271+
return cls._from_sequence(scalars, dtype, copy)
270272

271273
@classmethod
272274
def _from_factorized(cls, values, original):

pandas/core/dtypes/cast.py

+12-21
Original file line numberDiff line numberDiff line change
@@ -616,27 +616,18 @@ def astype_nansafe(arr, dtype, copy=True, skipna=False):
616616

617617
# dispatch on extension dtype if needed
618618
if is_extension_array_dtype(dtype):
619-
if is_object_dtype(arr):
620-
try:
621-
array_type = dtype.construct_array_type()
622-
except AttributeError:
623-
dtype = pandas_dtype(dtype)
624-
array_type = dtype.construct_array_type()
625-
try:
626-
# use _from_sequence_of_strings if the class defines it
627-
return array_type._from_sequence_of_strings(arr,
628-
dtype=dtype,
629-
copy=copy)
630-
except AbstractMethodError:
631-
return array_type._from_sequence(arr, dtype=dtype, copy=copy)
632-
else:
633-
try:
634-
return dtype.construct_array_type()._from_sequence(
635-
arr, dtype=dtype, copy=copy)
636-
except AttributeError:
637-
dtype = pandas_dtype(dtype)
638-
return dtype.construct_array_type()._from_sequence(
639-
arr, dtype=dtype, copy=copy)
619+
try:
620+
array_type = dtype.construct_array_type()
621+
except AttributeError:
622+
dtype = pandas_dtype(dtype)
623+
array_type = dtype.construct_array_type()
624+
try:
625+
# use _from_sequence_of_strings if the class defines it
626+
return array_type._from_sequence_of_strings(arr,
627+
dtype=dtype,
628+
copy=copy)
629+
except NotImplementedError:
630+
return array_type._from_sequence(arr, dtype=dtype, copy=copy)
640631

641632
if not isinstance(dtype, np.dtype):
642633
dtype = pandas_dtype(dtype)

pandas/core/dtypes/common.py

+3-22
Original file line numberDiff line numberDiff line change
@@ -1795,10 +1795,7 @@ def _get_dtype(arr_or_dtype):
17951795
if isinstance(arr_or_dtype, np.dtype):
17961796
return arr_or_dtype
17971797
elif isinstance(arr_or_dtype, type):
1798-
try:
1799-
return pandas_dtype(arr_or_dtype)
1800-
except TypeError:
1801-
return np.dtype(arr_or_dtype)
1798+
return np.dtype(arr_or_dtype)
18021799
elif isinstance(arr_or_dtype, ExtensionDtype):
18031800
return arr_or_dtype
18041801
elif isinstance(arr_or_dtype, DatetimeTZDtype):
@@ -1816,11 +1813,6 @@ def _get_dtype(arr_or_dtype):
18161813
return PeriodDtype.construct_from_string(arr_or_dtype)
18171814
elif is_interval_dtype(arr_or_dtype):
18181815
return IntervalDtype.construct_from_string(arr_or_dtype)
1819-
else:
1820-
try:
1821-
return pandas_dtype(arr_or_dtype)
1822-
except TypeError:
1823-
pass
18241816
elif isinstance(arr_or_dtype, (ABCCategorical, ABCCategoricalIndex,
18251817
ABCSparseArray, ABCSparseSeries)):
18261818
return arr_or_dtype.dtype
@@ -1851,15 +1843,7 @@ def _get_dtype_type(arr_or_dtype):
18511843
if isinstance(arr_or_dtype, np.dtype):
18521844
return arr_or_dtype.type
18531845
elif isinstance(arr_or_dtype, type):
1854-
try:
1855-
dtype = pandas_dtype(arr_or_dtype)
1856-
try:
1857-
return dtype.type
1858-
except AttributeError:
1859-
raise TypeError
1860-
except TypeError:
1861-
return np.dtype(arr_or_dtype).type
1862-
1846+
return np.dtype(arr_or_dtype).type
18631847
elif isinstance(arr_or_dtype, CategoricalDtype):
18641848
return CategoricalDtypeType
18651849
elif isinstance(arr_or_dtype, DatetimeTZDtype):
@@ -1888,10 +1872,7 @@ def _get_dtype_type(arr_or_dtype):
18881872
try:
18891873
return arr_or_dtype.dtype.type
18901874
except AttributeError:
1891-
try:
1892-
return arr_or_dtype.numpy_dtype.type
1893-
except AttributeError:
1894-
return type(None)
1875+
return type(None)
18951876

18961877

18971878
def _get_dtype_from_object(dtype):

pandas/tests/extension/base/io.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,11 @@ def data(dtype):
2828
class ExtensionParsingTests(BaseExtensionTests):
2929

3030
@pytest.mark.parametrize('engine', ['c', 'python'])
31-
def test_EA_types(self, engine):
32-
df = pd.DataFrame({'Int': pd.Series([1, 2, 3], dtype='Int64'),
33-
'A': [1, 2, 1]})
31+
def test_EA_types(self, engine, data):
32+
df = pd.DataFrame({'Int': pd.Series(data, dtype=str(data.dtype)),
33+
'A': data})
3434
data = df.to_csv(index=False)
35-
result = pd.read_csv(StringIO(data), dtype={'Int': Int64Dtype},
35+
result = pd.read_csv(StringIO(data), dtype={'Int': str(data.dtype)},
3636
engine=engine)
3737
assert result is not None
3838

39-
df = pd.DataFrame({'Int': pd.Series([1, 2, 3], dtype='Int8'),
40-
'A': [1, 2, 1]})
41-
data = df.to_csv(index=False)
42-
result = pd.read_csv(StringIO(data), dtype={'Int': 'Int8'},
43-
engine=engine)
44-
assert result is not None

0 commit comments

Comments
 (0)