Skip to content

Commit d6b9157

Browse files
lithomas1mroeschke
authored and
im-vinicius
committed
BUG: Fix Arrow CSV Parser erroring when specifying a dtype for index … (pandas-dev#53360)
* BUG: Fix Arrow CSV Parser erroring when specifying a dtype for index cols * BUG: Fix Arrow CSV Parser erroring when specifying a dtype for index cols * Update doc/source/whatsnew/v2.1.0.rst Co-authored-by: Matthew Roeschke <[email protected]> --------- Co-authored-by: Matthew Roeschke <[email protected]>
1 parent 7c9fa51 commit d6b9157

File tree

4 files changed

+64
-6
lines changed

4 files changed

+64
-6
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ I/O
396396
- :meth:`DataFrame.to_sql` now raising ``ValueError`` when the name param is left empty while using SQLAlchemy to connect (:issue:`52675`)
397397
- Bug in :func:`json_normalize`, fix json_normalize cannot parse metadata fields list type (:issue:`37782`)
398398
- Bug in :func:`read_csv` where it would error when ``parse_dates`` was set to a list or dictionary with ``engine="pyarrow"`` (:issue:`47961`)
399+
- Bug in :func:`read_csv`, with ``engine="pyarrow"`` erroring when specifying a ``dtype`` with ``index_col`` (:issue:`53229`)
399400
- Bug in :func:`read_hdf` not properly closing store after a ``IndexError`` is raised (:issue:`52781`)
400401
- Bug in :func:`read_html`, style elements were read into DataFrames (:issue:`52197`)
401402
- Bug in :func:`read_html`, tail texts were removed together with elements containing ``display:none`` style (:issue:`51629`)

pandas/io/parsers/arrow_parser_wrapper.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,29 @@ def _finalize_pandas_output(self, frame: DataFrame) -> DataFrame:
142142
elif item not in frame.columns:
143143
raise ValueError(f"Index {item} invalid")
144144

145+
# Process dtype for index_col and drop from dtypes
146+
if self.dtype is not None:
147+
key, new_dtype = (
148+
(item, self.dtype.get(item))
149+
if self.dtype.get(item) is not None
150+
else (frame.columns[item], self.dtype.get(frame.columns[item]))
151+
)
152+
if new_dtype is not None:
153+
frame[key] = frame[key].astype(new_dtype)
154+
del self.dtype[key]
155+
145156
frame.set_index(index_to_set, drop=True, inplace=True)
146157
# Clear names if headerless and no name given
147158
if self.header is None and not multi_index_named:
148159
frame.index.names = [None] * len(frame.index.names)
149160

150-
if self.kwds.get("dtype") is not None:
161+
if self.dtype is not None:
162+
# Ignore non-existent columns from dtype mapping
163+
# like other parsers do
164+
if isinstance(self.dtype, dict):
165+
self.dtype = {k: v for k, v in self.dtype.items() if k in frame.columns}
151166
try:
152-
frame = frame.astype(self.kwds.get("dtype"))
167+
frame = frame.astype(self.dtype)
153168
except TypeError as e:
154169
# GH#44901 reraise to keep api consistent
155170
raise ValueError(e)

pandas/tests/io/parser/test_index_col.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,15 @@ def test_infer_types_boolean_sum(all_parsers):
324324
tm.assert_frame_equal(result, expected, check_index_type=False)
325325

326326

327-
@skip_pyarrow
328327
@pytest.mark.parametrize("dtype, val", [(object, "01"), ("int64", 1)])
329-
def test_specify_dtype_for_index_col(all_parsers, dtype, val):
328+
def test_specify_dtype_for_index_col(all_parsers, dtype, val, request):
330329
# GH#9435
331330
data = "a,b\n01,2"
332331
parser = all_parsers
332+
if dtype == object and parser.engine == "pyarrow":
333+
request.node.add_marker(
334+
pytest.mark.xfail(reason="Cannot disable type-inference for pyarrow engine")
335+
)
333336
result = parser.read_csv(StringIO(data), index_col="a", dtype={"a": dtype})
334337
expected = DataFrame({"b": [2]}, index=Index([val], name="a"))
335338
tm.assert_frame_equal(result, expected)

pandas/tests/io/parser/usecols/test_usecols_basic.py

+41-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pandas import (
1313
DataFrame,
1414
Index,
15+
array,
1516
)
1617
import pandas._testing as tm
1718

@@ -24,8 +25,8 @@
2425
"Usecols do not match columns, columns expected but not found: {0}"
2526
)
2627

27-
# TODO(1.4): Change to xfails at release time
28-
pytestmark = pytest.mark.usefixtures("pyarrow_skip")
28+
# TODO: Switch to xfails
29+
skip_pyarrow = pytest.mark.usefixtures("pyarrow_skip")
2930

3031

3132
def test_raise_on_mixed_dtype_usecols(all_parsers):
@@ -41,6 +42,7 @@ def test_raise_on_mixed_dtype_usecols(all_parsers):
4142
parser.read_csv(StringIO(data), usecols=usecols)
4243

4344

45+
@skip_pyarrow
4446
@pytest.mark.parametrize("usecols", [(1, 2), ("b", "c")])
4547
def test_usecols(all_parsers, usecols):
4648
data = """\
@@ -56,6 +58,7 @@ def test_usecols(all_parsers, usecols):
5658
tm.assert_frame_equal(result, expected)
5759

5860

61+
@skip_pyarrow
5962
def test_usecols_with_names(all_parsers):
6063
data = """\
6164
a,b,c
@@ -71,6 +74,7 @@ def test_usecols_with_names(all_parsers):
7174
tm.assert_frame_equal(result, expected)
7275

7376

77+
@skip_pyarrow
7478
@pytest.mark.parametrize(
7579
"names,usecols", [(["b", "c"], [1, 2]), (["a", "b", "c"], ["b", "c"])]
7680
)
@@ -87,6 +91,7 @@ def test_usecols_relative_to_names(all_parsers, names, usecols):
8791
tm.assert_frame_equal(result, expected)
8892

8993

94+
@skip_pyarrow
9095
def test_usecols_relative_to_names2(all_parsers):
9196
# see gh-5766
9297
data = """\
@@ -103,6 +108,7 @@ def test_usecols_relative_to_names2(all_parsers):
103108
tm.assert_frame_equal(result, expected)
104109

105110

111+
@skip_pyarrow
106112
def test_usecols_name_length_conflict(all_parsers):
107113
data = """\
108114
1,2,3
@@ -127,6 +133,7 @@ def test_usecols_single_string(all_parsers):
127133
parser.read_csv(StringIO(data), usecols="foo")
128134

129135

136+
@skip_pyarrow
130137
@pytest.mark.parametrize(
131138
"data", ["a,b,c,d\n1,2,3,4\n5,6,7,8", "a,b,c,d\n1,2,3,4,\n5,6,7,8,"]
132139
)
@@ -140,6 +147,7 @@ def test_usecols_index_col_false(all_parsers, data):
140147
tm.assert_frame_equal(result, expected)
141148

142149

150+
@skip_pyarrow
143151
@pytest.mark.parametrize("index_col", ["b", 0])
144152
@pytest.mark.parametrize("usecols", [["b", "c"], [1, 2]])
145153
def test_usecols_index_col_conflict(all_parsers, usecols, index_col):
@@ -166,6 +174,7 @@ def test_usecols_index_col_conflict2(all_parsers):
166174
tm.assert_frame_equal(result, expected)
167175

168176

177+
@skip_pyarrow
169178
def test_usecols_implicit_index_col(all_parsers):
170179
# see gh-2654
171180
parser = all_parsers
@@ -198,6 +207,7 @@ def test_usecols_index_col_end(all_parsers):
198207
tm.assert_frame_equal(result, expected)
199208

200209

210+
@skip_pyarrow
201211
def test_usecols_regex_sep(all_parsers):
202212
# see gh-2733
203213
parser = all_parsers
@@ -208,6 +218,7 @@ def test_usecols_regex_sep(all_parsers):
208218
tm.assert_frame_equal(result, expected)
209219

210220

221+
@skip_pyarrow
211222
def test_usecols_with_whitespace(all_parsers):
212223
parser = all_parsers
213224
data = "a b c\n4 apple bat 5.7\n8 orange cow 10"
@@ -217,6 +228,7 @@ def test_usecols_with_whitespace(all_parsers):
217228
tm.assert_frame_equal(result, expected)
218229

219230

231+
@skip_pyarrow
220232
@pytest.mark.parametrize(
221233
"usecols,expected",
222234
[
@@ -239,6 +251,7 @@ def test_usecols_with_integer_like_header(all_parsers, usecols, expected):
239251
tm.assert_frame_equal(result, expected)
240252

241253

254+
@skip_pyarrow
242255
def test_empty_usecols(all_parsers):
243256
data = "a,b,c\n1,2,3\n4,5,6"
244257
expected = DataFrame(columns=Index([]))
@@ -259,6 +272,7 @@ def test_np_array_usecols(all_parsers):
259272
tm.assert_frame_equal(result, expected)
260273

261274

275+
@skip_pyarrow
262276
@pytest.mark.parametrize(
263277
"usecols,expected",
264278
[
@@ -291,6 +305,7 @@ def test_callable_usecols(all_parsers, usecols, expected):
291305
tm.assert_frame_equal(result, expected)
292306

293307

308+
@skip_pyarrow
294309
@pytest.mark.parametrize("usecols", [["a", "c"], lambda x: x in ["a", "c"]])
295310
def test_incomplete_first_row(all_parsers, usecols):
296311
# see gh-6710
@@ -303,6 +318,7 @@ def test_incomplete_first_row(all_parsers, usecols):
303318
tm.assert_frame_equal(result, expected)
304319

305320

321+
@skip_pyarrow
306322
@pytest.mark.parametrize(
307323
"data,usecols,kwargs,expected",
308324
[
@@ -335,6 +351,7 @@ def test_uneven_length_cols(all_parsers, data, usecols, kwargs, expected):
335351
tm.assert_frame_equal(result, expected)
336352

337353

354+
@skip_pyarrow
338355
@pytest.mark.parametrize(
339356
"usecols,kwargs,expected,msg",
340357
[
@@ -391,6 +408,7 @@ def test_raises_on_usecols_names_mismatch(all_parsers, usecols, kwargs, expected
391408
tm.assert_frame_equal(result, expected)
392409

393410

411+
@skip_pyarrow
394412
@pytest.mark.parametrize("usecols", [["A", "C"], [0, 2]])
395413
def test_usecols_subset_names_mismatch_orig_columns(all_parsers, usecols):
396414
data = "a,b,c,d\n1,2,3,4\n5,6,7,8"
@@ -402,6 +420,7 @@ def test_usecols_subset_names_mismatch_orig_columns(all_parsers, usecols):
402420
tm.assert_frame_equal(result, expected)
403421

404422

423+
@skip_pyarrow
405424
@pytest.mark.parametrize("names", [None, ["a", "b"]])
406425
def test_usecols_indices_out_of_bounds(all_parsers, names):
407426
# GH#25623 & GH 41130; enforced in 2.0
@@ -414,6 +433,7 @@ def test_usecols_indices_out_of_bounds(all_parsers, names):
414433
parser.read_csv(StringIO(data), usecols=[0, 2], names=names, header=0)
415434

416435

436+
@skip_pyarrow
417437
def test_usecols_additional_columns(all_parsers):
418438
# GH#46997
419439
parser = all_parsers
@@ -423,10 +443,29 @@ def test_usecols_additional_columns(all_parsers):
423443
tm.assert_frame_equal(result, expected)
424444

425445

446+
@skip_pyarrow
426447
def test_usecols_additional_columns_integer_columns(all_parsers):
427448
# GH#46997
428449
parser = all_parsers
429450
usecols = lambda header: header.strip() in ["0", "1"]
430451
result = parser.read_csv(StringIO("0,1\nx,y,z"), index_col=False, usecols=usecols)
431452
expected = DataFrame({"0": ["x"], "1": "y"})
432453
tm.assert_frame_equal(result, expected)
454+
455+
456+
def test_usecols_dtype(all_parsers):
457+
parser = all_parsers
458+
data = """
459+
col1,col2,col3
460+
a,1,x
461+
b,2,y
462+
"""
463+
result = parser.read_csv(
464+
StringIO(data),
465+
usecols=["col1", "col2"],
466+
dtype={"col1": "string", "col2": "uint8", "col3": "string"},
467+
)
468+
expected = DataFrame(
469+
{"col1": array(["a", "b"]), "col2": np.array([1, 2], dtype="uint8")}
470+
)
471+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)