Skip to content

Commit f29ef30

Browse files
authored
BUG: Fix some more arrow CSV tests (#52087)
1 parent 8534e13 commit f29ef30

File tree

6 files changed

+45
-31
lines changed

6 files changed

+45
-31
lines changed

doc/source/whatsnew/v2.1.0.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,8 @@ Period
334334
- :meth:`PeriodIndex.map` with ``na_action="ignore"`` now works as expected (:issue:`51644`)
335335
- Bug in :class:`PeriodDtype` constructor raising ``ValueError`` instead of ``TypeError`` when an invalid type is passed (:issue:`51790`)
336336
- Bug in :meth:`arrays.PeriodArray.map` and :meth:`PeriodIndex.map`, where the supplied callable operated array-wise instead of element-wise (:issue:`51977`)
337-
-
337+
- Bug in :func:`read_csv` not processing empty strings as a null value, with ``engine="pyarrow"`` (:issue:`52087`)
338+
- Bug in :func:`read_csv` returning ``object`` dtype columns instead of ``float64`` dtype columns with ``engine="pyarrow"`` for columns that are all null with ``engine="pyarrow"`` (:issue:`52087`)
338339

339340
Plotting
340341
^^^^^^^^

pandas/io/parsers/arrow_parser_wrapper.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import TYPE_CHECKING
44

5+
from pandas._libs import lib
56
from pandas.compat._optional import import_optional_dependency
67

78
from pandas.core.dtypes.inference import is_integer
@@ -80,6 +81,7 @@ def _get_pyarrow_options(self) -> None:
8081
"decimal_point",
8182
)
8283
}
84+
self.convert_options["strings_can_be_null"] = "" in self.kwds["null_values"]
8385
self.read_options = {
8486
"autogenerate_column_names": self.header is None,
8587
"skip_rows": self.header
@@ -149,6 +151,7 @@ def read(self) -> DataFrame:
149151
DataFrame
150152
The DataFrame created from the CSV file.
151153
"""
154+
pa = import_optional_dependency("pyarrow")
152155
pyarrow_csv = import_optional_dependency("pyarrow.csv")
153156
self._get_pyarrow_options()
154157

@@ -158,10 +161,30 @@ def read(self) -> DataFrame:
158161
parse_options=pyarrow_csv.ParseOptions(**self.parse_options),
159162
convert_options=pyarrow_csv.ConvertOptions(**self.convert_options),
160163
)
161-
if self.kwds["dtype_backend"] == "pyarrow":
164+
165+
dtype_backend = self.kwds["dtype_backend"]
166+
167+
# Convert all pa.null() cols -> float64 (non nullable)
168+
# else Int64 (nullable case, see below)
169+
if dtype_backend is lib.no_default:
170+
new_schema = table.schema
171+
new_type = pa.float64()
172+
for i, arrow_type in enumerate(table.schema.types):
173+
if pa.types.is_null(arrow_type):
174+
new_schema = new_schema.set(
175+
i, new_schema.field(i).with_type(new_type)
176+
)
177+
178+
table = table.cast(new_schema)
179+
180+
if dtype_backend == "pyarrow":
162181
frame = table.to_pandas(types_mapper=pd.ArrowDtype)
163-
elif self.kwds["dtype_backend"] == "numpy_nullable":
164-
frame = table.to_pandas(types_mapper=_arrow_dtype_mapping().get)
182+
elif dtype_backend == "numpy_nullable":
183+
# Modify the default mapping to also
184+
# map null to Int64 (to match other engines)
185+
dtype_mapping = _arrow_dtype_mapping()
186+
dtype_mapping[pa.null()] = pd.Int64Dtype()
187+
frame = table.to_pandas(types_mapper=dtype_mapping.get)
165188
else:
166189
frame = table.to_pandas()
167190
return self._finalize_pandas_output(frame)

pandas/io/parsers/readers.py

+3
Original file line numberDiff line numberDiff line change
@@ -1460,8 +1460,11 @@ def _get_options_with_defaults(self, engine: CSVEngine) -> dict[str, Any]:
14601460
value = kwds[argname]
14611461

14621462
if engine != "c" and value != default:
1463+
# TODO: Refactor this logic, its pretty convoluted
14631464
if "python" in engine and argname not in _python_unsupported:
14641465
pass
1466+
elif "pyarrow" in engine and argname not in _pyarrow_unsupported:
1467+
pass
14651468
else:
14661469
raise ValueError(
14671470
f"The {repr(argname)} option is not supported with the "

pandas/tests/io/parser/dtypes/test_dtypes_basic.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -423,13 +423,9 @@ def test_dtype_backend(all_parsers):
423423
"e": pd.Series([pd.NA, 6], dtype="Int64"),
424424
"f": pd.Series([pd.NA, 7.5], dtype="Float64"),
425425
"g": pd.Series([pd.NA, True], dtype="boolean"),
426-
"h": pd.Series(
427-
[pd.NA if parser.engine != "pyarrow" else "", "a"], dtype="string"
428-
),
426+
"h": pd.Series([pd.NA, "a"], dtype="string"),
429427
"i": pd.Series([Timestamp("2019-12-31")] * 2),
430-
"j": pd.Series(
431-
[pd.NA, pd.NA], dtype="Int64" if parser.engine != "pyarrow" else object
432-
),
428+
"j": pd.Series([pd.NA, pd.NA], dtype="Int64"),
433429
}
434430
)
435431
tm.assert_frame_equal(result, expected)
@@ -451,7 +447,6 @@ def test_dtype_backend_and_dtype(all_parsers):
451447
tm.assert_frame_equal(result, expected)
452448

453449

454-
@pytest.mark.usefixtures("pyarrow_xfail")
455450
def test_dtype_backend_string(all_parsers, string_storage):
456451
# GH#36712
457452
pa = pytest.importorskip("pyarrow")
@@ -499,7 +494,6 @@ def test_dtype_backend_pyarrow(all_parsers, request):
499494
# GH#36712
500495
pa = pytest.importorskip("pyarrow")
501496
parser = all_parsers
502-
engine = parser.engine
503497

504498
data = """a,b,c,d,e,f,g,h,i,j
505499
1,2.5,True,a,,,,,12-31-2019,
@@ -516,7 +510,7 @@ def test_dtype_backend_pyarrow(all_parsers, request):
516510
"f": pd.Series([pd.NA, 7.5], dtype="float64[pyarrow]"),
517511
"g": pd.Series([pd.NA, True], dtype="bool[pyarrow]"),
518512
"h": pd.Series(
519-
[pd.NA if engine != "pyarrow" else "", "a"],
513+
[pd.NA, "a"],
520514
dtype=pd.ArrowDtype(pa.string()),
521515
),
522516
"i": pd.Series([Timestamp("2019-12-31")] * 2),

pandas/tests/io/parser/test_na_values.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
xfail_pyarrow = pytest.mark.usefixtures("pyarrow_xfail")
2121

2222

23-
@skip_pyarrow
2423
def test_string_nas(all_parsers):
2524
parser = all_parsers
2625
data = """A,B,C
@@ -36,7 +35,6 @@ def test_string_nas(all_parsers):
3635
tm.assert_frame_equal(result, expected)
3736

3837

39-
@skip_pyarrow
4038
def test_detect_string_na(all_parsers):
4139
parser = all_parsers
4240
data = """A,B
@@ -89,7 +87,6 @@ def test_non_string_na_values(all_parsers, data, na_values):
8987
tm.assert_frame_equal(result, expected)
9088

9189

92-
@skip_pyarrow
9390
def test_default_na_values(all_parsers):
9491
_NA_VALUES = {
9592
"-1.#IND",
@@ -138,6 +135,7 @@ def f(i, v):
138135
tm.assert_frame_equal(result, expected)
139136

140137

138+
# TODO: needs skiprows list support in pyarrow
141139
@skip_pyarrow
142140
@pytest.mark.parametrize("na_values", ["baz", ["baz"]])
143141
def test_custom_na_values(all_parsers, na_values):
@@ -172,6 +170,7 @@ def test_bool_na_values(all_parsers):
172170
tm.assert_frame_equal(result, expected)
173171

174172

173+
# TODO: Needs pyarrow support for dictionary in na_values
175174
@skip_pyarrow
176175
def test_na_value_dict(all_parsers):
177176
data = """A,B,C
@@ -191,7 +190,6 @@ def test_na_value_dict(all_parsers):
191190
tm.assert_frame_equal(df, expected)
192191

193192

194-
@skip_pyarrow
195193
@pytest.mark.parametrize(
196194
"index_col,expected",
197195
[
@@ -225,6 +223,7 @@ def test_na_value_dict_multi_index(all_parsers, index_col, expected):
225223
tm.assert_frame_equal(result, expected)
226224

227225

226+
# TODO: xfail components of this test, the first one passes
228227
@skip_pyarrow
229228
@pytest.mark.parametrize(
230229
"kwargs,expected",
@@ -287,7 +286,6 @@ def test_na_values_keep_default(all_parsers, kwargs, expected):
287286
tm.assert_frame_equal(result, expected)
288287

289288

290-
@skip_pyarrow
291289
def test_no_na_values_no_keep_default(all_parsers):
292290
# see gh-4318: passing na_values=None and
293291
# keep_default_na=False yields 'None" as a na_value
@@ -314,6 +312,7 @@ def test_no_na_values_no_keep_default(all_parsers):
314312
tm.assert_frame_equal(result, expected)
315313

316314

315+
# TODO: Blocked on na_values dict support in pyarrow
317316
@skip_pyarrow
318317
def test_no_keep_default_na_dict_na_values(all_parsers):
319318
# see gh-19227
@@ -326,6 +325,7 @@ def test_no_keep_default_na_dict_na_values(all_parsers):
326325
tm.assert_frame_equal(result, expected)
327326

328327

328+
# TODO: Blocked on na_values dict support in pyarrow
329329
@skip_pyarrow
330330
def test_no_keep_default_na_dict_na_scalar_values(all_parsers):
331331
# see gh-19227
@@ -338,6 +338,7 @@ def test_no_keep_default_na_dict_na_scalar_values(all_parsers):
338338
tm.assert_frame_equal(df, expected)
339339

340340

341+
# TODO: Blocked on na_values dict support in pyarrow
341342
@skip_pyarrow
342343
@pytest.mark.parametrize("col_zero_na_values", [113125, "113125"])
343344
def test_no_keep_default_na_dict_na_values_diff_reprs(all_parsers, col_zero_na_values):
@@ -368,6 +369,7 @@ def test_no_keep_default_na_dict_na_values_diff_reprs(all_parsers, col_zero_na_v
368369
tm.assert_frame_equal(result, expected)
369370

370371

372+
# TODO: Empty null_values doesn't work properly on pyarrow
371373
@skip_pyarrow
372374
@pytest.mark.parametrize(
373375
"na_filter,row_data",
@@ -390,6 +392,7 @@ def test_na_values_na_filter_override(all_parsers, na_filter, row_data):
390392
tm.assert_frame_equal(result, expected)
391393

392394

395+
# TODO: Arrow parse error
393396
@skip_pyarrow
394397
def test_na_trailing_columns(all_parsers):
395398
parser = all_parsers
@@ -418,6 +421,7 @@ def test_na_trailing_columns(all_parsers):
418421
tm.assert_frame_equal(result, expected)
419422

420423

424+
# TODO: xfail the na_values dict case
421425
@skip_pyarrow
422426
@pytest.mark.parametrize(
423427
"na_values,row_data",
@@ -495,6 +499,7 @@ def test_empty_na_values_no_default_with_index(all_parsers):
495499
tm.assert_frame_equal(result, expected)
496500

497501

502+
# TODO: Missing support for na_filter kewyord
498503
@skip_pyarrow
499504
@pytest.mark.parametrize(
500505
"na_filter,index_data", [(False, ["", "5"]), (True, [np.nan, 5.0])]

pandas/tests/io/parser/test_parse_dates.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -1252,19 +1252,7 @@ def test_bad_date_parse(all_parsers, cache_dates, value):
12521252
parser = all_parsers
12531253
s = StringIO((f"{value},\n") * 50000)
12541254

1255-
if parser.engine == "pyarrow" and not cache_dates:
1256-
# None in input gets converted to 'None', for which
1257-
# pandas tries to guess the datetime format, triggering
1258-
# the warning. TODO: parse dates directly in pyarrow, see
1259-
# https://github.com/pandas-dev/pandas/issues/48017
1260-
warn = UserWarning
1261-
else:
1262-
# Note: warning is not raised if 'cache_dates', because here there is only a
1263-
# single unique date and hence no risk of inconsistent parsing.
1264-
warn = None
1265-
parser.read_csv_check_warnings(
1266-
warn,
1267-
"Could not infer format",
1255+
parser.read_csv(
12681256
s,
12691257
header=None,
12701258
names=["foo", "bar"],

0 commit comments

Comments
 (0)