Skip to content

EHN: partially upgrade sql module for SQLAlchemy 2.0 compat #43116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 28 additions & 27 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,9 @@ def read_sql_table(
--------
>>> pd.read_sql_table('table_name', 'postgres:///db_name') # doctest:+SKIP
"""
from sqlalchemy.exc import InvalidRequestError

pandas_sql = pandasSQL_builder(con, schema=schema)
try:
pandas_sql.meta.reflect(only=[table_name], views=True)
except InvalidRequestError as err:
raise ValueError(f"Table {table_name} not found") from err
if not pandas_sql.has_table(table_name):
raise ValueError(f"Table {table_name} not found")

table = pandas_sql.read_table(
table_name,
Expand Down Expand Up @@ -580,7 +576,7 @@ def read_sql(
_is_table_name = False

if _is_table_name:
pandas_sql.meta.reflect(only=[sql])
pandas_sql.meta.reflect(bind=pandas_sql.connectable, only=[sql])
return pandas_sql.read_table(
sql,
index_col=index_col,
Expand Down Expand Up @@ -803,7 +799,7 @@ def _execute_create(self):
self.table = self.table.to_metadata(self.pd_sql.meta)
else:
self.table = self.table.tometadata(self.pd_sql.meta)
self.table.create()
self.table.create(bind=self.pd_sql.connectable)

def create(self):
if self.exists():
Expand Down Expand Up @@ -842,8 +838,12 @@ def _execute_insert_multi(self, conn, keys: list[str], data_iter):
and tables containing a few columns
but performance degrades quickly with increase of columns.
"""

from sqlalchemy import insert

data = [dict(zip(keys, row)) for row in data_iter]
conn.execute(self.table.insert(data))
stmt = insert(self.table).values(data)
conn.execute(stmt)

def insert_data(self):
if self.index is not None:
Expand Down Expand Up @@ -951,17 +951,16 @@ def _query_iterator(
yield self.frame

def read(self, coerce_float=True, parse_dates=None, columns=None, chunksize=None):
from sqlalchemy import select

if columns is not None and len(columns) > 0:
from sqlalchemy import select

cols = [self.table.c[n] for n in columns]
if self.index is not None:
for idx in self.index[::-1]:
cols.insert(0, self.table.c[idx])
sql_select = select(cols)
sql_select = select(*cols) if _gt14() else select(cols)
else:
sql_select = self.table.select()
sql_select = select(self.table) if _gt14() else self.table.select()

result = self.pd_sql.execute(sql_select)
column_names = result.keys()
Expand Down Expand Up @@ -1043,6 +1042,7 @@ def _create_table_setup(self):
PrimaryKeyConstraint,
Table,
)
from sqlalchemy.schema import MetaData

column_names_and_types = self._get_column_names_and_types(self._sqlalchemy_type)

Expand All @@ -1063,10 +1063,7 @@ def _create_table_setup(self):

# At this point, attach to new metadata, only attach to self.meta
# once table is created.
from sqlalchemy.schema import MetaData

meta = MetaData(self.pd_sql, schema=schema)

meta = MetaData()
return Table(self.name, meta, *columns, schema=schema)

def _harmonize_columns(self, parse_dates=None):
Expand Down Expand Up @@ -1355,15 +1352,19 @@ def __init__(self, engine, schema: str | None = None):
from sqlalchemy.schema import MetaData

self.connectable = engine
self.meta = MetaData(self.connectable, schema=schema)
self.meta = MetaData(schema=schema)
self.meta.reflect(bind=engine)

@contextmanager
def run_transaction(self):
with self.connectable.begin() as tx:
if hasattr(tx, "execute"):
yield tx
else:
yield self.connectable
from sqlalchemy.engine import Engine

if isinstance(self.connectable, Engine):
with self.connectable.connect() as conn:
with conn.begin():
yield conn
else:
yield self.connectable

def execute(self, *args, **kwargs):
"""Simple passthrough to SQLAlchemy connectable"""
Expand Down Expand Up @@ -1724,9 +1725,9 @@ def tables(self):

def has_table(self, name: str, schema: str | None = None):
if _gt14():
import sqlalchemy as sa
from sqlalchemy import inspect

insp = sa.inspect(self.connectable)
insp = inspect(self.connectable)
return insp.has_table(name, schema or self.meta.schema)
else:
return self.connectable.run_callable(
Expand All @@ -1752,8 +1753,8 @@ def get_table(self, table_name: str, schema: str | None = None):
def drop_table(self, table_name: str, schema: str | None = None):
schema = schema or self.meta.schema
if self.has_table(table_name, schema):
self.meta.reflect(only=[table_name], schema=schema)
self.get_table(table_name, schema).drop()
self.meta.reflect(bind=self.connectable, only=[table_name], schema=schema)
self.get_table(table_name, schema).drop(bind=self.connectable)
self.meta.clear()

def _create_sql_schema(
Expand Down
61 changes: 42 additions & 19 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
import pandas.io.sql as sql
from pandas.io.sql import (
SQLAlchemyEngine,
SQLDatabase,
SQLiteDatabase,
_gt14,
get_engine,
read_sql_query,
Expand Down Expand Up @@ -150,7 +152,8 @@ def create_and_load_iris(conn, iris_file: Path, dialect: str):
stmt = insert(iris).values(params)
if isinstance(conn, Engine):
with conn.connect() as conn:
conn.execute(stmt)
with conn.begin():
conn.execute(stmt)
else:
conn.execute(stmt)

Expand All @@ -167,7 +170,8 @@ def create_and_load_iris_view(conn):
stmt = text(stmt)
if isinstance(conn, Engine):
with conn.connect() as conn:
conn.execute(stmt)
with conn.begin():
conn.execute(stmt)
else:
conn.execute(stmt)

Expand Down Expand Up @@ -238,7 +242,8 @@ def create_and_load_types(conn, types_data: list[dict], dialect: str):
stmt = insert(types).values(types_data)
if isinstance(conn, Engine):
with conn.connect() as conn:
conn.execute(stmt)
with conn.begin():
conn.execute(stmt)
else:
conn.execute(stmt)

Expand Down Expand Up @@ -601,13 +606,24 @@ def _to_sql_save_index(self):

def _transaction_test(self):
with self.pandasSQL.run_transaction() as trans:
trans.execute("CREATE TABLE test_trans (A INT, B TEXT)")
stmt = "CREATE TABLE test_trans (A INT, B TEXT)"
if isinstance(self.pandasSQL, SQLiteDatabase):
trans.execute(stmt)
else:
from sqlalchemy import text

stmt = text(stmt)
trans.execute(stmt)

class DummyException(Exception):
pass

# Make sure when transaction is rolled back, no rows get inserted
ins_sql = "INSERT INTO test_trans (A,B) VALUES (1, 'blah')"
if isinstance(self.pandasSQL, SQLDatabase):
from sqlalchemy import text

ins_sql = text(ins_sql)
try:
with self.pandasSQL.run_transaction() as trans:
trans.execute(ins_sql)
Expand Down Expand Up @@ -1127,12 +1143,20 @@ def test_read_sql_delegate(self):

def test_not_reflect_all_tables(self):
from sqlalchemy import text
from sqlalchemy.engine import Engine

# create invalid table
qry = text("CREATE TABLE invalid (x INTEGER, y UNKNOWN);")
self.conn.execute(qry)
qry = text("CREATE TABLE other_table (x INTEGER, y INTEGER);")
self.conn.execute(qry)
query_list = [
text("CREATE TABLE invalid (x INTEGER, y UNKNOWN);"),
text("CREATE TABLE other_table (x INTEGER, y INTEGER);"),
]
for query in query_list:
if isinstance(self.conn, Engine):
with self.conn.connect() as conn:
with conn.begin():
conn.execute(query)
else:
self.conn.execute(query)

with tm.assert_produces_warning(None):
sql.read_sql_table("other_table", self.conn)
Expand Down Expand Up @@ -1858,7 +1882,8 @@ def test_get_schema_create_table(self, test_frame3):
create_sql = text(create_sql)
if isinstance(self.conn, Engine):
with self.conn.connect() as conn:
conn.execute(create_sql)
with conn.begin():
conn.execute(create_sql)
else:
self.conn.execute(create_sql)
returned_df = sql.read_sql_table(tbl, self.conn)
Expand Down Expand Up @@ -2203,11 +2228,11 @@ def test_default_type_conversion(self):
assert issubclass(df.BoolColWithNull.dtype.type, np.floating)

def test_read_procedure(self):
import pymysql
from sqlalchemy import text
from sqlalchemy.engine import Engine

# see GH7324. Although it is more an api test, it is added to the
# GH 7324
# Although it is more an api test, it is added to the
# mysql tests as sqlite does not have stored procedures
df = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3]})
df.to_sql("test_procedure", self.conn, index=False)
Expand All @@ -2220,14 +2245,12 @@ def test_read_procedure(self):
SELECT * FROM test_procedure;
END"""
proc = text(proc)
connection = self.conn.connect() if isinstance(self.conn, Engine) else self.conn
trans = connection.begin()
try:
_ = connection.execute(proc)
trans.commit()
except pymysql.Error:
trans.rollback()
raise
if isinstance(self.conn, Engine):
with self.conn.connect() as conn:
with conn.begin():
conn.execute(proc)
else:
self.conn.execute(proc)

res1 = sql.read_sql_query("CALL get_testdb();", self.conn)
tm.assert_frame_equal(df, res1)
Expand Down