diff --git a/pandas/io/sql.py b/pandas/io/sql.py index afd045bd8bb2b..45444852c99a6 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -279,13 +279,9 @@ def read_sql_table( -------- >>> pd.read_sql_table('table_name', 'postgres:///db_name') # doctest:+SKIP """ - from sqlalchemy.exc import InvalidRequestError - pandas_sql = pandasSQL_builder(con, schema=schema) - try: - pandas_sql.meta.reflect(only=[table_name], views=True) - except InvalidRequestError as err: - raise ValueError(f"Table {table_name} not found") from err + if not pandas_sql.has_table(table_name): + raise ValueError(f"Table {table_name} not found") table = pandas_sql.read_table( table_name, @@ -580,7 +576,7 @@ def read_sql( _is_table_name = False if _is_table_name: - pandas_sql.meta.reflect(only=[sql]) + pandas_sql.meta.reflect(bind=pandas_sql.connectable, only=[sql]) return pandas_sql.read_table( sql, index_col=index_col, @@ -803,7 +799,7 @@ def _execute_create(self): self.table = self.table.to_metadata(self.pd_sql.meta) else: self.table = self.table.tometadata(self.pd_sql.meta) - self.table.create() + self.table.create(bind=self.pd_sql.connectable) def create(self): if self.exists(): @@ -842,8 +838,12 @@ def _execute_insert_multi(self, conn, keys: list[str], data_iter): and tables containing a few columns but performance degrades quickly with increase of columns. """ + + from sqlalchemy import insert + data = [dict(zip(keys, row)) for row in data_iter] - conn.execute(self.table.insert(data)) + stmt = insert(self.table).values(data) + conn.execute(stmt) def insert_data(self): if self.index is not None: @@ -951,17 +951,16 @@ def _query_iterator( yield self.frame def read(self, coerce_float=True, parse_dates=None, columns=None, chunksize=None): + from sqlalchemy import select if columns is not None and len(columns) > 0: - from sqlalchemy import select - cols = [self.table.c[n] for n in columns] if self.index is not None: for idx in self.index[::-1]: cols.insert(0, self.table.c[idx]) - sql_select = select(cols) + sql_select = select(*cols) if _gt14() else select(cols) else: - sql_select = self.table.select() + sql_select = select(self.table) if _gt14() else self.table.select() result = self.pd_sql.execute(sql_select) column_names = result.keys() @@ -1043,6 +1042,7 @@ def _create_table_setup(self): PrimaryKeyConstraint, Table, ) + from sqlalchemy.schema import MetaData column_names_and_types = self._get_column_names_and_types(self._sqlalchemy_type) @@ -1063,10 +1063,7 @@ def _create_table_setup(self): # At this point, attach to new metadata, only attach to self.meta # once table is created. - from sqlalchemy.schema import MetaData - - meta = MetaData(self.pd_sql, schema=schema) - + meta = MetaData() return Table(self.name, meta, *columns, schema=schema) def _harmonize_columns(self, parse_dates=None): @@ -1355,15 +1352,19 @@ def __init__(self, engine, schema: str | None = None): from sqlalchemy.schema import MetaData self.connectable = engine - self.meta = MetaData(self.connectable, schema=schema) + self.meta = MetaData(schema=schema) + self.meta.reflect(bind=engine) @contextmanager def run_transaction(self): - with self.connectable.begin() as tx: - if hasattr(tx, "execute"): - yield tx - else: - yield self.connectable + from sqlalchemy.engine import Engine + + if isinstance(self.connectable, Engine): + with self.connectable.connect() as conn: + with conn.begin(): + yield conn + else: + yield self.connectable def execute(self, *args, **kwargs): """Simple passthrough to SQLAlchemy connectable""" @@ -1724,9 +1725,9 @@ def tables(self): def has_table(self, name: str, schema: str | None = None): if _gt14(): - import sqlalchemy as sa + from sqlalchemy import inspect - insp = sa.inspect(self.connectable) + insp = inspect(self.connectable) return insp.has_table(name, schema or self.meta.schema) else: return self.connectable.run_callable( @@ -1752,8 +1753,8 @@ def get_table(self, table_name: str, schema: str | None = None): def drop_table(self, table_name: str, schema: str | None = None): schema = schema or self.meta.schema if self.has_table(table_name, schema): - self.meta.reflect(only=[table_name], schema=schema) - self.get_table(table_name, schema).drop() + self.meta.reflect(bind=self.connectable, only=[table_name], schema=schema) + self.get_table(table_name, schema).drop(bind=self.connectable) self.meta.clear() def _create_sql_schema( diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 217c1b28c61a5..7f73b4f12c2fb 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -54,6 +54,8 @@ import pandas.io.sql as sql from pandas.io.sql import ( SQLAlchemyEngine, + SQLDatabase, + SQLiteDatabase, _gt14, get_engine, read_sql_query, @@ -150,7 +152,8 @@ def create_and_load_iris(conn, iris_file: Path, dialect: str): stmt = insert(iris).values(params) if isinstance(conn, Engine): with conn.connect() as conn: - conn.execute(stmt) + with conn.begin(): + conn.execute(stmt) else: conn.execute(stmt) @@ -167,7 +170,8 @@ def create_and_load_iris_view(conn): stmt = text(stmt) if isinstance(conn, Engine): with conn.connect() as conn: - conn.execute(stmt) + with conn.begin(): + conn.execute(stmt) else: conn.execute(stmt) @@ -238,7 +242,8 @@ def create_and_load_types(conn, types_data: list[dict], dialect: str): stmt = insert(types).values(types_data) if isinstance(conn, Engine): with conn.connect() as conn: - conn.execute(stmt) + with conn.begin(): + conn.execute(stmt) else: conn.execute(stmt) @@ -601,13 +606,24 @@ def _to_sql_save_index(self): def _transaction_test(self): with self.pandasSQL.run_transaction() as trans: - trans.execute("CREATE TABLE test_trans (A INT, B TEXT)") + stmt = "CREATE TABLE test_trans (A INT, B TEXT)" + if isinstance(self.pandasSQL, SQLiteDatabase): + trans.execute(stmt) + else: + from sqlalchemy import text + + stmt = text(stmt) + trans.execute(stmt) class DummyException(Exception): pass # Make sure when transaction is rolled back, no rows get inserted ins_sql = "INSERT INTO test_trans (A,B) VALUES (1, 'blah')" + if isinstance(self.pandasSQL, SQLDatabase): + from sqlalchemy import text + + ins_sql = text(ins_sql) try: with self.pandasSQL.run_transaction() as trans: trans.execute(ins_sql) @@ -1127,12 +1143,20 @@ def test_read_sql_delegate(self): def test_not_reflect_all_tables(self): from sqlalchemy import text + from sqlalchemy.engine import Engine # create invalid table - qry = text("CREATE TABLE invalid (x INTEGER, y UNKNOWN);") - self.conn.execute(qry) - qry = text("CREATE TABLE other_table (x INTEGER, y INTEGER);") - self.conn.execute(qry) + query_list = [ + text("CREATE TABLE invalid (x INTEGER, y UNKNOWN);"), + text("CREATE TABLE other_table (x INTEGER, y INTEGER);"), + ] + for query in query_list: + if isinstance(self.conn, Engine): + with self.conn.connect() as conn: + with conn.begin(): + conn.execute(query) + else: + self.conn.execute(query) with tm.assert_produces_warning(None): sql.read_sql_table("other_table", self.conn) @@ -1858,7 +1882,8 @@ def test_get_schema_create_table(self, test_frame3): create_sql = text(create_sql) if isinstance(self.conn, Engine): with self.conn.connect() as conn: - conn.execute(create_sql) + with conn.begin(): + conn.execute(create_sql) else: self.conn.execute(create_sql) returned_df = sql.read_sql_table(tbl, self.conn) @@ -2203,11 +2228,11 @@ def test_default_type_conversion(self): assert issubclass(df.BoolColWithNull.dtype.type, np.floating) def test_read_procedure(self): - import pymysql from sqlalchemy import text from sqlalchemy.engine import Engine - # see GH7324. Although it is more an api test, it is added to the + # GH 7324 + # Although it is more an api test, it is added to the # mysql tests as sqlite does not have stored procedures df = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3]}) df.to_sql("test_procedure", self.conn, index=False) @@ -2220,14 +2245,12 @@ def test_read_procedure(self): SELECT * FROM test_procedure; END""" proc = text(proc) - connection = self.conn.connect() if isinstance(self.conn, Engine) else self.conn - trans = connection.begin() - try: - _ = connection.execute(proc) - trans.commit() - except pymysql.Error: - trans.rollback() - raise + if isinstance(self.conn, Engine): + with self.conn.connect() as conn: + with conn.begin(): + conn.execute(proc) + else: + self.conn.execute(proc) res1 = sql.read_sql_query("CALL get_testdb();", self.conn) tm.assert_frame_equal(df, res1)