Skip to content

Backport PR #60324: REF: centralize pyarrow Table to pandas conversions and types_mapper handling #60332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions pandas/io/_util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
from __future__ import annotations

from typing import Callable
from typing import (
TYPE_CHECKING,
Literal,
)

import numpy as np

from pandas._config import using_string_dtype

from pandas._libs import lib
from pandas.compat import pa_version_under18p0
from pandas.compat._optional import import_optional_dependency

import pandas as pd

if TYPE_CHECKING:
from collections.abc import Callable

import pyarrow

from pandas._typing import DtypeBackend


def _arrow_dtype_mapping() -> dict:
pa = import_optional_dependency("pyarrow")
Expand All @@ -30,7 +43,7 @@ def _arrow_dtype_mapping() -> dict:
}


def arrow_string_types_mapper() -> Callable:
def _arrow_string_types_mapper() -> Callable:
pa = import_optional_dependency("pyarrow")

mapping = {
Expand All @@ -41,3 +54,35 @@ def arrow_string_types_mapper() -> Callable:
mapping[pa.string_view()] = pd.StringDtype(na_value=np.nan)

return mapping.get


def arrow_table_to_pandas(
table: pyarrow.Table,
dtype_backend: DtypeBackend | Literal["numpy"] | lib.NoDefault = lib.no_default,
null_to_int64: bool = False,
to_pandas_kwargs: dict | None = None,
) -> pd.DataFrame:
if to_pandas_kwargs is None:
to_pandas_kwargs = {}

pa = import_optional_dependency("pyarrow")

types_mapper: type[pd.ArrowDtype] | None | Callable
if dtype_backend == "numpy_nullable":
mapping = _arrow_dtype_mapping()
if null_to_int64:
# Modify the default mapping to also map null to Int64
# (to match other engines - only for CSV parser)
mapping[pa.null()] = pd.Int64Dtype()
types_mapper = mapping.get
elif dtype_backend == "pyarrow":
types_mapper = pd.ArrowDtype
elif using_string_dtype():
types_mapper = _arrow_string_types_mapper()
elif dtype_backend is lib.no_default or dtype_backend == "numpy":
types_mapper = None
else:
raise NotImplementedError

df = table.to_pandas(types_mapper=types_mapper, **to_pandas_kwargs)
return df
17 changes: 2 additions & 15 deletions pandas/io/feather_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
from pandas.util._decorators import doc
from pandas.util._validators import check_dtype_backend

import pandas as pd
from pandas.core.api import DataFrame
from pandas.core.shared_docs import _shared_docs

from pandas.io._util import arrow_string_types_mapper
from pandas.io._util import arrow_table_to_pandas
from pandas.io.common import get_handle

if TYPE_CHECKING:
Expand Down Expand Up @@ -128,16 +127,4 @@ def read_feather(
pa_table = feather.read_table(
handles.handle, columns=columns, use_threads=bool(use_threads)
)

if dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

return pa_table.to_pandas(types_mapper=_arrow_dtype_mapping().get)

elif dtype_backend == "pyarrow":
return pa_table.to_pandas(types_mapper=pd.ArrowDtype)

elif using_string_dtype():
return pa_table.to_pandas(types_mapper=arrow_string_types_mapper())
else:
raise NotImplementedError
return arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)
15 changes: 2 additions & 13 deletions pandas/io/json/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from pandas.core.dtypes.dtypes import PeriodDtype

from pandas import (
ArrowDtype,
DataFrame,
Index,
MultiIndex,
Expand All @@ -52,6 +51,7 @@
from pandas.core.reshape.concat import concat
from pandas.core.shared_docs import _shared_docs

from pandas.io._util import arrow_table_to_pandas
from pandas.io.common import (
IOHandles,
dedup_names,
Expand Down Expand Up @@ -997,18 +997,7 @@ def read(self) -> DataFrame | Series:
if self.engine == "pyarrow":
pyarrow_json = import_optional_dependency("pyarrow.json")
pa_table = pyarrow_json.read_json(self.data)

mapping: type[ArrowDtype] | None | Callable
if self.dtype_backend == "pyarrow":
mapping = ArrowDtype
elif self.dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping().get
else:
mapping = None

return pa_table.to_pandas(types_mapper=mapping)
return arrow_table_to_pandas(pa_table, dtype_backend=self.dtype_backend)
elif self.engine == "ujson":
if self.lines:
if self.chunksize:
Expand Down
21 changes: 2 additions & 19 deletions pandas/io/orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@
Literal,
)

from pandas._config import using_string_dtype

from pandas._libs import lib
from pandas.compat._optional import import_optional_dependency
from pandas.util._validators import check_dtype_backend

import pandas as pd
from pandas.core.indexes.api import default_index

from pandas.io._util import arrow_string_types_mapper
from pandas.io._util import arrow_table_to_pandas
from pandas.io.common import (
get_handle,
is_fsspec_url,
Expand Down Expand Up @@ -117,21 +114,7 @@ def read_orc(
pa_table = orc.read_table(
source=source, columns=columns, filesystem=filesystem, **kwargs
)
if dtype_backend is not lib.no_default:
if dtype_backend == "pyarrow":
df = pa_table.to_pandas(types_mapper=pd.ArrowDtype)
else:
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping()
df = pa_table.to_pandas(types_mapper=mapping.get)
return df
else:
if using_string_dtype():
types_mapper = arrow_string_types_mapper()
else:
types_mapper = None
return pa_table.to_pandas(types_mapper=types_mapper)
return arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)


def to_orc(
Expand Down
34 changes: 18 additions & 16 deletions pandas/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
Literal,
)
import warnings
from warnings import catch_warnings
from warnings import (
catch_warnings,
filterwarnings,
)

from pandas._config import using_string_dtype
from pandas._config.config import _get_option

from pandas._libs import lib
Expand All @@ -22,14 +24,13 @@
from pandas.util._exceptions import find_stack_level
from pandas.util._validators import check_dtype_backend

import pandas as pd
from pandas import (
DataFrame,
get_option,
)
from pandas.core.shared_docs import _shared_docs

from pandas.io._util import arrow_string_types_mapper
from pandas.io._util import arrow_table_to_pandas
from pandas.io.common import (
IOHandles,
get_handle,
Expand Down Expand Up @@ -250,20 +251,10 @@ def read(
kwargs["use_pandas_metadata"] = True

to_pandas_kwargs = {}
if dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping()
to_pandas_kwargs["types_mapper"] = mapping.get
elif dtype_backend == "pyarrow":
to_pandas_kwargs["types_mapper"] = pd.ArrowDtype # type: ignore[assignment]
elif using_string_dtype():
to_pandas_kwargs["types_mapper"] = arrow_string_types_mapper()

manager = _get_option("mode.data_manager", silent=True)
if manager == "array":
to_pandas_kwargs["split_blocks"] = True # type: ignore[assignment]

to_pandas_kwargs["split_blocks"] = True
path_or_handle, handles, filesystem = _get_path_or_handle(
path,
filesystem,
Expand All @@ -278,7 +269,18 @@ def read(
filters=filters,
**kwargs,
)
result = pa_table.to_pandas(**to_pandas_kwargs)

with catch_warnings():
filterwarnings(
"ignore",
"make_block is deprecated",
DeprecationWarning,
)
result = arrow_table_to_pandas(
pa_table,
dtype_backend=dtype_backend,
to_pandas_kwargs=to_pandas_kwargs,
)

if manager == "array":
result = result._as_manager("array", copy=False)
Expand Down
33 changes: 12 additions & 21 deletions pandas/io/parsers/arrow_parser_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from typing import TYPE_CHECKING
import warnings

from pandas._config import using_string_dtype

from pandas._libs import lib
from pandas.compat._optional import import_optional_dependency
from pandas.errors import (
Expand All @@ -16,18 +14,14 @@
from pandas.core.dtypes.common import pandas_dtype
from pandas.core.dtypes.inference import is_integer

import pandas as pd
from pandas import DataFrame

from pandas.io._util import (
_arrow_dtype_mapping,
arrow_string_types_mapper,
)
from pandas.io._util import arrow_table_to_pandas
from pandas.io.parsers.base_parser import ParserBase

if TYPE_CHECKING:
from pandas._typing import ReadBuffer

from pandas import DataFrame


class ArrowParserWrapper(ParserBase):
"""
Expand Down Expand Up @@ -287,17 +281,14 @@ def read(self) -> DataFrame:

table = table.cast(new_schema)

if dtype_backend == "pyarrow":
frame = table.to_pandas(types_mapper=pd.ArrowDtype)
elif dtype_backend == "numpy_nullable":
# Modify the default mapping to also
# map null to Int64 (to match other engines)
dtype_mapping = _arrow_dtype_mapping()
dtype_mapping[pa.null()] = pd.Int64Dtype()
frame = table.to_pandas(types_mapper=dtype_mapping.get)
elif using_string_dtype():
frame = table.to_pandas(types_mapper=arrow_string_types_mapper())
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
"make_block is deprecated",
DeprecationWarning,
)
frame = arrow_table_to_pandas(
table, dtype_backend=dtype_backend, null_to_int64=True
)

else:
frame = table.to_pandas()
return self._finalize_pandas_output(frame)
41 changes: 7 additions & 34 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@
is_object_dtype,
is_string_dtype,
)
from pandas.core.dtypes.dtypes import (
ArrowDtype,
DatetimeTZDtype,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.missing import isna

from pandas import get_option
Expand All @@ -68,6 +65,8 @@
from pandas.core.internals.construction import convert_object_array
from pandas.core.tools.datetimes import to_datetime

from pandas.io._util import arrow_table_to_pandas

if TYPE_CHECKING:
from collections.abc import (
Iterator,
Expand Down Expand Up @@ -2221,23 +2220,10 @@ def read_table(
else:
stmt = f"SELECT {select_list} FROM {table_name}"

mapping: type[ArrowDtype] | None | Callable
if dtype_backend == "pyarrow":
mapping = ArrowDtype
elif dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping().get
elif using_string_dtype():
from pandas.io._util import arrow_string_types_mapper

mapping = arrow_string_types_mapper()
else:
mapping = None

with self.con.cursor() as cur:
cur.execute(stmt)
df = cur.fetch_arrow_table().to_pandas(types_mapper=mapping)
pa_table = cur.fetch_arrow_table()
df = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)

return _wrap_result_adbc(
df,
Expand Down Expand Up @@ -2305,23 +2291,10 @@ def read_query(
if chunksize:
raise NotImplementedError("'chunksize' is not implemented for ADBC drivers")

mapping: type[ArrowDtype] | None | Callable
if dtype_backend == "pyarrow":
mapping = ArrowDtype
elif dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping().get
elif using_string_dtype():
from pandas.io._util import arrow_string_types_mapper

mapping = arrow_string_types_mapper()
else:
mapping = None

with self.con.cursor() as cur:
cur.execute(sql)
df = cur.fetch_arrow_table().to_pandas(types_mapper=mapping)
pa_table = cur.fetch_arrow_table()
df = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)

return _wrap_result_adbc(
df,
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,12 +959,12 @@ def sqlite_buildin_types(sqlite_buildin, types_data):

adbc_connectable_iris = [
pytest.param("postgresql_adbc_iris", marks=pytest.mark.db),
pytest.param("sqlite_adbc_iris", marks=pytest.mark.db),
"sqlite_adbc_iris",
]

adbc_connectable_types = [
pytest.param("postgresql_adbc_types", marks=pytest.mark.db),
pytest.param("sqlite_adbc_types", marks=pytest.mark.db),
"sqlite_adbc_types",
]


Expand Down
Loading