From eb80d52d28b5dd8e21e58cb5887ddf78ee4c9d5d Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Wed, 15 Mar 2023 13:58:14 +0100 Subject: [PATCH] BUG: Add numpy_nullable support to arrow csv parser --- pandas/io/parsers/arrow_parser_wrapper.py | 3 +++ pandas/tests/io/parser/dtypes/test_dtypes_basic.py | 9 ++++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/pandas/io/parsers/arrow_parser_wrapper.py b/pandas/io/parsers/arrow_parser_wrapper.py index b98a31e3f940b..30fc65dca7ca1 100644 --- a/pandas/io/parsers/arrow_parser_wrapper.py +++ b/pandas/io/parsers/arrow_parser_wrapper.py @@ -9,6 +9,7 @@ import pandas as pd from pandas import DataFrame +from pandas.io._util import _arrow_dtype_mapping from pandas.io.parsers.base_parser import ParserBase if TYPE_CHECKING: @@ -151,6 +152,8 @@ def read(self) -> DataFrame: ) if self.kwds["dtype_backend"] == "pyarrow": frame = table.to_pandas(types_mapper=pd.ArrowDtype) + elif self.kwds["dtype_backend"] == "numpy_nullable": + frame = table.to_pandas(types_mapper=_arrow_dtype_mapping().get) else: frame = table.to_pandas() return self._finalize_pandas_output(frame) diff --git a/pandas/tests/io/parser/dtypes/test_dtypes_basic.py b/pandas/tests/io/parser/dtypes/test_dtypes_basic.py index d0e5cd02767bf..bb05b000c184f 100644 --- a/pandas/tests/io/parser/dtypes/test_dtypes_basic.py +++ b/pandas/tests/io/parser/dtypes/test_dtypes_basic.py @@ -402,7 +402,6 @@ def test_dtypes_defaultdict_invalid(all_parsers): parser.read_csv(StringIO(data), dtype=dtype) -@pytest.mark.usefixtures("pyarrow_xfail") def test_dtype_backend(all_parsers): # GH#36712 @@ -424,9 +423,13 @@ def test_dtype_backend(all_parsers): "e": pd.Series([pd.NA, 6], dtype="Int64"), "f": pd.Series([pd.NA, 7.5], dtype="Float64"), "g": pd.Series([pd.NA, True], dtype="boolean"), - "h": pd.Series([pd.NA, "a"], dtype="string"), + "h": pd.Series( + [pd.NA if parser.engine != "pyarrow" else "", "a"], dtype="string" + ), "i": pd.Series([Timestamp("2019-12-31")] * 2), - "j": pd.Series([pd.NA, pd.NA], dtype="Int64"), + "j": pd.Series( + [pd.NA, pd.NA], dtype="Int64" if parser.engine != "pyarrow" else object + ), } ) tm.assert_frame_equal(result, expected)