Skip to content

Commit 83ef3e4

Browse files
author
Chuck Cadman
committed
TST: Upgrade tests to work with sqlalchemy 2.0
1 parent 848dc71 commit 83ef3e4

File tree

1 file changed

+54
-22
lines changed

1 file changed

+54
-22
lines changed

pandas/tests/io/test_sql.py

+54-22
Original file line numberDiff line numberDiff line change
@@ -269,25 +269,21 @@ def count_rows(conn, table_name: str):
269269
cur = conn.cursor()
270270
return cur.execute(stmt).fetchone()[0]
271271
else:
272-
from sqlalchemy import (
273-
create_engine,
274-
text,
275-
)
272+
from sqlalchemy import create_engine
276273
from sqlalchemy.engine import Engine
277274

278-
stmt = text(stmt)
279275
if isinstance(conn, str):
280276
try:
281277
engine = create_engine(conn)
282278
with engine.connect() as conn:
283-
return conn.execute(stmt).scalar_one()
279+
return conn.exec_driver_sql(stmt).scalar_one()
284280
finally:
285281
engine.dispose()
286282
elif isinstance(conn, Engine):
287283
with conn.connect() as conn:
288-
return conn.execute(stmt).scalar_one()
284+
return conn.exec_driver_sql(stmt).scalar_one()
289285
else:
290-
return conn.execute(stmt).scalar_one()
286+
return conn.exec_driver_sql(stmt).scalar_one()
291287

292288

293289
@pytest.fixture
@@ -417,7 +413,8 @@ def mysql_pymysql_engine(iris_path, types_data):
417413

418414
@pytest.fixture
419415
def mysql_pymysql_conn(mysql_pymysql_engine):
420-
yield mysql_pymysql_engine.connect()
416+
with mysql_pymysql_engine.connect() as conn:
417+
yield conn
421418

422419

423420
@pytest.fixture
@@ -443,7 +440,8 @@ def postgresql_psycopg2_engine(iris_path, types_data):
443440

444441
@pytest.fixture
445442
def postgresql_psycopg2_conn(postgresql_psycopg2_engine):
446-
yield postgresql_psycopg2_engine.connect()
443+
with postgresql_psycopg2_engine.connect() as conn:
444+
yield conn
447445

448446

449447
@pytest.fixture
@@ -463,7 +461,8 @@ def sqlite_engine(sqlite_str):
463461

464462
@pytest.fixture
465463
def sqlite_conn(sqlite_engine):
466-
yield sqlite_engine.connect()
464+
with sqlite_engine.connect() as conn:
465+
yield conn
467466

468467

469468
@pytest.fixture
@@ -483,7 +482,8 @@ def sqlite_iris_engine(sqlite_engine, iris_path):
483482

484483
@pytest.fixture
485484
def sqlite_iris_conn(sqlite_iris_engine):
486-
yield sqlite_iris_engine.connect()
485+
with sqlite_iris_engine.connect() as conn:
486+
yield conn
487487

488488

489489
@pytest.fixture
@@ -538,7 +538,7 @@ def sqlite_buildin_iris(sqlite_buildin, iris_path):
538538
@pytest.mark.parametrize("method", [None, "multi"])
539539
def test_to_sql(conn, method, test_frame1, request):
540540
conn = request.getfixturevalue(conn)
541-
with pandasSQL_builder(conn) as pandasSQL:
541+
with pandasSQL_builder(conn, need_transaction=True) as pandasSQL:
542542
pandasSQL.to_sql(test_frame1, "test_frame", method=method)
543543
assert pandasSQL.has_table("test_frame")
544544
assert count_rows(conn, "test_frame") == len(test_frame1)
@@ -549,7 +549,7 @@ def test_to_sql(conn, method, test_frame1, request):
549549
@pytest.mark.parametrize("mode, num_row_coef", [("replace", 1), ("append", 2)])
550550
def test_to_sql_exist(conn, mode, num_row_coef, test_frame1, request):
551551
conn = request.getfixturevalue(conn)
552-
with pandasSQL_builder(conn) as pandasSQL:
552+
with pandasSQL_builder(conn, need_transaction=True) as pandasSQL:
553553
pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail")
554554
pandasSQL.to_sql(test_frame1, "test_frame", if_exists=mode)
555555
assert pandasSQL.has_table("test_frame")
@@ -560,7 +560,7 @@ def test_to_sql_exist(conn, mode, num_row_coef, test_frame1, request):
560560
@pytest.mark.parametrize("conn", all_connectable)
561561
def test_to_sql_exist_fail(conn, test_frame1, request):
562562
conn = request.getfixturevalue(conn)
563-
with pandasSQL_builder(conn) as pandasSQL:
563+
with pandasSQL_builder(conn, need_transaction=True) as pandasSQL:
564564
pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail")
565565
assert pandasSQL.has_table("test_frame")
566566

@@ -613,6 +613,8 @@ def test_read_iris_query_expression_with_parameter(conn, request):
613613
select(iris), conn, params={"name": "Iris-setosa", "length": 5.1}
614614
)
615615
check_iris_frame(iris_frame)
616+
if isinstance(conn, str):
617+
autoload_con.dispose()
616618

617619

618620
@pytest.mark.db
@@ -658,7 +660,7 @@ def sample(pd_table, conn, keys, data_iter):
658660
data = [dict(zip(keys, row)) for row in data_iter]
659661
conn.execute(pd_table.table.insert(), data)
660662

661-
with pandasSQL_builder(conn) as pandasSQL:
663+
with pandasSQL_builder(conn, need_transaction=True) as pandasSQL:
662664
pandasSQL.to_sql(test_frame1, "test_frame", method=sample)
663665
assert pandasSQL.has_table("test_frame")
664666
assert check == [1]
@@ -778,6 +780,8 @@ def teardown_method(self):
778780
pass
779781
else:
780782
with conn:
783+
for view in self._get_all_views(conn):
784+
self.drop_view(view, conn)
781785
for tbl in self._get_all_tables(conn):
782786
self.drop_table(tbl, conn)
783787

@@ -794,6 +798,14 @@ def _get_all_tables(self, conn):
794798
c = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
795799
return [table[0] for table in c.fetchall()]
796800

801+
def drop_view(self, view_name, conn):
802+
conn.execute(f"DROP VIEW IF EXISTS {sql._get_valid_sqlite_name(view_name)}")
803+
conn.commit()
804+
805+
def _get_all_views(self, conn):
806+
c = conn.execute("SELECT name FROM sqlite_master WHERE type='view'")
807+
return [view[0] for view in c.fetchall()]
808+
797809

798810
class SQLAlchemyMixIn(MixInBase):
799811
@classmethod
@@ -804,6 +816,8 @@ def connect(self):
804816
return self.engine.connect()
805817

806818
def drop_table(self, table_name, conn):
819+
if conn.in_transaction():
820+
conn.get_transaction().rollback()
807821
with conn.begin():
808822
sql.SQLDatabase(conn).drop_table(table_name)
809823

@@ -812,6 +826,20 @@ def _get_all_tables(self, conn):
812826

813827
return inspect(conn).get_table_names()
814828

829+
def drop_view(self, view_name, conn):
830+
quoted_view = conn.engine.dialect.identifier_preparer.quote_identifier(
831+
view_name
832+
)
833+
if conn.in_transaction():
834+
conn.get_transaction().rollback()
835+
with conn.begin():
836+
conn.exec_driver_sql(f"DROP VIEW IF EXISTS {quoted_view}")
837+
838+
def _get_all_views(self, conn):
839+
from sqlalchemy import inspect
840+
841+
return inspect(conn).get_view_names()
842+
815843

816844
class PandasSQLTest:
817845
"""
@@ -1745,8 +1773,8 @@ def test_create_table(self):
17451773
temp_frame = DataFrame(
17461774
{"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]}
17471775
)
1748-
pandasSQL = sql.SQLDatabase(temp_conn)
1749-
assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4
1776+
with sql.SQLDatabase(temp_conn, need_transaction=True) as pandasSQL:
1777+
assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4
17501778

17511779
insp = inspect(temp_conn)
17521780
assert insp.has_table("temp_frame")
@@ -1765,6 +1793,10 @@ def test_drop_table(self):
17651793
assert insp.has_table("temp_frame")
17661794

17671795
pandasSQL.drop_table("temp_frame")
1796+
try:
1797+
insp.clear_cache() # needed with SQLAlchemy 2.0, unavailable prior
1798+
except AttributeError:
1799+
pass
17681800
assert not insp.has_table("temp_frame")
17691801

17701802
def test_roundtrip(self, test_frame1):
@@ -2628,8 +2660,8 @@ def test_schema_support(self):
26282660
df = DataFrame({"col1": [1, 2], "col2": [0.1, 0.2], "col3": ["a", "n"]})
26292661

26302662
# create a schema
2631-
self.conn.execute("DROP SCHEMA IF EXISTS other CASCADE;")
2632-
self.conn.execute("CREATE SCHEMA other;")
2663+
self.conn.exec_driver_sql("DROP SCHEMA IF EXISTS other CASCADE;")
2664+
self.conn.exec_driver_sql("CREATE SCHEMA other;")
26332665

26342666
# write dataframe to different schema's
26352667
assert df.to_sql("test_schema_public", self.conn, index=False) == 2
@@ -2661,8 +2693,8 @@ def test_schema_support(self):
26612693
# different if_exists options
26622694

26632695
# create a schema
2664-
self.conn.execute("DROP SCHEMA IF EXISTS other CASCADE;")
2665-
self.conn.execute("CREATE SCHEMA other;")
2696+
self.conn.exec_driver_sql("DROP SCHEMA IF EXISTS other CASCADE;")
2697+
self.conn.exec_driver_sql("CREATE SCHEMA other;")
26662698

26672699
# write dataframe with different if_exists options
26682700
assert (

0 commit comments

Comments
 (0)