From 56248c0c919463161698b21c7b7236c6f1a6399b Mon Sep 17 00:00:00 2001 From: Chuck Cadman Date: Mon, 14 Nov 2022 23:33:04 -0800 Subject: [PATCH 1/2] DOC: Clarify behavior of DataFrame.to_sql --- pandas/core/generic.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index b90833bda82b5..997d2fa4ca924 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2751,8 +2751,11 @@ def to_sql( Using SQLAlchemy makes it possible to use any DB supported by that library. Legacy support is provided for sqlite3.Connection objects. The user is responsible for engine disposal and connection closure for the SQLAlchemy - connectable See `here \ - `_. + connectable. See `here \ + `_. + If passing a sqlalchemy.engine.Connection which is already in a transaction, + the transaction will not be committed. If passing a sqlite3.Connection, + it will not be possible to roll back the record insertion. schema : str, optional Specify the schema (if database flavor supports this). If None, use From 57a63c254cacef5a49fc24125082015a2b6ee784 Mon Sep 17 00:00:00 2001 From: Chuck Cadman Date: Thu, 17 Nov 2022 00:54:39 -0800 Subject: [PATCH 2/2] CLN: Make SQLDatabase only accept a sqlalchemy Connection. --- pandas/io/sql.py | 261 +++++++++++++++++++----------------- pandas/tests/io/test_sql.py | 58 +++----- 2 files changed, 160 insertions(+), 159 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 69160f2ca6e64..8c4f1c027fb0e 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -180,9 +180,9 @@ def execute(sql, con, params=None): ------- Results Iterable """ - pandas_sql = pandasSQL_builder(con) - args = _convert_params(sql, params) - return pandas_sql.execute(*args) + with pandasSQL_builder(con, need_transaction=True) as pandas_sql: + args = _convert_params(sql, params) + return pandas_sql.execute(*args) # ----------------------------------------------------------------------------- @@ -282,18 +282,18 @@ def read_sql_table( -------- >>> pd.read_sql_table('table_name', 'postgres:///db_name') # doctest:+SKIP """ - pandas_sql = pandasSQL_builder(con, schema=schema) - if not pandas_sql.has_table(table_name): - raise ValueError(f"Table {table_name} not found") - - table = pandas_sql.read_table( - table_name, - index_col=index_col, - coerce_float=coerce_float, - parse_dates=parse_dates, - columns=columns, - chunksize=chunksize, - ) + with pandasSQL_builder(con, schema=schema) as pandas_sql: + if not pandas_sql.has_table(table_name): + raise ValueError(f"Table {table_name} not found") + + table = pandas_sql.read_table( + table_name, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + columns=columns, + chunksize=chunksize, + ) if table is not None: return table @@ -396,16 +396,16 @@ def read_sql_query( Any datetime values with time zone information parsed via the `parse_dates` parameter will be converted to UTC. """ - pandas_sql = pandasSQL_builder(con) - return pandas_sql.read_query( - sql, - index_col=index_col, - params=params, - coerce_float=coerce_float, - parse_dates=parse_dates, - chunksize=chunksize, - dtype=dtype, - ) + with pandasSQL_builder(con) as pandas_sql: + return pandas_sql.read_query( + sql, + index_col=index_col, + params=params, + coerce_float=coerce_float, + parse_dates=parse_dates, + chunksize=chunksize, + dtype=dtype, + ) @overload @@ -561,42 +561,42 @@ def read_sql( 0 0 2012-11-10 1 1 2010-11-12 """ - pandas_sql = pandasSQL_builder(con) + with pandasSQL_builder(con) as pandas_sql: - if isinstance(pandas_sql, SQLiteDatabase): - return pandas_sql.read_query( - sql, - index_col=index_col, - params=params, - coerce_float=coerce_float, - parse_dates=parse_dates, - chunksize=chunksize, - ) + if isinstance(pandas_sql, SQLiteDatabase): + return pandas_sql.read_query( + sql, + index_col=index_col, + params=params, + coerce_float=coerce_float, + parse_dates=parse_dates, + chunksize=chunksize, + ) - try: - _is_table_name = pandas_sql.has_table(sql) - except Exception: - # using generic exception to catch errors from sql drivers (GH24988) - _is_table_name = False + try: + _is_table_name = pandas_sql.has_table(sql) + except Exception: + # using generic exception to catch errors from sql drivers (GH24988) + _is_table_name = False - if _is_table_name: - return pandas_sql.read_table( - sql, - index_col=index_col, - coerce_float=coerce_float, - parse_dates=parse_dates, - columns=columns, - chunksize=chunksize, - ) - else: - return pandas_sql.read_query( - sql, - index_col=index_col, - params=params, - coerce_float=coerce_float, - parse_dates=parse_dates, - chunksize=chunksize, - ) + if _is_table_name: + return pandas_sql.read_table( + sql, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + columns=columns, + chunksize=chunksize, + ) + else: + return pandas_sql.read_query( + sql, + index_col=index_col, + params=params, + coerce_float=coerce_float, + parse_dates=parse_dates, + chunksize=chunksize, + ) def to_sql( @@ -685,8 +685,6 @@ def to_sql( if if_exists not in ("fail", "replace", "append"): raise ValueError(f"'{if_exists}' is not valid for if_exists") - pandas_sql = pandasSQL_builder(con, schema=schema) - if isinstance(frame, Series): frame = frame.to_frame() elif not isinstance(frame, DataFrame): @@ -694,19 +692,20 @@ def to_sql( "'frame' argument should be either a Series or a DataFrame" ) - return pandas_sql.to_sql( - frame, - name, - if_exists=if_exists, - index=index, - index_label=index_label, - schema=schema, - chunksize=chunksize, - dtype=dtype, - method=method, - engine=engine, - **engine_kwargs, - ) + with pandasSQL_builder(con, schema=schema, need_transaction=True) as pandas_sql: + return pandas_sql.to_sql( + frame, + name, + if_exists=if_exists, + index=index, + index_label=index_label, + schema=schema, + chunksize=chunksize, + dtype=dtype, + method=method, + engine=engine, + **engine_kwargs, + ) def has_table(table_name: str, con, schema: str | None = None) -> bool: @@ -729,41 +728,66 @@ def has_table(table_name: str, con, schema: str | None = None) -> bool: ------- boolean """ - pandas_sql = pandasSQL_builder(con, schema=schema) - return pandas_sql.has_table(table_name) + with pandasSQL_builder(con, schema=schema) as pandas_sql: + return pandas_sql.has_table(table_name) table_exists = has_table -def pandasSQL_builder(con, schema: str | None = None) -> PandasSQL: +@contextmanager +def pandasSQL_builder( + con, + schema: str | None = None, + need_transaction: bool = False, +) -> Iterator[PandasSQL]: """ Convenience function to return the correct PandasSQL subclass based on the - provided parameters. + provided parameters. Also creates a sqlalchemy connection and transaction + if necessary. """ import sqlite3 if isinstance(con, sqlite3.Connection) or con is None: - return SQLiteDatabase(con) - - sqlalchemy = import_optional_dependency("sqlalchemy", errors="ignore") + yield SQLiteDatabase(con) + else: + sqlalchemy = import_optional_dependency("sqlalchemy", errors="ignore") - if isinstance(con, str): - if sqlalchemy is None: + if sqlalchemy is not None and isinstance( + con, (str, sqlalchemy.engine.Connectable) + ): + with _sqlalchemy_con(con, need_transaction) as con: + yield SQLDatabase(con, schema=schema) + elif isinstance(con, str) and sqlalchemy is None: raise ImportError("Using URI string without sqlalchemy installed.") - con = sqlalchemy.create_engine(con) + else: - if sqlalchemy is not None and isinstance(con, sqlalchemy.engine.Connectable): - return SQLDatabase(con, schema=schema) + warnings.warn( + "pandas only supports SQLAlchemy connectable (engine/connection) or " + "database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 " + "objects are not tested. Please consider using SQLAlchemy.", + UserWarning, + stacklevel=find_stack_level() + 2, + ) + yield SQLiteDatabase(con) + + +@contextmanager +def _sqlalchemy_con(connectable, need_transaction: bool): + """Create a sqlalchemy connection and a transaction if necessary.""" + sqlalchemy = import_optional_dependency("sqlalchemy", errors="raise") - warnings.warn( - "pandas only supports SQLAlchemy connectable (engine/connection) or " - "database string URI or sqlite3 DBAPI2 connection. " - "Other DBAPI2 objects are not tested. Please consider using SQLAlchemy.", - UserWarning, - stacklevel=find_stack_level(), - ) - return SQLiteDatabase(con) + if isinstance(connectable, str): + connectable = sqlalchemy.create_engine(connectable) + if isinstance(connectable, sqlalchemy.engine.Engine): + with connectable.connect() as con: + if need_transaction: + with con.begin(): + yield con + else: + yield con + else: + yield connectable class SQLTable(PandasObject): @@ -816,12 +840,12 @@ def exists(self): def sql_schema(self) -> str: from sqlalchemy.schema import CreateTable - return str(CreateTable(self.table).compile(self.pd_sql.connectable)) + return str(CreateTable(self.table).compile(self.pd_sql.con)) def _execute_create(self) -> None: # Inserting table into database, add to MetaData object self.table = self.table.to_metadata(self.pd_sql.meta) - self.table.create(bind=self.pd_sql.connectable) + self.table.create(bind=self.pd_sql.con) def create(self) -> None: if self.exists(): @@ -1410,8 +1434,8 @@ class SQLDatabase(PandasSQL): Parameters ---------- - engine : SQLAlchemy connectable - Connectable to connect with the database. Using SQLAlchemy makes it + con : SQLAlchemy Connection + Connection to connect with the database. Using SQLAlchemy makes it possible to use any DB supported by that library. schema : string, default None Name of SQL schema in database to write to (if database flavor @@ -1419,26 +1443,19 @@ class SQLDatabase(PandasSQL): """ - def __init__(self, engine, schema: str | None = None) -> None: + def __init__(self, con, schema: str | None = None) -> None: from sqlalchemy.schema import MetaData - self.connectable = engine + self.con = con self.meta = MetaData(schema=schema) @contextmanager def run_transaction(self): - from sqlalchemy.engine import Engine - - if isinstance(self.connectable, Engine): - with self.connectable.connect() as conn: - with conn.begin(): - yield conn - else: - yield self.connectable + yield self.con def execute(self, *args, **kwargs): """Simple passthrough to SQLAlchemy connectable""" - return self.connectable.execution_options().execute(*args, **kwargs) + return self.con.execute(*args, **kwargs) def read_table( self, @@ -1492,6 +1509,7 @@ def read_table( SQLDatabase.read_query """ + self.meta.reflect(bind=self.con, only=[table_name]) table = SQLTable(table_name, self, index=index_col, schema=schema) return table.read( coerce_float=coerce_float, @@ -1681,9 +1699,8 @@ def check_case_sensitive( # Only check when name is not a number and name is not lower case from sqlalchemy import inspect as sqlalchemy_inspect - with self.connectable.connect() as conn: - insp = sqlalchemy_inspect(conn) - table_names = insp.get_table_names(schema=schema or self.meta.schema) + insp = sqlalchemy_inspect(self.con) + table_names = insp.get_table_names(schema=schema or self.meta.schema) if name not in table_names: msg = ( f"The provided table name '{name}' is not found exactly as " @@ -1773,7 +1790,7 @@ def to_sql( total_inserted = sql_engine.insert_records( table=table, - con=self.connectable, + con=self.con, frame=frame, name=name, index=index, @@ -1790,10 +1807,10 @@ def to_sql( def tables(self): return self.meta.tables - def has_table(self, name: str, schema: str | None = None): + def has_table(self, name: str, schema: str | None = None) -> bool: from sqlalchemy import inspect as sqlalchemy_inspect - insp = sqlalchemy_inspect(self.connectable) + insp = sqlalchemy_inspect(self.con) return insp.has_table(name, schema or self.meta.schema) def get_table(self, table_name: str, schema: str | None = None) -> Table: @@ -1803,9 +1820,7 @@ def get_table(self, table_name: str, schema: str | None = None) -> Table: ) schema = schema or self.meta.schema - tbl = Table( - table_name, self.meta, autoload_with=self.connectable, schema=schema - ) + tbl = Table(table_name, self.meta, autoload_with=self.con, schema=schema) for column in tbl.columns: if isinstance(column.type, Numeric): column.type.asdecimal = False @@ -1814,8 +1829,8 @@ def get_table(self, table_name: str, schema: str | None = None) -> Table: def drop_table(self, table_name: str, schema: str | None = None) -> None: schema = schema or self.meta.schema if self.has_table(table_name, schema): - self.meta.reflect(bind=self.connectable, only=[table_name], schema=schema) - self.get_table(table_name, schema).drop(bind=self.connectable) + self.meta.reflect(bind=self.con, only=[table_name], schema=schema) + self.get_table(table_name, schema).drop(bind=self.con) self.meta.clear() def _create_sql_schema( @@ -2288,7 +2303,7 @@ def get_schema( .. versionadded:: 1.2.0 """ - pandas_sql = pandasSQL_builder(con=con) - return pandas_sql._create_sql_schema( - frame, name, keys=keys, dtype=dtype, schema=schema - ) + with pandasSQL_builder(con=con) as pandas_sql: + return pandas_sql._create_sql_schema( + frame, name, keys=keys, dtype=dtype, schema=schema + ) diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index f321ecc2f65ff..2b2771b7fccc3 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -505,9 +505,9 @@ def sqlite_buildin_iris(sqlite_buildin, iris_path): @pytest.mark.parametrize("method", [None, "multi"]) def test_to_sql(conn, method, test_frame1, request): conn = request.getfixturevalue(conn) - pandasSQL = pandasSQL_builder(conn) - pandasSQL.to_sql(test_frame1, "test_frame", method=method) - assert pandasSQL.has_table("test_frame") + with pandasSQL_builder(conn) as pandasSQL: + pandasSQL.to_sql(test_frame1, "test_frame", method=method) + assert pandasSQL.has_table("test_frame") assert count_rows(conn, "test_frame") == len(test_frame1) @@ -516,10 +516,10 @@ def test_to_sql(conn, method, test_frame1, request): @pytest.mark.parametrize("mode, num_row_coef", [("replace", 1), ("append", 2)]) def test_to_sql_exist(conn, mode, num_row_coef, test_frame1, request): conn = request.getfixturevalue(conn) - pandasSQL = pandasSQL_builder(conn) - pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail") - pandasSQL.to_sql(test_frame1, "test_frame", if_exists=mode) - assert pandasSQL.has_table("test_frame") + with pandasSQL_builder(conn) as pandasSQL: + pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail") + pandasSQL.to_sql(test_frame1, "test_frame", if_exists=mode) + assert pandasSQL.has_table("test_frame") assert count_rows(conn, "test_frame") == num_row_coef * len(test_frame1) @@ -527,21 +527,21 @@ def test_to_sql_exist(conn, mode, num_row_coef, test_frame1, request): @pytest.mark.parametrize("conn", all_connectable) def test_to_sql_exist_fail(conn, test_frame1, request): conn = request.getfixturevalue(conn) - pandasSQL = pandasSQL_builder(conn) - pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail") - assert pandasSQL.has_table("test_frame") - - msg = "Table 'test_frame' already exists" - with pytest.raises(ValueError, match=msg): + with pandasSQL_builder(conn) as pandasSQL: pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail") + assert pandasSQL.has_table("test_frame") + + msg = "Table 'test_frame' already exists" + with pytest.raises(ValueError, match=msg): + pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail") @pytest.mark.db @pytest.mark.parametrize("conn", all_connectable_iris) def test_read_iris(conn, request): conn = request.getfixturevalue(conn) - pandasSQL = pandasSQL_builder(conn) - iris_frame = pandasSQL.read_query("SELECT * FROM iris") + with pandasSQL_builder(conn) as pandasSQL: + iris_frame = pandasSQL.read_query("SELECT * FROM iris") check_iris_frame(iris_frame) @@ -549,7 +549,6 @@ def test_read_iris(conn, request): @pytest.mark.parametrize("conn", sqlalchemy_connectable) def test_to_sql_callable(conn, test_frame1, request): conn = request.getfixturevalue(conn) - pandasSQL = pandasSQL_builder(conn) check = [] # used to double check function below is really being used @@ -558,8 +557,9 @@ def sample(pd_table, conn, keys, data_iter): data = [dict(zip(keys, row)) for row in data_iter] conn.execute(pd_table.table.insert(), data) - pandasSQL.to_sql(test_frame1, "test_frame", method=sample) - assert pandasSQL.has_table("test_frame") + with pandasSQL_builder(conn) as pandasSQL: + pandasSQL.to_sql(test_frame1, "test_frame", method=sample) + assert pandasSQL.has_table("test_frame") assert check == [1] assert count_rows(conn, "test_frame") == len(test_frame1) @@ -694,7 +694,7 @@ def _get_all_tables(self): return inspect(self.conn).get_table_names() def _close_conn(self): - # https://docs.sqlalchemy.org/en/13/core/connections.html#engine-disposal + # https://docs.sqlalchemy.org/en/14/core/connections.html#engine-disposal self.conn.dispose() @@ -1287,8 +1287,7 @@ def test_escaped_table_name(self): tm.assert_frame_equal(res, df) -@pytest.mark.skipif(not SQLALCHEMY_INSTALLED, reason="SQLAlchemy not installed") -class TestSQLApi(SQLAlchemyMixIn, _TestSQLApi): +class _TestSQLApiEngine(SQLAlchemyMixIn, _TestSQLApi): """ Test the public API as it would be used directly @@ -1512,7 +1511,8 @@ def setup_method(self, load_iris_data, load_types_data): self.pandasSQL = sql.SQLDatabase(self.__engine) -class TestSQLApiConn(_EngineToConnMixin, TestSQLApi): +@pytest.mark.skipif(not SQLALCHEMY_INSTALLED, reason="SQLAlchemy not installed") +class TestSQLApiConn(_EngineToConnMixin, _TestSQLApiEngine): pass @@ -2525,30 +2525,16 @@ def test_schema_support(self): tm.assert_frame_equal(res1, res2) -@pytest.mark.db -class TestMySQLAlchemy(_TestMySQLAlchemy, _TestSQLAlchemy): - pass - - @pytest.mark.db class TestMySQLAlchemyConn(_TestMySQLAlchemy, _TestSQLAlchemyConn): pass -@pytest.mark.db -class TestPostgreSQLAlchemy(_TestPostgreSQLAlchemy, _TestSQLAlchemy): - pass - - @pytest.mark.db class TestPostgreSQLAlchemyConn(_TestPostgreSQLAlchemy, _TestSQLAlchemyConn): pass -class TestSQLiteAlchemy(_TestSQLiteAlchemy, _TestSQLAlchemy): - pass - - class TestSQLiteAlchemyConn(_TestSQLiteAlchemy, _TestSQLAlchemyConn): pass