Skip to content

Commit efbb63b

Browse files
committed
Backport PR pandas-dev#52087: BUG: Fix some more arrow CSV tests
1 parent 2b012f0 commit efbb63b

File tree

6 files changed

+45
-30
lines changed

6 files changed

+45
-30
lines changed

doc/source/whatsnew/v2.0.2.rst

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ Bug fixes
3030
- Bug in :func:`api.interchange.from_dataframe` was unnecessarily raising on bitmasks (:issue:`49888`)
3131
- Bug in :func:`merge` when merging on datetime columns on different resolutions (:issue:`53200`)
3232
- Bug in :func:`read_csv` raising ``OverflowError`` for ``engine="pyarrow"`` and ``parse_dates`` set (:issue:`53295`)
33+
- Bug in :func:`read_csv` not processing empty strings as a null value, with ``engine="pyarrow"`` (:issue:`52087`)
34+
- Bug in :func:`read_csv` returning ``object`` dtype columns instead of ``float64`` dtype columns with ``engine="pyarrow"`` for columns that are all null with ``engine="pyarrow"`` (:issue:`52087`)
3335
- Bug in :func:`to_datetime` was inferring format to contain ``"%H"`` instead of ``"%I"`` if date contained "AM" / "PM" tokens (:issue:`53147`)
3436
- Bug in :meth:`DataFrame.convert_dtypes` ignores ``convert_*`` keywords when set to False ``dtype_backend="pyarrow"`` (:issue:`52872`)
3537
- Bug in :meth:`DataFrame.sort_values` raising for PyArrow ``dictionary`` dtype (:issue:`53232`)

pandas/io/parsers/arrow_parser_wrapper.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from pandas._libs import lib
34
from pandas._typing import ReadBuffer
45
from pandas.compat._optional import import_optional_dependency
56

@@ -76,6 +77,7 @@ def _get_pyarrow_options(self) -> None:
7677
"decimal_point",
7778
)
7879
}
80+
self.convert_options["strings_can_be_null"] = "" in self.kwds["null_values"]
7981
self.read_options = {
8082
"autogenerate_column_names": self.header is None,
8183
"skip_rows": self.header
@@ -146,6 +148,7 @@ def read(self) -> DataFrame:
146148
DataFrame
147149
The DataFrame created from the CSV file.
148150
"""
151+
pa = import_optional_dependency("pyarrow")
149152
pyarrow_csv = import_optional_dependency("pyarrow.csv")
150153
self._get_pyarrow_options()
151154

@@ -155,10 +158,30 @@ def read(self) -> DataFrame:
155158
parse_options=pyarrow_csv.ParseOptions(**self.parse_options),
156159
convert_options=pyarrow_csv.ConvertOptions(**self.convert_options),
157160
)
158-
if self.kwds["dtype_backend"] == "pyarrow":
161+
162+
dtype_backend = self.kwds["dtype_backend"]
163+
164+
# Convert all pa.null() cols -> float64 (non nullable)
165+
# else Int64 (nullable case, see below)
166+
if dtype_backend is lib.no_default:
167+
new_schema = table.schema
168+
new_type = pa.float64()
169+
for i, arrow_type in enumerate(table.schema.types):
170+
if pa.types.is_null(arrow_type):
171+
new_schema = new_schema.set(
172+
i, new_schema.field(i).with_type(new_type)
173+
)
174+
175+
table = table.cast(new_schema)
176+
177+
if dtype_backend == "pyarrow":
159178
frame = table.to_pandas(types_mapper=pd.ArrowDtype)
160-
elif self.kwds["dtype_backend"] == "numpy_nullable":
161-
frame = table.to_pandas(types_mapper=_arrow_dtype_mapping().get)
179+
elif dtype_backend == "numpy_nullable":
180+
# Modify the default mapping to also
181+
# map null to Int64 (to match other engines)
182+
dtype_mapping = _arrow_dtype_mapping()
183+
dtype_mapping[pa.null()] = pd.Int64Dtype()
184+
frame = table.to_pandas(types_mapper=dtype_mapping.get)
162185
else:
163186
frame = table.to_pandas()
164187
return self._finalize_pandas_output(frame)

pandas/io/parsers/readers.py

+3
Original file line numberDiff line numberDiff line change
@@ -1438,8 +1438,11 @@ def _get_options_with_defaults(self, engine: CSVEngine) -> dict[str, Any]:
14381438
value = kwds[argname]
14391439

14401440
if engine != "c" and value != default:
1441+
# TODO: Refactor this logic, its pretty convoluted
14411442
if "python" in engine and argname not in _python_unsupported:
14421443
pass
1444+
elif "pyarrow" in engine and argname not in _pyarrow_unsupported:
1445+
pass
14431446
else:
14441447
raise ValueError(
14451448
f"The {repr(argname)} option is not supported with the "

pandas/tests/io/parser/dtypes/test_dtypes_basic.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -423,13 +423,9 @@ def test_dtype_backend(all_parsers):
423423
"e": pd.Series([pd.NA, 6], dtype="Int64"),
424424
"f": pd.Series([pd.NA, 7.5], dtype="Float64"),
425425
"g": pd.Series([pd.NA, True], dtype="boolean"),
426-
"h": pd.Series(
427-
[pd.NA if parser.engine != "pyarrow" else "", "a"], dtype="string"
428-
),
426+
"h": pd.Series([pd.NA, "a"], dtype="string"),
429427
"i": pd.Series([Timestamp("2019-12-31")] * 2),
430-
"j": pd.Series(
431-
[pd.NA, pd.NA], dtype="Int64" if parser.engine != "pyarrow" else object
432-
),
428+
"j": pd.Series([pd.NA, pd.NA], dtype="Int64"),
433429
}
434430
)
435431
tm.assert_frame_equal(result, expected)
@@ -451,7 +447,6 @@ def test_dtype_backend_and_dtype(all_parsers):
451447
tm.assert_frame_equal(result, expected)
452448

453449

454-
@pytest.mark.usefixtures("pyarrow_xfail")
455450
def test_dtype_backend_string(all_parsers, string_storage):
456451
# GH#36712
457452
pa = pytest.importorskip("pyarrow")
@@ -499,7 +494,6 @@ def test_dtype_backend_pyarrow(all_parsers, request):
499494
# GH#36712
500495
pa = pytest.importorskip("pyarrow")
501496
parser = all_parsers
502-
engine = parser.engine
503497

504498
data = """a,b,c,d,e,f,g,h,i,j
505499
1,2.5,True,a,,,,,12-31-2019,
@@ -516,7 +510,7 @@ def test_dtype_backend_pyarrow(all_parsers, request):
516510
"f": pd.Series([pd.NA, 7.5], dtype="float64[pyarrow]"),
517511
"g": pd.Series([pd.NA, True], dtype="bool[pyarrow]"),
518512
"h": pd.Series(
519-
[pd.NA if engine != "pyarrow" else "", "a"],
513+
[pd.NA, "a"],
520514
dtype=pd.ArrowDtype(pa.string()),
521515
),
522516
"i": pd.Series([Timestamp("2019-12-31")] * 2),

pandas/tests/io/parser/test_na_values.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
xfail_pyarrow = pytest.mark.usefixtures("pyarrow_xfail")
2121

2222

23-
@skip_pyarrow
2423
def test_string_nas(all_parsers):
2524
parser = all_parsers
2625
data = """A,B,C
@@ -36,7 +35,6 @@ def test_string_nas(all_parsers):
3635
tm.assert_frame_equal(result, expected)
3736

3837

39-
@skip_pyarrow
4038
def test_detect_string_na(all_parsers):
4139
parser = all_parsers
4240
data = """A,B
@@ -89,7 +87,6 @@ def test_non_string_na_values(all_parsers, data, na_values):
8987
tm.assert_frame_equal(result, expected)
9088

9189

92-
@skip_pyarrow
9390
def test_default_na_values(all_parsers):
9491
_NA_VALUES = {
9592
"-1.#IND",
@@ -138,6 +135,7 @@ def f(i, v):
138135
tm.assert_frame_equal(result, expected)
139136

140137

138+
# TODO: needs skiprows list support in pyarrow
141139
@skip_pyarrow
142140
@pytest.mark.parametrize("na_values", ["baz", ["baz"]])
143141
def test_custom_na_values(all_parsers, na_values):
@@ -172,6 +170,7 @@ def test_bool_na_values(all_parsers):
172170
tm.assert_frame_equal(result, expected)
173171

174172

173+
# TODO: Needs pyarrow support for dictionary in na_values
175174
@skip_pyarrow
176175
def test_na_value_dict(all_parsers):
177176
data = """A,B,C
@@ -191,7 +190,6 @@ def test_na_value_dict(all_parsers):
191190
tm.assert_frame_equal(df, expected)
192191

193192

194-
@skip_pyarrow
195193
@pytest.mark.parametrize(
196194
"index_col,expected",
197195
[
@@ -225,6 +223,7 @@ def test_na_value_dict_multi_index(all_parsers, index_col, expected):
225223
tm.assert_frame_equal(result, expected)
226224

227225

226+
# TODO: xfail components of this test, the first one passes
228227
@skip_pyarrow
229228
@pytest.mark.parametrize(
230229
"kwargs,expected",
@@ -287,7 +286,6 @@ def test_na_values_keep_default(all_parsers, kwargs, expected):
287286
tm.assert_frame_equal(result, expected)
288287

289288

290-
@skip_pyarrow
291289
def test_no_na_values_no_keep_default(all_parsers):
292290
# see gh-4318: passing na_values=None and
293291
# keep_default_na=False yields 'None" as a na_value
@@ -314,6 +312,7 @@ def test_no_na_values_no_keep_default(all_parsers):
314312
tm.assert_frame_equal(result, expected)
315313

316314

315+
# TODO: Blocked on na_values dict support in pyarrow
317316
@skip_pyarrow
318317
def test_no_keep_default_na_dict_na_values(all_parsers):
319318
# see gh-19227
@@ -326,6 +325,7 @@ def test_no_keep_default_na_dict_na_values(all_parsers):
326325
tm.assert_frame_equal(result, expected)
327326

328327

328+
# TODO: Blocked on na_values dict support in pyarrow
329329
@skip_pyarrow
330330
def test_no_keep_default_na_dict_na_scalar_values(all_parsers):
331331
# see gh-19227
@@ -338,6 +338,7 @@ def test_no_keep_default_na_dict_na_scalar_values(all_parsers):
338338
tm.assert_frame_equal(df, expected)
339339

340340

341+
# TODO: Blocked on na_values dict support in pyarrow
341342
@skip_pyarrow
342343
@pytest.mark.parametrize("col_zero_na_values", [113125, "113125"])
343344
def test_no_keep_default_na_dict_na_values_diff_reprs(all_parsers, col_zero_na_values):
@@ -368,6 +369,7 @@ def test_no_keep_default_na_dict_na_values_diff_reprs(all_parsers, col_zero_na_v
368369
tm.assert_frame_equal(result, expected)
369370

370371

372+
# TODO: Empty null_values doesn't work properly on pyarrow
371373
@skip_pyarrow
372374
@pytest.mark.parametrize(
373375
"na_filter,row_data",
@@ -390,6 +392,7 @@ def test_na_values_na_filter_override(all_parsers, na_filter, row_data):
390392
tm.assert_frame_equal(result, expected)
391393

392394

395+
# TODO: Arrow parse error
393396
@skip_pyarrow
394397
def test_na_trailing_columns(all_parsers):
395398
parser = all_parsers
@@ -418,6 +421,7 @@ def test_na_trailing_columns(all_parsers):
418421
tm.assert_frame_equal(result, expected)
419422

420423

424+
# TODO: xfail the na_values dict case
421425
@skip_pyarrow
422426
@pytest.mark.parametrize(
423427
"na_values,row_data",
@@ -495,6 +499,7 @@ def test_empty_na_values_no_default_with_index(all_parsers):
495499
tm.assert_frame_equal(result, expected)
496500

497501

502+
# TODO: Missing support for na_filter kewyord
498503
@skip_pyarrow
499504
@pytest.mark.parametrize(
500505
"na_filter,index_data", [(False, ["", "5"]), (True, [np.nan, 5.0])]

pandas/tests/io/parser/test_parse_dates.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -1252,19 +1252,7 @@ def test_bad_date_parse(all_parsers, cache_dates, value):
12521252
parser = all_parsers
12531253
s = StringIO((f"{value},\n") * 50000)
12541254

1255-
if parser.engine == "pyarrow" and not cache_dates:
1256-
# None in input gets converted to 'None', for which
1257-
# pandas tries to guess the datetime format, triggering
1258-
# the warning. TODO: parse dates directly in pyarrow, see
1259-
# https://github.com/pandas-dev/pandas/issues/48017
1260-
warn = UserWarning
1261-
else:
1262-
# Note: warning is not raised if 'cache_dates', because here there is only a
1263-
# single unique date and hence no risk of inconsistent parsing.
1264-
warn = None
1265-
parser.read_csv_check_warnings(
1266-
warn,
1267-
"Could not infer format",
1255+
parser.read_csv(
12681256
s,
12691257
header=None,
12701258
names=["foo", "bar"],

0 commit comments

Comments
 (0)