Skip to content

Commit 12d6f60

Browse files
REF: centralize pyarrow Table to pandas conversions and types_mapper handling (#60324)
1 parent ee3c18f commit 12d6f60

File tree

8 files changed

+63
-122
lines changed

8 files changed

+63
-122
lines changed

pandas/io/_util.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import (
4+
TYPE_CHECKING,
5+
Literal,
6+
)
47

58
import numpy as np
69

10+
from pandas._config import using_string_dtype
11+
12+
from pandas._libs import lib
713
from pandas.compat import pa_version_under18p0
814
from pandas.compat._optional import import_optional_dependency
915

@@ -12,6 +18,10 @@
1218
if TYPE_CHECKING:
1319
from collections.abc import Callable
1420

21+
import pyarrow
22+
23+
from pandas._typing import DtypeBackend
24+
1525

1626
def _arrow_dtype_mapping() -> dict:
1727
pa = import_optional_dependency("pyarrow")
@@ -33,7 +43,7 @@ def _arrow_dtype_mapping() -> dict:
3343
}
3444

3545

36-
def arrow_string_types_mapper() -> Callable:
46+
def _arrow_string_types_mapper() -> Callable:
3747
pa = import_optional_dependency("pyarrow")
3848

3949
mapping = {
@@ -44,3 +54,31 @@ def arrow_string_types_mapper() -> Callable:
4454
mapping[pa.string_view()] = pd.StringDtype(na_value=np.nan)
4555

4656
return mapping.get
57+
58+
59+
def arrow_table_to_pandas(
60+
table: pyarrow.Table,
61+
dtype_backend: DtypeBackend | Literal["numpy"] | lib.NoDefault = lib.no_default,
62+
null_to_int64: bool = False,
63+
) -> pd.DataFrame:
64+
pa = import_optional_dependency("pyarrow")
65+
66+
types_mapper: type[pd.ArrowDtype] | None | Callable
67+
if dtype_backend == "numpy_nullable":
68+
mapping = _arrow_dtype_mapping()
69+
if null_to_int64:
70+
# Modify the default mapping to also map null to Int64
71+
# (to match other engines - only for CSV parser)
72+
mapping[pa.null()] = pd.Int64Dtype()
73+
types_mapper = mapping.get
74+
elif dtype_backend == "pyarrow":
75+
types_mapper = pd.ArrowDtype
76+
elif using_string_dtype():
77+
types_mapper = _arrow_string_types_mapper()
78+
elif dtype_backend is lib.no_default or dtype_backend == "numpy":
79+
types_mapper = None
80+
else:
81+
raise NotImplementedError
82+
83+
df = table.to_pandas(types_mapper=types_mapper)
84+
return df

pandas/io/feather_format.py

+2-15
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
from pandas.util._decorators import doc
1616
from pandas.util._validators import check_dtype_backend
1717

18-
import pandas as pd
1918
from pandas.core.api import DataFrame
2019
from pandas.core.shared_docs import _shared_docs
2120

22-
from pandas.io._util import arrow_string_types_mapper
21+
from pandas.io._util import arrow_table_to_pandas
2322
from pandas.io.common import get_handle
2423

2524
if TYPE_CHECKING:
@@ -147,16 +146,4 @@ def read_feather(
147146
pa_table = feather.read_table(
148147
handles.handle, columns=columns, use_threads=bool(use_threads)
149148
)
150-
151-
if dtype_backend == "numpy_nullable":
152-
from pandas.io._util import _arrow_dtype_mapping
153-
154-
return pa_table.to_pandas(types_mapper=_arrow_dtype_mapping().get)
155-
156-
elif dtype_backend == "pyarrow":
157-
return pa_table.to_pandas(types_mapper=pd.ArrowDtype)
158-
159-
elif using_string_dtype():
160-
return pa_table.to_pandas(types_mapper=arrow_string_types_mapper())
161-
else:
162-
raise NotImplementedError
149+
return arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)

pandas/io/json/_json.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from pandas.core.dtypes.dtypes import PeriodDtype
3737

3838
from pandas import (
39-
ArrowDtype,
4039
DataFrame,
4140
Index,
4241
MultiIndex,
@@ -48,6 +47,7 @@
4847
from pandas.core.reshape.concat import concat
4948
from pandas.core.shared_docs import _shared_docs
5049

50+
from pandas.io._util import arrow_table_to_pandas
5151
from pandas.io.common import (
5252
IOHandles,
5353
dedup_names,
@@ -940,18 +940,7 @@ def read(self) -> DataFrame | Series:
940940
if self.engine == "pyarrow":
941941
pyarrow_json = import_optional_dependency("pyarrow.json")
942942
pa_table = pyarrow_json.read_json(self.data)
943-
944-
mapping: type[ArrowDtype] | None | Callable
945-
if self.dtype_backend == "pyarrow":
946-
mapping = ArrowDtype
947-
elif self.dtype_backend == "numpy_nullable":
948-
from pandas.io._util import _arrow_dtype_mapping
949-
950-
mapping = _arrow_dtype_mapping().get
951-
else:
952-
mapping = None
953-
954-
return pa_table.to_pandas(types_mapper=mapping)
943+
return arrow_table_to_pandas(pa_table, dtype_backend=self.dtype_backend)
955944
elif self.engine == "ujson":
956945
if self.lines:
957946
if self.chunksize:

pandas/io/orc.py

+2-19
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,13 @@
99
Literal,
1010
)
1111

12-
from pandas._config import using_string_dtype
13-
1412
from pandas._libs import lib
1513
from pandas.compat._optional import import_optional_dependency
1614
from pandas.util._validators import check_dtype_backend
1715

18-
import pandas as pd
1916
from pandas.core.indexes.api import default_index
2017

21-
from pandas.io._util import arrow_string_types_mapper
18+
from pandas.io._util import arrow_table_to_pandas
2219
from pandas.io.common import (
2320
get_handle,
2421
is_fsspec_url,
@@ -127,21 +124,7 @@ def read_orc(
127124
pa_table = orc.read_table(
128125
source=source, columns=columns, filesystem=filesystem, **kwargs
129126
)
130-
if dtype_backend is not lib.no_default:
131-
if dtype_backend == "pyarrow":
132-
df = pa_table.to_pandas(types_mapper=pd.ArrowDtype)
133-
else:
134-
from pandas.io._util import _arrow_dtype_mapping
135-
136-
mapping = _arrow_dtype_mapping()
137-
df = pa_table.to_pandas(types_mapper=mapping.get)
138-
return df
139-
else:
140-
if using_string_dtype():
141-
types_mapper = arrow_string_types_mapper()
142-
else:
143-
types_mapper = None
144-
return pa_table.to_pandas(types_mapper=types_mapper)
127+
return arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)
145128

146129

147130
def to_orc(

pandas/io/parquet.py

+2-16
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,19 @@
1515
filterwarnings,
1616
)
1717

18-
from pandas._config import using_string_dtype
19-
2018
from pandas._libs import lib
2119
from pandas.compat._optional import import_optional_dependency
2220
from pandas.errors import AbstractMethodError
2321
from pandas.util._decorators import doc
2422
from pandas.util._validators import check_dtype_backend
2523

26-
import pandas as pd
2724
from pandas import (
2825
DataFrame,
2926
get_option,
3027
)
3128
from pandas.core.shared_docs import _shared_docs
3229

33-
from pandas.io._util import arrow_string_types_mapper
30+
from pandas.io._util import arrow_table_to_pandas
3431
from pandas.io.common import (
3532
IOHandles,
3633
get_handle,
@@ -249,17 +246,6 @@ def read(
249246
) -> DataFrame:
250247
kwargs["use_pandas_metadata"] = True
251248

252-
to_pandas_kwargs = {}
253-
if dtype_backend == "numpy_nullable":
254-
from pandas.io._util import _arrow_dtype_mapping
255-
256-
mapping = _arrow_dtype_mapping()
257-
to_pandas_kwargs["types_mapper"] = mapping.get
258-
elif dtype_backend == "pyarrow":
259-
to_pandas_kwargs["types_mapper"] = pd.ArrowDtype # type: ignore[assignment]
260-
elif using_string_dtype():
261-
to_pandas_kwargs["types_mapper"] = arrow_string_types_mapper()
262-
263249
path_or_handle, handles, filesystem = _get_path_or_handle(
264250
path,
265251
filesystem,
@@ -280,7 +266,7 @@ def read(
280266
"make_block is deprecated",
281267
DeprecationWarning,
282268
)
283-
result = pa_table.to_pandas(**to_pandas_kwargs)
269+
result = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)
284270

285271
if pa_table.schema.metadata:
286272
if b"PANDAS_ATTRS" in pa_table.schema.metadata:

pandas/io/parsers/arrow_parser_wrapper.py

+6-21
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from typing import TYPE_CHECKING
44
import warnings
55

6-
from pandas._config import using_string_dtype
7-
86
from pandas._libs import lib
97
from pandas.compat._optional import import_optional_dependency
108
from pandas.errors import (
@@ -16,18 +14,14 @@
1614
from pandas.core.dtypes.common import pandas_dtype
1715
from pandas.core.dtypes.inference import is_integer
1816

19-
import pandas as pd
20-
from pandas import DataFrame
21-
22-
from pandas.io._util import (
23-
_arrow_dtype_mapping,
24-
arrow_string_types_mapper,
25-
)
17+
from pandas.io._util import arrow_table_to_pandas
2618
from pandas.io.parsers.base_parser import ParserBase
2719

2820
if TYPE_CHECKING:
2921
from pandas._typing import ReadBuffer
3022

23+
from pandas import DataFrame
24+
3125

3226
class ArrowParserWrapper(ParserBase):
3327
"""
@@ -293,17 +287,8 @@ def read(self) -> DataFrame:
293287
"make_block is deprecated",
294288
DeprecationWarning,
295289
)
296-
if dtype_backend == "pyarrow":
297-
frame = table.to_pandas(types_mapper=pd.ArrowDtype)
298-
elif dtype_backend == "numpy_nullable":
299-
# Modify the default mapping to also
300-
# map null to Int64 (to match other engines)
301-
dtype_mapping = _arrow_dtype_mapping()
302-
dtype_mapping[pa.null()] = pd.Int64Dtype()
303-
frame = table.to_pandas(types_mapper=dtype_mapping.get)
304-
elif using_string_dtype():
305-
frame = table.to_pandas(types_mapper=arrow_string_types_mapper())
290+
frame = arrow_table_to_pandas(
291+
table, dtype_backend=dtype_backend, null_to_int64=True
292+
)
306293

307-
else:
308-
frame = table.to_pandas()
309294
return self._finalize_pandas_output(frame)

pandas/io/sql.py

+7-34
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,7 @@
4848
is_object_dtype,
4949
is_string_dtype,
5050
)
51-
from pandas.core.dtypes.dtypes import (
52-
ArrowDtype,
53-
DatetimeTZDtype,
54-
)
51+
from pandas.core.dtypes.dtypes import DatetimeTZDtype
5552
from pandas.core.dtypes.missing import isna
5653

5754
from pandas import get_option
@@ -67,6 +64,8 @@
6764
from pandas.core.internals.construction import convert_object_array
6865
from pandas.core.tools.datetimes import to_datetime
6966

67+
from pandas.io._util import arrow_table_to_pandas
68+
7069
if TYPE_CHECKING:
7170
from collections.abc import (
7271
Callable,
@@ -2208,23 +2207,10 @@ def read_table(
22082207
else:
22092208
stmt = f"SELECT {select_list} FROM {table_name}"
22102209

2211-
mapping: type[ArrowDtype] | None | Callable
2212-
if dtype_backend == "pyarrow":
2213-
mapping = ArrowDtype
2214-
elif dtype_backend == "numpy_nullable":
2215-
from pandas.io._util import _arrow_dtype_mapping
2216-
2217-
mapping = _arrow_dtype_mapping().get
2218-
elif using_string_dtype():
2219-
from pandas.io._util import arrow_string_types_mapper
2220-
2221-
mapping = arrow_string_types_mapper()
2222-
else:
2223-
mapping = None
2224-
22252210
with self.con.cursor() as cur:
22262211
cur.execute(stmt)
2227-
df = cur.fetch_arrow_table().to_pandas(types_mapper=mapping)
2212+
pa_table = cur.fetch_arrow_table()
2213+
df = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)
22282214

22292215
return _wrap_result_adbc(
22302216
df,
@@ -2292,23 +2278,10 @@ def read_query(
22922278
if chunksize:
22932279
raise NotImplementedError("'chunksize' is not implemented for ADBC drivers")
22942280

2295-
mapping: type[ArrowDtype] | None | Callable
2296-
if dtype_backend == "pyarrow":
2297-
mapping = ArrowDtype
2298-
elif dtype_backend == "numpy_nullable":
2299-
from pandas.io._util import _arrow_dtype_mapping
2300-
2301-
mapping = _arrow_dtype_mapping().get
2302-
elif using_string_dtype():
2303-
from pandas.io._util import arrow_string_types_mapper
2304-
2305-
mapping = arrow_string_types_mapper()
2306-
else:
2307-
mapping = None
2308-
23092281
with self.con.cursor() as cur:
23102282
cur.execute(sql)
2311-
df = cur.fetch_arrow_table().to_pandas(types_mapper=mapping)
2283+
pa_table = cur.fetch_arrow_table()
2284+
df = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)
23122285

23132286
return _wrap_result_adbc(
23142287
df,

pandas/tests/io/test_sql.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -959,12 +959,12 @@ def sqlite_buildin_types(sqlite_buildin, types_data):
959959

960960
adbc_connectable_iris = [
961961
pytest.param("postgresql_adbc_iris", marks=pytest.mark.db),
962-
pytest.param("sqlite_adbc_iris", marks=pytest.mark.db),
962+
"sqlite_adbc_iris",
963963
]
964964

965965
adbc_connectable_types = [
966966
pytest.param("postgresql_adbc_types", marks=pytest.mark.db),
967-
pytest.param("sqlite_adbc_types", marks=pytest.mark.db),
967+
"sqlite_adbc_types",
968968
]
969969

970970

0 commit comments

Comments
 (0)