Skip to content

Commit d9a2ad1

Browse files
authored
ENH: Add on_bad_lines for pyarrow (#54643)
* ENH: Add on_bad_lines for pyarrow (SQUASHED) * Update to appropriate version in docstring * Address review comments * Refine whatsnew * Add "error" value * Condense What's New * Move to "Other Enhancements" * Refactor tests in "test_read_errors" to work with added capabilities * Conditionally import pyarrow error types * Revert changes in v2.2.0.rst > enhancements * Address review comments * Address review comments * Wrap ArrowInvalid with ParserError * Change ArrowInvalid to optional import
1 parent f19f81f commit d9a2ad1

File tree

5 files changed

+93
-16
lines changed

5 files changed

+93
-16
lines changed

doc/source/whatsnew/v2.2.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ enhancement2
7373

7474
Other enhancements
7575
^^^^^^^^^^^^^^^^^^
76+
77+
- :func:`read_csv` now supports ``on_bad_lines`` parameter with ``engine="pyarrow"``. (:issue:`54480`)
7678
- :meth:`ExtensionArray._explode` interface method added to allow extension type implementations of the ``explode`` method (:issue:`54833`)
7779
- DataFrame.apply now allows the usage of numba (via ``engine="numba"``) to JIT compile the passed function, allowing for potential speedups (:issue:`54666`)
7880
-

pandas/io/parsers/arrow_parser_wrapper.py

+39-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
from __future__ import annotations
22

33
from typing import TYPE_CHECKING
4+
import warnings
45

56
from pandas._config import using_pyarrow_string_dtype
67

78
from pandas._libs import lib
89
from pandas.compat._optional import import_optional_dependency
10+
from pandas.errors import (
11+
ParserError,
12+
ParserWarning,
13+
)
14+
from pandas.util._exceptions import find_stack_level
915

1016
from pandas.core.dtypes.inference import is_integer
1117

@@ -85,6 +91,30 @@ def _get_pyarrow_options(self) -> None:
8591
and option_name
8692
in ("delimiter", "quote_char", "escape_char", "ignore_empty_lines")
8793
}
94+
95+
on_bad_lines = self.kwds.get("on_bad_lines")
96+
if on_bad_lines is not None:
97+
if callable(on_bad_lines):
98+
self.parse_options["invalid_row_handler"] = on_bad_lines
99+
elif on_bad_lines == ParserBase.BadLineHandleMethod.ERROR:
100+
self.parse_options[
101+
"invalid_row_handler"
102+
] = None # PyArrow raises an exception by default
103+
elif on_bad_lines == ParserBase.BadLineHandleMethod.WARN:
104+
105+
def handle_warning(invalid_row):
106+
warnings.warn(
107+
f"Expected {invalid_row.expected_columns} columns, but found "
108+
f"{invalid_row.actual_columns}: {invalid_row.text}",
109+
ParserWarning,
110+
stacklevel=find_stack_level(),
111+
)
112+
return "skip"
113+
114+
self.parse_options["invalid_row_handler"] = handle_warning
115+
elif on_bad_lines == ParserBase.BadLineHandleMethod.SKIP:
116+
self.parse_options["invalid_row_handler"] = lambda _: "skip"
117+
88118
self.convert_options = {
89119
option_name: option_value
90120
for option_name, option_value in self.kwds.items()
@@ -190,12 +220,15 @@ def read(self) -> DataFrame:
190220
pyarrow_csv = import_optional_dependency("pyarrow.csv")
191221
self._get_pyarrow_options()
192222

193-
table = pyarrow_csv.read_csv(
194-
self.src,
195-
read_options=pyarrow_csv.ReadOptions(**self.read_options),
196-
parse_options=pyarrow_csv.ParseOptions(**self.parse_options),
197-
convert_options=pyarrow_csv.ConvertOptions(**self.convert_options),
198-
)
223+
try:
224+
table = pyarrow_csv.read_csv(
225+
self.src,
226+
read_options=pyarrow_csv.ReadOptions(**self.read_options),
227+
parse_options=pyarrow_csv.ParseOptions(**self.parse_options),
228+
convert_options=pyarrow_csv.ConvertOptions(**self.convert_options),
229+
)
230+
except pa.ArrowInvalid as e:
231+
raise ParserError(e) from e
199232

200233
dtype_backend = self.kwds["dtype_backend"]
201234

pandas/io/parsers/readers.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,13 @@
401401
expected, a ``ParserWarning`` will be emitted while dropping extra elements.
402402
Only supported when ``engine='python'``
403403
404+
.. versionchanged:: 2.2.0
405+
406+
- Callable, function with signature
407+
as described in `pyarrow documentation
408+
<https://arrow.apache.org/docs/python/generated/pyarrow.csv.ParseOptions.html
409+
#pyarrow.csv.ParseOptions.invalid_row_handler>_` when ``engine='pyarrow'``
410+
404411
delim_whitespace : bool, default False
405412
Specifies whether or not whitespace (e.g. ``' '`` or ``'\\t'``) will be
406413
used as the ``sep`` delimiter. Equivalent to setting ``sep='\\s+'``. If this option
@@ -494,7 +501,6 @@ class _Fwf_Defaults(TypedDict):
494501
"thousands",
495502
"memory_map",
496503
"dialect",
497-
"on_bad_lines",
498504
"delim_whitespace",
499505
"quoting",
500506
"lineterminator",
@@ -2142,9 +2148,10 @@ def _refine_defaults_read(
21422148
elif on_bad_lines == "skip":
21432149
kwds["on_bad_lines"] = ParserBase.BadLineHandleMethod.SKIP
21442150
elif callable(on_bad_lines):
2145-
if engine != "python":
2151+
if engine not in ["python", "pyarrow"]:
21462152
raise ValueError(
2147-
"on_bad_line can only be a callable function if engine='python'"
2153+
"on_bad_line can only be a callable function "
2154+
"if engine='python' or 'pyarrow'"
21482155
)
21492156
kwds["on_bad_lines"] = on_bad_lines
21502157
else:

pandas/tests/io/parser/common/test_read_errors.py

+35-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Tests that work on both the Python and C engines but do not have a
2+
Tests that work on the Python, C and PyArrow engines but do not have a
33
specific classification into the other test modules.
44
"""
55
import codecs
@@ -21,7 +21,8 @@
2121
from pandas import DataFrame
2222
import pandas._testing as tm
2323

24-
pytestmark = pytest.mark.usefixtures("pyarrow_skip")
24+
xfail_pyarrow = pytest.mark.usefixtures("pyarrow_xfail")
25+
skip_pyarrow = pytest.mark.usefixtures("pyarrow_skip")
2526

2627

2728
def test_empty_decimal_marker(all_parsers):
@@ -33,10 +34,17 @@ def test_empty_decimal_marker(all_parsers):
3334
msg = "Only length-1 decimal markers supported"
3435
parser = all_parsers
3536

37+
if parser.engine == "pyarrow":
38+
msg = (
39+
"only single character unicode strings can be "
40+
"converted to Py_UCS4, got length 0"
41+
)
42+
3643
with pytest.raises(ValueError, match=msg):
3744
parser.read_csv(StringIO(data), decimal="")
3845

3946

47+
@skip_pyarrow
4048
def test_bad_stream_exception(all_parsers, csv_dir_path):
4149
# see gh-13652
4250
#
@@ -57,6 +65,7 @@ def test_bad_stream_exception(all_parsers, csv_dir_path):
5765
parser.read_csv(stream)
5866

5967

68+
@skip_pyarrow
6069
def test_malformed(all_parsers):
6170
# see gh-6607
6271
parser = all_parsers
@@ -71,6 +80,7 @@ def test_malformed(all_parsers):
7180
parser.read_csv(StringIO(data), header=1, comment="#")
7281

7382

83+
@skip_pyarrow
7484
@pytest.mark.parametrize("nrows", [5, 3, None])
7585
def test_malformed_chunks(all_parsers, nrows):
7686
data = """ignore
@@ -90,6 +100,7 @@ def test_malformed_chunks(all_parsers, nrows):
90100
reader.read(nrows)
91101

92102

103+
@skip_pyarrow
93104
def test_catch_too_many_names(all_parsers):
94105
# see gh-5156
95106
data = """\
@@ -109,6 +120,7 @@ def test_catch_too_many_names(all_parsers):
109120
parser.read_csv(StringIO(data), header=0, names=["a", "b", "c", "d"])
110121

111122

123+
@skip_pyarrow
112124
@pytest.mark.parametrize("nrows", [0, 1, 2, 3, 4, 5])
113125
def test_raise_on_no_columns(all_parsers, nrows):
114126
parser = all_parsers
@@ -147,6 +159,10 @@ def test_error_bad_lines(all_parsers):
147159
data = "a\n1\n1,2,3\n4\n5,6,7"
148160

149161
msg = "Expected 1 fields in line 3, saw 3"
162+
163+
if parser.engine == "pyarrow":
164+
msg = "CSV parse error: Expected 1 columns, got 3: 1,2,3"
165+
150166
with pytest.raises(ParserError, match=msg):
151167
parser.read_csv(StringIO(data), on_bad_lines="error")
152168

@@ -156,9 +172,13 @@ def test_warn_bad_lines(all_parsers):
156172
parser = all_parsers
157173
data = "a\n1\n1,2,3\n4\n5,6,7"
158174
expected = DataFrame({"a": [1, 4]})
175+
match_msg = "Skipping line"
176+
177+
if parser.engine == "pyarrow":
178+
match_msg = "Expected 1 columns, but found 3: 1,2,3"
159179

160180
with tm.assert_produces_warning(
161-
ParserWarning, match="Skipping line", check_stacklevel=False
181+
ParserWarning, match=match_msg, check_stacklevel=False
162182
):
163183
result = parser.read_csv(StringIO(data), on_bad_lines="warn")
164184
tm.assert_frame_equal(result, expected)
@@ -174,10 +194,14 @@ def test_read_csv_wrong_num_columns(all_parsers):
174194
parser = all_parsers
175195
msg = "Expected 6 fields in line 3, saw 7"
176196

197+
if parser.engine == "pyarrow":
198+
msg = "Expected 6 columns, got 7: 6,7,8,9,10,11,12"
199+
177200
with pytest.raises(ParserError, match=msg):
178201
parser.read_csv(StringIO(data))
179202

180203

204+
@skip_pyarrow
181205
def test_null_byte_char(request, all_parsers):
182206
# see gh-2741
183207
data = "\x00,foo"
@@ -200,6 +224,7 @@ def test_null_byte_char(request, all_parsers):
200224
parser.read_csv(StringIO(data), names=names)
201225

202226

227+
@skip_pyarrow
203228
@pytest.mark.filterwarnings("always::ResourceWarning")
204229
def test_open_file(request, all_parsers):
205230
# GH 39024
@@ -238,6 +263,8 @@ def test_bad_header_uniform_error(all_parsers):
238263
"Could not construct index. Requested to use 1 "
239264
"number of columns, but 3 left to parse."
240265
)
266+
elif parser.engine == "pyarrow":
267+
msg = "CSV parse error: Expected 1 columns, got 4: col1,col2,col3,col4"
241268

242269
with pytest.raises(ParserError, match=msg):
243270
parser.read_csv(StringIO(data), index_col=0, on_bad_lines="error")
@@ -253,9 +280,13 @@ def test_on_bad_lines_warn_correct_formatting(all_parsers):
253280
a,b
254281
"""
255282
expected = DataFrame({"1": "a", "2": ["b"] * 2})
283+
match_msg = "Skipping line"
284+
285+
if parser.engine == "pyarrow":
286+
match_msg = "Expected 2 columns, but found 3: a,b,c"
256287

257288
with tm.assert_produces_warning(
258-
ParserWarning, match="Skipping line", check_stacklevel=False
289+
ParserWarning, match=match_msg, check_stacklevel=False
259290
):
260291
result = parser.read_csv(StringIO(data), on_bad_lines="warn")
261292
tm.assert_frame_equal(result, expected)

pandas/tests/io/parser/test_unsupported.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,17 @@ def test_pyarrow_engine(self):
151151
with pytest.raises(ValueError, match=msg):
152152
read_csv(StringIO(data), engine="pyarrow", **kwargs)
153153

154-
def test_on_bad_lines_callable_python_only(self, all_parsers):
154+
def test_on_bad_lines_callable_python_or_pyarrow(self, all_parsers):
155155
# GH 5686
156+
# GH 54643
156157
sio = StringIO("a,b\n1,2")
157158
bad_lines_func = lambda x: x
158159
parser = all_parsers
159-
if all_parsers.engine != "python":
160-
msg = "on_bad_line can only be a callable function if engine='python'"
160+
if all_parsers.engine not in ["python", "pyarrow"]:
161+
msg = (
162+
"on_bad_line can only be a callable "
163+
"function if engine='python' or 'pyarrow'"
164+
)
161165
with pytest.raises(ValueError, match=msg):
162166
parser.read_csv(sio, on_bad_lines=bad_lines_func)
163167
else:

0 commit comments

Comments
 (0)