@@ -269,25 +269,21 @@ def count_rows(conn, table_name: str):
269
269
cur = conn .cursor ()
270
270
return cur .execute (stmt ).fetchone ()[0 ]
271
271
else :
272
- from sqlalchemy import (
273
- create_engine ,
274
- text ,
275
- )
272
+ from sqlalchemy import create_engine
276
273
from sqlalchemy .engine import Engine
277
274
278
- stmt = text (stmt )
279
275
if isinstance (conn , str ):
280
276
try :
281
277
engine = create_engine (conn )
282
278
with engine .connect () as conn :
283
- return conn .execute (stmt ).scalar_one ()
279
+ return conn .exec_driver_sql (stmt ).scalar_one ()
284
280
finally :
285
281
engine .dispose ()
286
282
elif isinstance (conn , Engine ):
287
283
with conn .connect () as conn :
288
- return conn .execute (stmt ).scalar_one ()
284
+ return conn .exec_driver_sql (stmt ).scalar_one ()
289
285
else :
290
- return conn .execute (stmt ).scalar_one ()
286
+ return conn .exec_driver_sql (stmt ).scalar_one ()
291
287
292
288
293
289
@pytest .fixture
@@ -417,7 +413,8 @@ def mysql_pymysql_engine(iris_path, types_data):
417
413
418
414
@pytest .fixture
419
415
def mysql_pymysql_conn (mysql_pymysql_engine ):
420
- yield mysql_pymysql_engine .connect ()
416
+ with mysql_pymysql_engine .connect () as conn :
417
+ yield conn
421
418
422
419
423
420
@pytest .fixture
@@ -443,7 +440,8 @@ def postgresql_psycopg2_engine(iris_path, types_data):
443
440
444
441
@pytest .fixture
445
442
def postgresql_psycopg2_conn (postgresql_psycopg2_engine ):
446
- yield postgresql_psycopg2_engine .connect ()
443
+ with postgresql_psycopg2_engine .connect () as conn :
444
+ yield conn
447
445
448
446
449
447
@pytest .fixture
@@ -463,7 +461,8 @@ def sqlite_engine(sqlite_str):
463
461
464
462
@pytest .fixture
465
463
def sqlite_conn (sqlite_engine ):
466
- yield sqlite_engine .connect ()
464
+ with sqlite_engine .connect () as conn :
465
+ yield conn
467
466
468
467
469
468
@pytest .fixture
@@ -483,7 +482,8 @@ def sqlite_iris_engine(sqlite_engine, iris_path):
483
482
484
483
@pytest .fixture
485
484
def sqlite_iris_conn (sqlite_iris_engine ):
486
- yield sqlite_iris_engine .connect ()
485
+ with sqlite_iris_engine .connect () as conn :
486
+ yield conn
487
487
488
488
489
489
@pytest .fixture
@@ -538,7 +538,7 @@ def sqlite_buildin_iris(sqlite_buildin, iris_path):
538
538
@pytest .mark .parametrize ("method" , [None , "multi" ])
539
539
def test_to_sql (conn , method , test_frame1 , request ):
540
540
conn = request .getfixturevalue (conn )
541
- with pandasSQL_builder (conn ) as pandasSQL :
541
+ with pandasSQL_builder (conn , need_transaction = True ) as pandasSQL :
542
542
pandasSQL .to_sql (test_frame1 , "test_frame" , method = method )
543
543
assert pandasSQL .has_table ("test_frame" )
544
544
assert count_rows (conn , "test_frame" ) == len (test_frame1 )
@@ -549,7 +549,7 @@ def test_to_sql(conn, method, test_frame1, request):
549
549
@pytest .mark .parametrize ("mode, num_row_coef" , [("replace" , 1 ), ("append" , 2 )])
550
550
def test_to_sql_exist (conn , mode , num_row_coef , test_frame1 , request ):
551
551
conn = request .getfixturevalue (conn )
552
- with pandasSQL_builder (conn ) as pandasSQL :
552
+ with pandasSQL_builder (conn , need_transaction = True ) as pandasSQL :
553
553
pandasSQL .to_sql (test_frame1 , "test_frame" , if_exists = "fail" )
554
554
pandasSQL .to_sql (test_frame1 , "test_frame" , if_exists = mode )
555
555
assert pandasSQL .has_table ("test_frame" )
@@ -560,7 +560,7 @@ def test_to_sql_exist(conn, mode, num_row_coef, test_frame1, request):
560
560
@pytest .mark .parametrize ("conn" , all_connectable )
561
561
def test_to_sql_exist_fail (conn , test_frame1 , request ):
562
562
conn = request .getfixturevalue (conn )
563
- with pandasSQL_builder (conn ) as pandasSQL :
563
+ with pandasSQL_builder (conn , need_transaction = True ) as pandasSQL :
564
564
pandasSQL .to_sql (test_frame1 , "test_frame" , if_exists = "fail" )
565
565
assert pandasSQL .has_table ("test_frame" )
566
566
@@ -613,6 +613,8 @@ def test_read_iris_query_expression_with_parameter(conn, request):
613
613
select (iris ), conn , params = {"name" : "Iris-setosa" , "length" : 5.1 }
614
614
)
615
615
check_iris_frame (iris_frame )
616
+ if isinstance (conn , str ):
617
+ autoload_con .dispose ()
616
618
617
619
618
620
@pytest .mark .db
@@ -658,7 +660,7 @@ def sample(pd_table, conn, keys, data_iter):
658
660
data = [dict (zip (keys , row )) for row in data_iter ]
659
661
conn .execute (pd_table .table .insert (), data )
660
662
661
- with pandasSQL_builder (conn ) as pandasSQL :
663
+ with pandasSQL_builder (conn , need_transaction = True ) as pandasSQL :
662
664
pandasSQL .to_sql (test_frame1 , "test_frame" , method = sample )
663
665
assert pandasSQL .has_table ("test_frame" )
664
666
assert check == [1 ]
@@ -778,6 +780,8 @@ def teardown_method(self):
778
780
pass
779
781
else :
780
782
with conn :
783
+ for view in self ._get_all_views (conn ):
784
+ self .drop_view (view , conn )
781
785
for tbl in self ._get_all_tables (conn ):
782
786
self .drop_table (tbl , conn )
783
787
@@ -794,6 +798,14 @@ def _get_all_tables(self, conn):
794
798
c = conn .execute ("SELECT name FROM sqlite_master WHERE type='table'" )
795
799
return [table [0 ] for table in c .fetchall ()]
796
800
801
+ def drop_view (self , view_name , conn ):
802
+ conn .execute (f"DROP VIEW IF EXISTS { sql ._get_valid_sqlite_name (view_name )} " )
803
+ conn .commit ()
804
+
805
+ def _get_all_views (self , conn ):
806
+ c = conn .execute ("SELECT name FROM sqlite_master WHERE type='view'" )
807
+ return [view [0 ] for view in c .fetchall ()]
808
+
797
809
798
810
class SQLAlchemyMixIn (MixInBase ):
799
811
@classmethod
@@ -804,6 +816,8 @@ def connect(self):
804
816
return self .engine .connect ()
805
817
806
818
def drop_table (self , table_name , conn ):
819
+ if conn .in_transaction ():
820
+ conn .get_transaction ().rollback ()
807
821
with conn .begin ():
808
822
sql .SQLDatabase (conn ).drop_table (table_name )
809
823
@@ -812,6 +826,20 @@ def _get_all_tables(self, conn):
812
826
813
827
return inspect (conn ).get_table_names ()
814
828
829
+ def drop_view (self , view_name , conn ):
830
+ quoted_view = conn .engine .dialect .identifier_preparer .quote_identifier (
831
+ view_name
832
+ )
833
+ if conn .in_transaction ():
834
+ conn .get_transaction ().rollback ()
835
+ with conn .begin ():
836
+ conn .exec_driver_sql (f"DROP VIEW IF EXISTS { quoted_view } " )
837
+
838
+ def _get_all_views (self , conn ):
839
+ from sqlalchemy import inspect
840
+
841
+ return inspect (conn ).get_view_names ()
842
+
815
843
816
844
class PandasSQLTest :
817
845
"""
@@ -1745,8 +1773,8 @@ def test_create_table(self):
1745
1773
temp_frame = DataFrame (
1746
1774
{"one" : [1.0 , 2.0 , 3.0 , 4.0 ], "two" : [4.0 , 3.0 , 2.0 , 1.0 ]}
1747
1775
)
1748
- pandasSQL = sql .SQLDatabase (temp_conn )
1749
- assert pandasSQL .to_sql (temp_frame , "temp_frame" ) == 4
1776
+ with sql .SQLDatabase (temp_conn , need_transaction = True ) as pandasSQL :
1777
+ assert pandasSQL .to_sql (temp_frame , "temp_frame" ) == 4
1750
1778
1751
1779
insp = inspect (temp_conn )
1752
1780
assert insp .has_table ("temp_frame" )
@@ -1765,6 +1793,10 @@ def test_drop_table(self):
1765
1793
assert insp .has_table ("temp_frame" )
1766
1794
1767
1795
pandasSQL .drop_table ("temp_frame" )
1796
+ try :
1797
+ insp .clear_cache () # needed with SQLAlchemy 2.0, unavailable prior
1798
+ except AttributeError :
1799
+ pass
1768
1800
assert not insp .has_table ("temp_frame" )
1769
1801
1770
1802
def test_roundtrip (self , test_frame1 ):
@@ -2628,8 +2660,8 @@ def test_schema_support(self):
2628
2660
df = DataFrame ({"col1" : [1 , 2 ], "col2" : [0.1 , 0.2 ], "col3" : ["a" , "n" ]})
2629
2661
2630
2662
# create a schema
2631
- self .conn .execute ("DROP SCHEMA IF EXISTS other CASCADE;" )
2632
- self .conn .execute ("CREATE SCHEMA other;" )
2663
+ self .conn .exec_driver_sql ("DROP SCHEMA IF EXISTS other CASCADE;" )
2664
+ self .conn .exec_driver_sql ("CREATE SCHEMA other;" )
2633
2665
2634
2666
# write dataframe to different schema's
2635
2667
assert df .to_sql ("test_schema_public" , self .conn , index = False ) == 2
@@ -2661,8 +2693,8 @@ def test_schema_support(self):
2661
2693
# different if_exists options
2662
2694
2663
2695
# create a schema
2664
- self .conn .execute ("DROP SCHEMA IF EXISTS other CASCADE;" )
2665
- self .conn .execute ("CREATE SCHEMA other;" )
2696
+ self .conn .exec_driver_sql ("DROP SCHEMA IF EXISTS other CASCADE;" )
2697
+ self .conn .exec_driver_sql ("CREATE SCHEMA other;" )
2666
2698
2667
2699
# write dataframe with different if_exists options
2668
2700
assert (
0 commit comments