Skip to content

Commit f332143

Browse files
authored
ENH: Implement io.nullable_backend config for read_csv(engine="pyarrow") (#49366)
1 parent fa41c52 commit f332143

File tree

4 files changed

+101
-20
lines changed

4 files changed

+101
-20
lines changed

doc/source/whatsnew/v2.0.0.rst

+17-4
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,24 @@ Available optional dependencies (listed in order of appearance at `install guide
2828
``[all, performance, computation, timezone, fss, aws, gcp, excel, parquet, feather, hdf5, spss, postgresql, mysql,
2929
sql-other, html, xml, plot, output_formatting, clipboard, compression, test]`` (:issue:`39164`).
3030

31-
.. _whatsnew_200.enhancements.enhancement2:
31+
.. _whatsnew_200.enhancements.io_readers_nullable_pyarrow:
3232

33-
enhancement2
34-
^^^^^^^^^^^^
33+
Configuration option, ``io.nullable_backend``, to return pyarrow-backed dtypes from IO functions
34+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
35+
36+
A new global configuration, ``io.nullable_backend`` can now be used in conjunction with the parameter ``use_nullable_dtypes=True`` in :func:`read_parquet` and :func:`read_csv` (with ``engine="pyarrow"``)
37+
to return pyarrow-backed dtypes when set to ``"pyarrow"`` (:issue:`48957`).
38+
39+
.. ipython:: python
40+
41+
import io
42+
data = io.StringIO("""a,b,c,d,e,f,g,h,i
43+
1,2.5,True,a,,,,,
44+
3,4.5,False,b,6,7.5,True,a,
45+
""")
46+
with pd.option_context("io.nullable_backend", "pyarrow"):
47+
df = pd.read_csv(data, use_nullable_dtypes=True, engine="pyarrow")
48+
df
3549
3650
.. _whatsnew_200.enhancements.other:
3751

@@ -42,7 +56,6 @@ Other enhancements
4256
- :meth:`Series.add_suffix`, :meth:`DataFrame.add_suffix`, :meth:`Series.add_prefix` and :meth:`DataFrame.add_prefix` support an ``axis`` argument. If ``axis`` is set, the default behaviour of which axis to consider can be overwritten (:issue:`47819`)
4357
- :func:`assert_frame_equal` now shows the first element where the DataFrames differ, analogously to ``pytest``'s output (:issue:`47910`)
4458
- Added new argument ``use_nullable_dtypes`` to :func:`read_csv` and :func:`read_excel` to enable automatic conversion to nullable dtypes (:issue:`36712`)
45-
- Added new global configuration, ``io.nullable_backend`` to allow ``use_nullable_dtypes=True`` to return pyarrow-backed dtypes when set to ``"pyarrow"`` in :func:`read_parquet` (:issue:`48957`)
4659
- Added ``index`` parameter to :meth:`DataFrame.to_dict` (:issue:`46398`)
4760
- Added metadata propagation for binary operators on :class:`DataFrame` (:issue:`28283`)
4861
- :class:`.CategoricalConversionWarning`, :class:`.InvalidComparison`, :class:`.InvalidVersion`, :class:`.LossySetitemError`, and :class:`.NoBufferPresent` are now exposed in ``pandas.errors`` (:issue:`27656`)

pandas/io/parsers/arrow_parser_wrapper.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
4-
53
from pandas._typing import ReadBuffer
64
from pandas.compat._optional import import_optional_dependency
75

86
from pandas.core.dtypes.inference import is_integer
97

10-
from pandas.io.parsers.base_parser import ParserBase
8+
from pandas import (
9+
DataFrame,
10+
arrays,
11+
get_option,
12+
)
1113

12-
if TYPE_CHECKING:
13-
from pandas import DataFrame
14+
from pandas.io.parsers.base_parser import ParserBase
1415

1516

1617
class ArrowParserWrapper(ParserBase):
@@ -77,7 +78,7 @@ def _get_pyarrow_options(self) -> None:
7778
else self.kwds["skiprows"],
7879
}
7980

80-
def _finalize_output(self, frame: DataFrame) -> DataFrame:
81+
def _finalize_pandas_output(self, frame: DataFrame) -> DataFrame:
8182
"""
8283
Processes data read in based on kwargs.
8384
@@ -148,6 +149,16 @@ def read(self) -> DataFrame:
148149
parse_options=pyarrow_csv.ParseOptions(**self.parse_options),
149150
convert_options=pyarrow_csv.ConvertOptions(**self.convert_options),
150151
)
151-
152-
frame = table.to_pandas()
153-
return self._finalize_output(frame)
152+
if (
153+
self.kwds["use_nullable_dtypes"]
154+
and get_option("io.nullable_backend") == "pyarrow"
155+
):
156+
frame = DataFrame(
157+
{
158+
col_name: arrays.ArrowExtensionArray(pa_col)
159+
for col_name, pa_col in zip(table.column_names, table.itercolumns())
160+
}
161+
)
162+
else:
163+
frame = table.to_pandas()
164+
return self._finalize_pandas_output(frame)

pandas/io/parsers/readers.py

+10
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
import numpy as np
2626

27+
from pandas._config import get_option
28+
2729
from pandas._libs import lib
2830
from pandas._libs.parsers import STR_NA_VALUES
2931
from pandas._typing import (
@@ -560,6 +562,14 @@ def _read(
560562
raise ValueError(
561563
"The 'chunksize' option is not supported with the 'pyarrow' engine"
562564
)
565+
elif (
566+
kwds.get("use_nullable_dtypes", False)
567+
and get_option("io.nullable_backend") == "pyarrow"
568+
):
569+
raise NotImplementedError(
570+
f"use_nullable_dtypes=True and engine={kwds['engine']} with "
571+
"io.nullable_backend set to 'pyarrow' is not implemented."
572+
)
563573
else:
564574
chunksize = validate_integer("chunksize", chunksize, 1)
565575

pandas/tests/io/parser/dtypes/test_dtypes_basic.py

+54-7
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import pytest
1010

1111
from pandas.errors import ParserWarning
12-
import pandas.util._test_decorators as td
1312

1413
import pandas as pd
1514
from pandas import (
@@ -22,13 +21,10 @@
2221
StringArray,
2322
)
2423

25-
# TODO(1.4): Change me into xfail at release time
26-
# and xfail individual tests
27-
pytestmark = pytest.mark.usefixtures("pyarrow_skip")
28-
2924

3025
@pytest.mark.parametrize("dtype", [str, object])
3126
@pytest.mark.parametrize("check_orig", [True, False])
27+
@pytest.mark.usefixtures("pyarrow_xfail")
3228
def test_dtype_all_columns(all_parsers, dtype, check_orig):
3329
# see gh-3795, gh-6607
3430
parser = all_parsers
@@ -53,6 +49,7 @@ def test_dtype_all_columns(all_parsers, dtype, check_orig):
5349
tm.assert_frame_equal(result, expected)
5450

5551

52+
@pytest.mark.usefixtures("pyarrow_xfail")
5653
def test_dtype_per_column(all_parsers):
5754
parser = all_parsers
5855
data = """\
@@ -71,6 +68,7 @@ def test_dtype_per_column(all_parsers):
7168
tm.assert_frame_equal(result, expected)
7269

7370

71+
@pytest.mark.usefixtures("pyarrow_xfail")
7472
def test_invalid_dtype_per_column(all_parsers):
7573
parser = all_parsers
7674
data = """\
@@ -84,6 +82,7 @@ def test_invalid_dtype_per_column(all_parsers):
8482
parser.read_csv(StringIO(data), dtype={"one": "foo", 1: "int"})
8583

8684

85+
@pytest.mark.usefixtures("pyarrow_xfail")
8786
def test_raise_on_passed_int_dtype_with_nas(all_parsers):
8887
# see gh-2631
8988
parser = all_parsers
@@ -101,6 +100,7 @@ def test_raise_on_passed_int_dtype_with_nas(all_parsers):
101100
parser.read_csv(StringIO(data), dtype={"DOY": np.int64}, skipinitialspace=True)
102101

103102

103+
@pytest.mark.usefixtures("pyarrow_xfail")
104104
def test_dtype_with_converters(all_parsers):
105105
parser = all_parsers
106106
data = """a,b
@@ -132,6 +132,7 @@ def test_numeric_dtype(all_parsers, dtype):
132132
tm.assert_frame_equal(expected, result)
133133

134134

135+
@pytest.mark.usefixtures("pyarrow_xfail")
135136
def test_boolean_dtype(all_parsers):
136137
parser = all_parsers
137138
data = "\n".join(
@@ -184,6 +185,7 @@ def test_boolean_dtype(all_parsers):
184185
tm.assert_frame_equal(result, expected)
185186

186187

188+
@pytest.mark.usefixtures("pyarrow_xfail")
187189
def test_delimiter_with_usecols_and_parse_dates(all_parsers):
188190
# GH#35873
189191
result = all_parsers.read_csv(
@@ -264,6 +266,7 @@ def test_skip_whitespace(c_parser_only, float_precision):
264266
tm.assert_series_equal(df.iloc[:, 1], pd.Series([1.2, 2.1, 1.0, 1.2], name="num"))
265267

266268

269+
@pytest.mark.usefixtures("pyarrow_xfail")
267270
def test_true_values_cast_to_bool(all_parsers):
268271
# GH#34655
269272
text = """a,b
@@ -286,6 +289,7 @@ def test_true_values_cast_to_bool(all_parsers):
286289
tm.assert_frame_equal(result, expected)
287290

288291

292+
@pytest.mark.usefixtures("pyarrow_xfail")
289293
@pytest.mark.parametrize("dtypes, exp_value", [({}, "1"), ({"a.1": "int64"}, 1)])
290294
def test_dtype_mangle_dup_cols(all_parsers, dtypes, exp_value):
291295
# GH#35211
@@ -300,6 +304,7 @@ def test_dtype_mangle_dup_cols(all_parsers, dtypes, exp_value):
300304
tm.assert_frame_equal(result, expected)
301305

302306

307+
@pytest.mark.usefixtures("pyarrow_xfail")
303308
def test_dtype_mangle_dup_cols_single_dtype(all_parsers):
304309
# GH#42022
305310
parser = all_parsers
@@ -309,6 +314,7 @@ def test_dtype_mangle_dup_cols_single_dtype(all_parsers):
309314
tm.assert_frame_equal(result, expected)
310315

311316

317+
@pytest.mark.usefixtures("pyarrow_xfail")
312318
def test_dtype_multi_index(all_parsers):
313319
# GH 42446
314320
parser = all_parsers
@@ -355,6 +361,7 @@ def test_nullable_int_dtype(all_parsers, any_int_ea_dtype):
355361
tm.assert_frame_equal(actual, expected)
356362

357363

364+
@pytest.mark.usefixtures("pyarrow_xfail")
358365
@pytest.mark.parametrize("default", ["float", "float64"])
359366
def test_dtypes_defaultdict(all_parsers, default):
360367
# GH#41574
@@ -368,6 +375,7 @@ def test_dtypes_defaultdict(all_parsers, default):
368375
tm.assert_frame_equal(result, expected)
369376

370377

378+
@pytest.mark.usefixtures("pyarrow_xfail")
371379
def test_dtypes_defaultdict_mangle_dup_cols(all_parsers):
372380
# GH#41574
373381
data = """a,b,a,b,b.1
@@ -381,6 +389,7 @@ def test_dtypes_defaultdict_mangle_dup_cols(all_parsers):
381389
tm.assert_frame_equal(result, expected)
382390

383391

392+
@pytest.mark.usefixtures("pyarrow_xfail")
384393
def test_dtypes_defaultdict_invalid(all_parsers):
385394
# GH#41574
386395
data = """a,b
@@ -392,6 +401,7 @@ def test_dtypes_defaultdict_invalid(all_parsers):
392401
parser.read_csv(StringIO(data), dtype=dtype)
393402

394403

404+
@pytest.mark.usefixtures("pyarrow_xfail")
395405
def test_use_nullable_dtypes(all_parsers):
396406
# GH#36712
397407

@@ -435,11 +445,11 @@ def test_use_nullabla_dtypes_and_dtype(all_parsers):
435445
tm.assert_frame_equal(result, expected)
436446

437447

438-
@td.skip_if_no("pyarrow")
448+
@pytest.mark.usefixtures("pyarrow_xfail")
439449
@pytest.mark.parametrize("storage", ["pyarrow", "python"])
440450
def test_use_nullable_dtypes_string(all_parsers, storage):
441451
# GH#36712
442-
import pyarrow as pa
452+
pa = pytest.importorskip("pyarrow")
443453

444454
with pd.option_context("mode.string_storage", storage):
445455

@@ -477,3 +487,40 @@ def test_use_nullable_dtypes_ea_dtype_specified(all_parsers):
477487
result = parser.read_csv(StringIO(data), dtype="Int64", use_nullable_dtypes=True)
478488
expected = DataFrame({"a": [1], "b": 2}, dtype="Int64")
479489
tm.assert_frame_equal(result, expected)
490+
491+
492+
def test_use_nullable_dtypes_pyarrow_backend(all_parsers, request):
493+
# GH#36712
494+
pa = pytest.importorskip("pyarrow")
495+
parser = all_parsers
496+
497+
data = """a,b,c,d,e,f,g,h,i,j
498+
1,2.5,True,a,,,,,12-31-2019,
499+
3,4.5,False,b,6,7.5,True,a,12-31-2019,
500+
"""
501+
with pd.option_context("io.nullable_backend", "pyarrow"):
502+
if parser.engine != "pyarrow":
503+
request.node.add_marker(
504+
pytest.mark.xfail(
505+
raises=NotImplementedError,
506+
reason=f"Not implemented with engine={parser.engine}",
507+
)
508+
)
509+
result = parser.read_csv(
510+
StringIO(data), use_nullable_dtypes=True, parse_dates=["i"]
511+
)
512+
expected = DataFrame(
513+
{
514+
"a": pd.Series([1, 3], dtype="int64[pyarrow]"),
515+
"b": pd.Series([2.5, 4.5], dtype="float64[pyarrow]"),
516+
"c": pd.Series([True, False], dtype="bool[pyarrow]"),
517+
"d": pd.Series(["a", "b"], dtype=pd.ArrowDtype(pa.string())),
518+
"e": pd.Series([pd.NA, 6], dtype="int64[pyarrow]"),
519+
"f": pd.Series([pd.NA, 7.5], dtype="float64[pyarrow]"),
520+
"g": pd.Series([pd.NA, True], dtype="bool[pyarrow]"),
521+
"h": pd.Series(["", "a"], dtype=pd.ArrowDtype(pa.string())),
522+
"i": pd.Series([Timestamp("2019-12-31")] * 2),
523+
"j": pd.Series([pd.NA, pd.NA], dtype="null[pyarrow]"),
524+
}
525+
)
526+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)