Skip to content

ENH: Add dtype_backend support to read_sql #50985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ to select the nullable dtypes implementation.
* :func:`read_html`
* :func:`read_xml`
* :func:`read_json`
* :func:`read_sql`
* :func:`read_sql_query`
* :func:`read_sql_table`
* :func:`read_parquet`
* :func:`read_orc`
* :func:`read_feather`
Expand Down
39 changes: 39 additions & 0 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
DataFrame,
Series,
)
from pandas.core.arrays import ArrowExtensionArray
from pandas.core.base import PandasObject
import pandas.core.common as com
from pandas.core.internals.construction import convert_object_array
Expand Down Expand Up @@ -155,6 +156,12 @@ def _convert_arrays_to_dataframe(
coerce_float=coerce_float,
use_nullable_dtypes=use_nullable_dtypes,
)
dtype_backend = get_option("mode.dtype_backend")
if dtype_backend == "pyarrow":
pa = import_optional_dependency("pyarrow")
arrays = [
ArrowExtensionArray(pa.array(arr, from_pandas=True)) for arr in arrays
]
if arrays:
return DataFrame(dict(zip(columns, arrays)))
else:
Expand Down Expand Up @@ -303,6 +310,14 @@ def read_sql_table(
set to True, nullable dtypes are used for all dtypes that have a nullable
implementation, even if no nulls are present.

.. note::

The nullable dtype implementation can be configured by calling
``pd.set_option("mode.dtype_backend", "pandas")`` to use
numpy-backed nullable dtypes or
``pd.set_option("mode.dtype_backend", "pyarrow")`` to use
pyarrow-backed nullable dtypes (using ``pd.ArrowDtype``).

.. versionadded:: 2.0

Returns
Expand Down Expand Up @@ -438,6 +453,14 @@ def read_sql_query(
set to True, nullable dtypes are used for all dtypes that have a nullable
implementation, even if no nulls are present.

.. note::

The nullable dtype implementation can be configured by calling
``pd.set_option("mode.dtype_backend", "pandas")`` to use
numpy-backed nullable dtypes or
``pd.set_option("mode.dtype_backend", "pyarrow")`` to use
pyarrow-backed nullable dtypes (using ``pd.ArrowDtype``).

.. versionadded:: 2.0

Returns
Expand Down Expand Up @@ -568,6 +591,14 @@ def read_sql(
set to True, nullable dtypes are used for all dtypes that have a nullable
implementation, even if no nulls are present.

.. note::

The nullable dtype implementation can be configured by calling
``pd.set_option("mode.dtype_backend", "pandas")`` to use
numpy-backed nullable dtypes or
``pd.set_option("mode.dtype_backend", "pyarrow")`` to use
pyarrow-backed nullable dtypes (using ``pd.ArrowDtype``).

.. versionadded:: 2.0
dtype : Type name or dict of columns
Data type for data or columns. E.g. np.float64 or
Expand Down Expand Up @@ -1609,6 +1640,14 @@ def read_table(
set to True, nullable dtypes are used for all dtypes that have a nullable
implementation, even if no nulls are present.

.. note::

The nullable dtype implementation can be configured by calling
``pd.set_option("mode.dtype_backend", "pandas")`` to use
numpy-backed nullable dtypes or
``pd.set_option("mode.dtype_backend", "pyarrow")`` to use
pyarrow-backed nullable dtypes (using ``pd.ArrowDtype``).

.. versionadded:: 2.0

Returns
Expand Down
116 changes: 76 additions & 40 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2293,61 +2293,73 @@ def test_get_engine_auto_error_message(self):

@pytest.mark.parametrize("option", [True, False])
@pytest.mark.parametrize("func", ["read_sql", "read_sql_query"])
def test_read_sql_nullable_dtypes(self, string_storage, func, option):
@pytest.mark.parametrize("dtype_backend", ["pandas", "pyarrow"])
def test_read_sql_nullable_dtypes(
self, string_storage, func, option, dtype_backend
):
# GH#50048
table = "test"
df = self.nullable_data()
df.to_sql(table, self.conn, index=False, if_exists="replace")

with pd.option_context("mode.string_storage", string_storage):
if option:
with pd.option_context("mode.nullable_dtypes", True):
result = getattr(pd, func)(f"Select * from {table}", self.conn)
else:
result = getattr(pd, func)(
f"Select * from {table}", self.conn, use_nullable_dtypes=True
)
expected = self.nullable_expected(string_storage)
with pd.option_context("mode.dtype_backend", dtype_backend):
if option:
with pd.option_context("mode.nullable_dtypes", True):
result = getattr(pd, func)(f"Select * from {table}", self.conn)
else:
result = getattr(pd, func)(
f"Select * from {table}", self.conn, use_nullable_dtypes=True
)
expected = self.nullable_expected(string_storage, dtype_backend)
tm.assert_frame_equal(result, expected)

with pd.option_context("mode.string_storage", string_storage):
iterator = getattr(pd, func)(
f"Select * from {table}",
self.conn,
use_nullable_dtypes=True,
chunksize=3,
)
expected = self.nullable_expected(string_storage)
for result in iterator:
tm.assert_frame_equal(result, expected)
with pd.option_context("mode.dtype_backend", dtype_backend):
iterator = getattr(pd, func)(
f"Select * from {table}",
self.conn,
use_nullable_dtypes=True,
chunksize=3,
)
expected = self.nullable_expected(string_storage, dtype_backend)
for result in iterator:
tm.assert_frame_equal(result, expected)

@pytest.mark.parametrize("option", [True, False])
@pytest.mark.parametrize("func", ["read_sql", "read_sql_table"])
def test_read_sql_nullable_dtypes_table(self, string_storage, func, option):
@pytest.mark.parametrize("dtype_backend", ["pandas", "pyarrow"])
def test_read_sql_nullable_dtypes_table(
self, string_storage, func, option, dtype_backend
):
# GH#50048
table = "test"
df = self.nullable_data()
df.to_sql(table, self.conn, index=False, if_exists="replace")

with pd.option_context("mode.string_storage", string_storage):
if option:
with pd.option_context("mode.nullable_dtypes", True):
result = getattr(pd, func)(table, self.conn)
else:
result = getattr(pd, func)(table, self.conn, use_nullable_dtypes=True)
expected = self.nullable_expected(string_storage)
with pd.option_context("mode.dtype_backend", dtype_backend):
if option:
with pd.option_context("mode.nullable_dtypes", True):
result = getattr(pd, func)(table, self.conn)
else:
result = getattr(pd, func)(
table, self.conn, use_nullable_dtypes=True
)
expected = self.nullable_expected(string_storage, dtype_backend)
tm.assert_frame_equal(result, expected)

with pd.option_context("mode.string_storage", string_storage):
iterator = getattr(pd, func)(
table,
self.conn,
use_nullable_dtypes=True,
chunksize=3,
)
expected = self.nullable_expected(string_storage)
for result in iterator:
tm.assert_frame_equal(result, expected)
with pd.option_context("mode.dtype_backend", dtype_backend):
iterator = getattr(pd, func)(
table,
self.conn,
use_nullable_dtypes=True,
chunksize=3,
)
expected = self.nullable_expected(string_storage, dtype_backend)
for result in iterator:
tm.assert_frame_equal(result, expected)

def nullable_data(self) -> DataFrame:
return DataFrame(
Expand All @@ -2363,7 +2375,7 @@ def nullable_data(self) -> DataFrame:
}
)

def nullable_expected(self, storage) -> DataFrame:
def nullable_expected(self, storage, dtype_backend) -> DataFrame:

string_array: StringArray | ArrowStringArray
string_array_na: StringArray | ArrowStringArray
Expand All @@ -2376,7 +2388,7 @@ def nullable_expected(self, storage) -> DataFrame:
string_array = ArrowStringArray(pa.array(["a", "b", "c"]))
string_array_na = ArrowStringArray(pa.array(["a", "b", None]))

return DataFrame(
df = DataFrame(
{
"a": Series([1, np.nan, 3], dtype="Int64"),
"b": Series([1, 2, 3], dtype="Int64"),
Expand All @@ -2388,6 +2400,18 @@ def nullable_expected(self, storage) -> DataFrame:
"h": string_array_na,
}
)
if dtype_backend == "pyarrow":
pa = pytest.importorskip("pyarrow")

from pandas.arrays import ArrowExtensionArray

df = DataFrame(
{
col: ArrowExtensionArray(pa.array(df[col], from_pandas=True))
for col in df.columns
}
)
return df

def test_chunksize_empty_dtypes(self):
# GH#50245
Expand Down Expand Up @@ -2511,8 +2535,14 @@ class Test(BaseModel):

assert list(df.columns) == ["id", "string_column"]

def nullable_expected(self, storage) -> DataFrame:
return super().nullable_expected(storage).astype({"e": "Int64", "f": "Int64"})
def nullable_expected(self, storage, dtype_backend) -> DataFrame:
df = super().nullable_expected(storage, dtype_backend)
if dtype_backend == "pandas":
df = df.astype({"e": "Int64", "f": "Int64"})
else:
df = df.astype({"e": "int64[pyarrow]", "f": "int64[pyarrow]"})

return df

@pytest.mark.parametrize("func", ["read_sql", "read_sql_table"])
def test_read_sql_nullable_dtypes_table(self, string_storage, func):
Expand Down Expand Up @@ -2546,8 +2576,14 @@ def setup_driver(cls):
def test_default_type_conversion(self):
pass

def nullable_expected(self, storage) -> DataFrame:
return super().nullable_expected(storage).astype({"e": "Int64", "f": "Int64"})
def nullable_expected(self, storage, dtype_backend) -> DataFrame:
df = super().nullable_expected(storage, dtype_backend)
if dtype_backend == "pandas":
df = df.astype({"e": "Int64", "f": "Int64"})
else:
df = df.astype({"e": "int64[pyarrow]", "f": "int64[pyarrow]"})

return df


@pytest.mark.db
Expand Down