From c01b207797b9bf9d2f3de03c9b11671ceac3d328 Mon Sep 17 00:00:00 2001 From: Robin Wilson Date: Mon, 6 Dec 2021 20:57:21 +0000 Subject: [PATCH 1/2] Initial implementation of support for SQLAlchemy 2.0. Currently tested with SQLAlchemy 1.4 using the 'future=True' flag to an engine. --- pandas/io/sql.py | 110 ++++++++++++++++++++++++++--------------------- 1 file changed, 62 insertions(+), 48 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 26869a660f4b4..5e1a8844b748a 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -22,6 +22,8 @@ ) import warnings +from sqlalchemy.sql.expression import text + import numpy as np import pandas._libs.lib as lib @@ -66,9 +68,15 @@ def _gt14() -> bool: return Version(sqlalchemy.__version__) >= Version("1.4.0") -def _convert_params(sql, params): +def _convert_params(sql, params, sqlalchemy=False): """Convert SQL and params args to DBAPI2.0 compliant format.""" - args = [sql] + if sqlalchemy: + if isinstance(sql, str): + args = [text(sql)] + else: + args = [sql] + else: + args = [sql] if params is not None: if hasattr(params, "keys"): # test if params is a mapping args += [params] @@ -180,7 +188,11 @@ def execute(sql, con, params=None): """ pandas_sql = pandasSQL_builder(con) args = _convert_params(sql, params) - return pandas_sql.execute(*args) + if isinstance(pandas_sql, SQLiteDatabase): + return pandas_sql.execute(*args) + else: + with pandas_sql.run_transaction() as conn: + return pandas_sql.execute(conn, *args) # ----------------------------------------------------------------------------- @@ -963,29 +975,30 @@ def read(self, coerce_float=True, parse_dates=None, columns=None, chunksize=None else: sql_select = select(self.table) if _gt14() else self.table.select() - result = self.pd_sql.execute(sql_select) - column_names = result.keys() - - if chunksize is not None: - return self._query_iterator( - result, - chunksize, - column_names, - coerce_float=coerce_float, - parse_dates=parse_dates, - ) - else: - data = result.fetchall() - self.frame = DataFrame.from_records( - data, columns=column_names, coerce_float=coerce_float - ) + with self.pd_sql.run_transaction() as conn: + result = self.pd_sql.execute(conn, sql_select) + column_names = result.keys() + + if chunksize is not None: + return self._query_iterator( + result, + chunksize, + column_names, + coerce_float=coerce_float, + parse_dates=parse_dates, + ) + else: + data = result.fetchall() + self.frame = DataFrame.from_records( + data, columns=column_names, coerce_float=coerce_float + ) - self._harmonize_columns(parse_dates=parse_dates) + self._harmonize_columns(parse_dates=parse_dates) - if self.index is not None: - self.frame.set_index(self.index, inplace=True) + if self.index is not None: + self.frame.set_index(self.index, inplace=True) - return self.frame + return self.frame def _index_name(self, index, index_label): # for writing: index=True to include index in sql table @@ -1367,9 +1380,9 @@ def run_transaction(self): else: yield self.connectable - def execute(self, *args, **kwargs): + def execute(self, conn, *args, **kwargs): """Simple passthrough to SQLAlchemy connectable""" - return self.connectable.execution_options().execute(*args, **kwargs) + return conn.execute(*args, **kwargs) def read_table( self, @@ -1524,30 +1537,31 @@ def read_query( """ args = _convert_params(sql, params) - result = self.execute(*args) - columns = result.keys() + with self.run_transaction() as conn: + result = conn.execute(*args) + columns = result.keys() - if chunksize is not None: - return self._query_iterator( - result, - chunksize, - columns, - index_col=index_col, - coerce_float=coerce_float, - parse_dates=parse_dates, - dtype=dtype, - ) - else: - data = result.fetchall() - frame = _wrap_result( - data, - columns, - index_col=index_col, - coerce_float=coerce_float, - parse_dates=parse_dates, - dtype=dtype, - ) - return frame + if chunksize is not None: + return self._query_iterator( + result, + chunksize, + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + ) + else: + data = result.fetchall() + frame = _wrap_result( + data, + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + ) + return frame read_sql = read_query From 921bcd4e65ae8319c55211beb9099fae435bd569 Mon Sep 17 00:00:00 2001 From: Robin Wilson Date: Mon, 6 Dec 2021 21:11:04 +0000 Subject: [PATCH 2/2] Run pre-commit hooks --- pandas/io/sql.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 5e1a8844b748a..893bff116ad22 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -22,9 +22,8 @@ ) import warnings -from sqlalchemy.sql.expression import text - import numpy as np +from sqlalchemy.sql.expression import text import pandas._libs.lib as lib from pandas._typing import DtypeArg