Skip to content

ENH: Initial implementation of support for SQLAlchemy 2.0 in read_sql etc (WIP) #44794

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 61 additions & 48 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import warnings

import numpy as np
from sqlalchemy.sql.expression import text
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will break without sqlalchemy installed, see pandasSQL_builder where import_optional_dependency is used.

Maybe a better option is to move the conversion to text into SQLDatabase.execute, since that will always use SQLAlchemy.

Alternatively, it looks like _convert_params is never called with sqlalchemy=True, so maybe just remove that code path?


import pandas._libs.lib as lib
from pandas._typing import DtypeArg
Expand Down Expand Up @@ -66,9 +67,15 @@ def _gt14() -> bool:
return Version(sqlalchemy.__version__) >= Version("1.4.0")


def _convert_params(sql, params):
def _convert_params(sql, params, sqlalchemy=False):
"""Convert SQL and params args to DBAPI2.0 compliant format."""
args = [sql]
if sqlalchemy:
if isinstance(sql, str):
args = [text(sql)]
else:
args = [sql]
else:
args = [sql]
if params is not None:
if hasattr(params, "keys"): # test if params is a mapping
args += [params]
Expand Down Expand Up @@ -180,7 +187,11 @@ def execute(sql, con, params=None):
"""
pandas_sql = pandasSQL_builder(con)
args = _convert_params(sql, params)
return pandas_sql.execute(*args)
if isinstance(pandas_sql, SQLiteDatabase):
return pandas_sql.execute(*args)
else:
with pandas_sql.run_transaction() as conn:
return pandas_sql.execute(conn, *args)


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -963,29 +974,30 @@ def read(self, coerce_float=True, parse_dates=None, columns=None, chunksize=None
else:
sql_select = select(self.table) if _gt14() else self.table.select()

result = self.pd_sql.execute(sql_select)
column_names = result.keys()

if chunksize is not None:
return self._query_iterator(
result,
chunksize,
column_names,
coerce_float=coerce_float,
parse_dates=parse_dates,
)
else:
data = result.fetchall()
self.frame = DataFrame.from_records(
data, columns=column_names, coerce_float=coerce_float
)
with self.pd_sql.run_transaction() as conn:
result = self.pd_sql.execute(conn, sql_select)
column_names = result.keys()

if chunksize is not None:
return self._query_iterator(
result,
chunksize,
column_names,
coerce_float=coerce_float,
parse_dates=parse_dates,
)
else:
data = result.fetchall()
self.frame = DataFrame.from_records(
data, columns=column_names, coerce_float=coerce_float
)

self._harmonize_columns(parse_dates=parse_dates)
self._harmonize_columns(parse_dates=parse_dates)

if self.index is not None:
self.frame.set_index(self.index, inplace=True)
if self.index is not None:
self.frame.set_index(self.index, inplace=True)

return self.frame
return self.frame

def _index_name(self, index, index_label):
# for writing: index=True to include index in sql table
Expand Down Expand Up @@ -1367,9 +1379,9 @@ def run_transaction(self):
else:
yield self.connectable

def execute(self, *args, **kwargs):
def execute(self, conn, *args, **kwargs):
"""Simple passthrough to SQLAlchemy connectable"""
return self.connectable.execution_options().execute(*args, **kwargs)
return conn.execute(*args, **kwargs)

def read_table(
self,
Expand Down Expand Up @@ -1524,30 +1536,31 @@ def read_query(
"""
args = _convert_params(sql, params)

result = self.execute(*args)
columns = result.keys()
with self.run_transaction() as conn:
result = conn.execute(*args)
columns = result.keys()

if chunksize is not None:
return self._query_iterator(
result,
chunksize,
columns,
index_col=index_col,
coerce_float=coerce_float,
parse_dates=parse_dates,
dtype=dtype,
)
else:
data = result.fetchall()
frame = _wrap_result(
data,
columns,
index_col=index_col,
coerce_float=coerce_float,
parse_dates=parse_dates,
dtype=dtype,
)
return frame
if chunksize is not None:
return self._query_iterator(
result,
chunksize,
columns,
index_col=index_col,
coerce_float=coerce_float,
parse_dates=parse_dates,
dtype=dtype,
)
else:
data = result.fetchall()
frame = _wrap_result(
data,
columns,
index_col=index_col,
coerce_float=coerce_float,
parse_dates=parse_dates,
dtype=dtype,
)
return frame

read_sql = read_query

Expand Down