Skip to content

Commit c4add5f

Browse files
fangchenlifeefladder
authored andcommitted
EHN: partially upgrade sql module for SQLAlchemy 2.0 compat (pandas-dev#43116)
1 parent 7a9bb25 commit c4add5f

File tree

2 files changed

+70
-46
lines changed

2 files changed

+70
-46
lines changed

pandas/io/sql.py

+28-27
Original file line numberDiff line numberDiff line change
@@ -279,13 +279,9 @@ def read_sql_table(
279279
--------
280280
>>> pd.read_sql_table('table_name', 'postgres:///db_name') # doctest:+SKIP
281281
"""
282-
from sqlalchemy.exc import InvalidRequestError
283-
284282
pandas_sql = pandasSQL_builder(con, schema=schema)
285-
try:
286-
pandas_sql.meta.reflect(only=[table_name], views=True)
287-
except InvalidRequestError as err:
288-
raise ValueError(f"Table {table_name} not found") from err
283+
if not pandas_sql.has_table(table_name):
284+
raise ValueError(f"Table {table_name} not found")
289285

290286
table = pandas_sql.read_table(
291287
table_name,
@@ -580,7 +576,7 @@ def read_sql(
580576
_is_table_name = False
581577

582578
if _is_table_name:
583-
pandas_sql.meta.reflect(only=[sql])
579+
pandas_sql.meta.reflect(bind=pandas_sql.connectable, only=[sql])
584580
return pandas_sql.read_table(
585581
sql,
586582
index_col=index_col,
@@ -803,7 +799,7 @@ def _execute_create(self):
803799
self.table = self.table.to_metadata(self.pd_sql.meta)
804800
else:
805801
self.table = self.table.tometadata(self.pd_sql.meta)
806-
self.table.create()
802+
self.table.create(bind=self.pd_sql.connectable)
807803

808804
def create(self):
809805
if self.exists():
@@ -842,8 +838,12 @@ def _execute_insert_multi(self, conn, keys: list[str], data_iter):
842838
and tables containing a few columns
843839
but performance degrades quickly with increase of columns.
844840
"""
841+
842+
from sqlalchemy import insert
843+
845844
data = [dict(zip(keys, row)) for row in data_iter]
846-
conn.execute(self.table.insert(data))
845+
stmt = insert(self.table).values(data)
846+
conn.execute(stmt)
847847

848848
def insert_data(self):
849849
if self.index is not None:
@@ -951,17 +951,16 @@ def _query_iterator(
951951
yield self.frame
952952

953953
def read(self, coerce_float=True, parse_dates=None, columns=None, chunksize=None):
954+
from sqlalchemy import select
954955

955956
if columns is not None and len(columns) > 0:
956-
from sqlalchemy import select
957-
958957
cols = [self.table.c[n] for n in columns]
959958
if self.index is not None:
960959
for idx in self.index[::-1]:
961960
cols.insert(0, self.table.c[idx])
962-
sql_select = select(cols)
961+
sql_select = select(*cols) if _gt14() else select(cols)
963962
else:
964-
sql_select = self.table.select()
963+
sql_select = select(self.table) if _gt14() else self.table.select()
965964

966965
result = self.pd_sql.execute(sql_select)
967966
column_names = result.keys()
@@ -1043,6 +1042,7 @@ def _create_table_setup(self):
10431042
PrimaryKeyConstraint,
10441043
Table,
10451044
)
1045+
from sqlalchemy.schema import MetaData
10461046

10471047
column_names_and_types = self._get_column_names_and_types(self._sqlalchemy_type)
10481048

@@ -1063,10 +1063,7 @@ def _create_table_setup(self):
10631063

10641064
# At this point, attach to new metadata, only attach to self.meta
10651065
# once table is created.
1066-
from sqlalchemy.schema import MetaData
1067-
1068-
meta = MetaData(self.pd_sql, schema=schema)
1069-
1066+
meta = MetaData()
10701067
return Table(self.name, meta, *columns, schema=schema)
10711068

10721069
def _harmonize_columns(self, parse_dates=None):
@@ -1355,15 +1352,19 @@ def __init__(self, engine, schema: str | None = None):
13551352
from sqlalchemy.schema import MetaData
13561353

13571354
self.connectable = engine
1358-
self.meta = MetaData(self.connectable, schema=schema)
1355+
self.meta = MetaData(schema=schema)
1356+
self.meta.reflect(bind=engine)
13591357

13601358
@contextmanager
13611359
def run_transaction(self):
1362-
with self.connectable.begin() as tx:
1363-
if hasattr(tx, "execute"):
1364-
yield tx
1365-
else:
1366-
yield self.connectable
1360+
from sqlalchemy.engine import Engine
1361+
1362+
if isinstance(self.connectable, Engine):
1363+
with self.connectable.connect() as conn:
1364+
with conn.begin():
1365+
yield conn
1366+
else:
1367+
yield self.connectable
13671368

13681369
def execute(self, *args, **kwargs):
13691370
"""Simple passthrough to SQLAlchemy connectable"""
@@ -1724,9 +1725,9 @@ def tables(self):
17241725

17251726
def has_table(self, name: str, schema: str | None = None):
17261727
if _gt14():
1727-
import sqlalchemy as sa
1728+
from sqlalchemy import inspect
17281729

1729-
insp = sa.inspect(self.connectable)
1730+
insp = inspect(self.connectable)
17301731
return insp.has_table(name, schema or self.meta.schema)
17311732
else:
17321733
return self.connectable.run_callable(
@@ -1752,8 +1753,8 @@ def get_table(self, table_name: str, schema: str | None = None):
17521753
def drop_table(self, table_name: str, schema: str | None = None):
17531754
schema = schema or self.meta.schema
17541755
if self.has_table(table_name, schema):
1755-
self.meta.reflect(only=[table_name], schema=schema)
1756-
self.get_table(table_name, schema).drop()
1756+
self.meta.reflect(bind=self.connectable, only=[table_name], schema=schema)
1757+
self.get_table(table_name, schema).drop(bind=self.connectable)
17571758
self.meta.clear()
17581759

17591760
def _create_sql_schema(

pandas/tests/io/test_sql.py

+42-19
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
import pandas.io.sql as sql
5555
from pandas.io.sql import (
5656
SQLAlchemyEngine,
57+
SQLDatabase,
58+
SQLiteDatabase,
5759
_gt14,
5860
get_engine,
5961
read_sql_query,
@@ -150,7 +152,8 @@ def create_and_load_iris(conn, iris_file: Path, dialect: str):
150152
stmt = insert(iris).values(params)
151153
if isinstance(conn, Engine):
152154
with conn.connect() as conn:
153-
conn.execute(stmt)
155+
with conn.begin():
156+
conn.execute(stmt)
154157
else:
155158
conn.execute(stmt)
156159

@@ -167,7 +170,8 @@ def create_and_load_iris_view(conn):
167170
stmt = text(stmt)
168171
if isinstance(conn, Engine):
169172
with conn.connect() as conn:
170-
conn.execute(stmt)
173+
with conn.begin():
174+
conn.execute(stmt)
171175
else:
172176
conn.execute(stmt)
173177

@@ -238,7 +242,8 @@ def create_and_load_types(conn, types_data: list[dict], dialect: str):
238242
stmt = insert(types).values(types_data)
239243
if isinstance(conn, Engine):
240244
with conn.connect() as conn:
241-
conn.execute(stmt)
245+
with conn.begin():
246+
conn.execute(stmt)
242247
else:
243248
conn.execute(stmt)
244249

@@ -601,13 +606,24 @@ def _to_sql_save_index(self):
601606

602607
def _transaction_test(self):
603608
with self.pandasSQL.run_transaction() as trans:
604-
trans.execute("CREATE TABLE test_trans (A INT, B TEXT)")
609+
stmt = "CREATE TABLE test_trans (A INT, B TEXT)"
610+
if isinstance(self.pandasSQL, SQLiteDatabase):
611+
trans.execute(stmt)
612+
else:
613+
from sqlalchemy import text
614+
615+
stmt = text(stmt)
616+
trans.execute(stmt)
605617

606618
class DummyException(Exception):
607619
pass
608620

609621
# Make sure when transaction is rolled back, no rows get inserted
610622
ins_sql = "INSERT INTO test_trans (A,B) VALUES (1, 'blah')"
623+
if isinstance(self.pandasSQL, SQLDatabase):
624+
from sqlalchemy import text
625+
626+
ins_sql = text(ins_sql)
611627
try:
612628
with self.pandasSQL.run_transaction() as trans:
613629
trans.execute(ins_sql)
@@ -1127,12 +1143,20 @@ def test_read_sql_delegate(self):
11271143

11281144
def test_not_reflect_all_tables(self):
11291145
from sqlalchemy import text
1146+
from sqlalchemy.engine import Engine
11301147

11311148
# create invalid table
1132-
qry = text("CREATE TABLE invalid (x INTEGER, y UNKNOWN);")
1133-
self.conn.execute(qry)
1134-
qry = text("CREATE TABLE other_table (x INTEGER, y INTEGER);")
1135-
self.conn.execute(qry)
1149+
query_list = [
1150+
text("CREATE TABLE invalid (x INTEGER, y UNKNOWN);"),
1151+
text("CREATE TABLE other_table (x INTEGER, y INTEGER);"),
1152+
]
1153+
for query in query_list:
1154+
if isinstance(self.conn, Engine):
1155+
with self.conn.connect() as conn:
1156+
with conn.begin():
1157+
conn.execute(query)
1158+
else:
1159+
self.conn.execute(query)
11361160

11371161
with tm.assert_produces_warning(None):
11381162
sql.read_sql_table("other_table", self.conn)
@@ -1858,7 +1882,8 @@ def test_get_schema_create_table(self, test_frame3):
18581882
create_sql = text(create_sql)
18591883
if isinstance(self.conn, Engine):
18601884
with self.conn.connect() as conn:
1861-
conn.execute(create_sql)
1885+
with conn.begin():
1886+
conn.execute(create_sql)
18621887
else:
18631888
self.conn.execute(create_sql)
18641889
returned_df = sql.read_sql_table(tbl, self.conn)
@@ -2203,11 +2228,11 @@ def test_default_type_conversion(self):
22032228
assert issubclass(df.BoolColWithNull.dtype.type, np.floating)
22042229

22052230
def test_read_procedure(self):
2206-
import pymysql
22072231
from sqlalchemy import text
22082232
from sqlalchemy.engine import Engine
22092233

2210-
# see GH7324. Although it is more an api test, it is added to the
2234+
# GH 7324
2235+
# Although it is more an api test, it is added to the
22112236
# mysql tests as sqlite does not have stored procedures
22122237
df = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3]})
22132238
df.to_sql("test_procedure", self.conn, index=False)
@@ -2220,14 +2245,12 @@ def test_read_procedure(self):
22202245
SELECT * FROM test_procedure;
22212246
END"""
22222247
proc = text(proc)
2223-
connection = self.conn.connect() if isinstance(self.conn, Engine) else self.conn
2224-
trans = connection.begin()
2225-
try:
2226-
_ = connection.execute(proc)
2227-
trans.commit()
2228-
except pymysql.Error:
2229-
trans.rollback()
2230-
raise
2248+
if isinstance(self.conn, Engine):
2249+
with self.conn.connect() as conn:
2250+
with conn.begin():
2251+
conn.execute(proc)
2252+
else:
2253+
self.conn.execute(proc)
22312254

22322255
res1 = sql.read_sql_query("CALL get_testdb();", self.conn)
22332256
tm.assert_frame_equal(df, res1)

0 commit comments

Comments
 (0)