@@ -62,7 +62,7 @@ def compile_big_int_sqlite(type_, compiler, **kw):
62
62
63
63
if _SQLALCHEMY_INSTALLED :
64
64
import sqlalchemy
65
- return isinstance (con , sqlalchemy .engine .Engine )
65
+ return isinstance (con , sqlalchemy .engine .Connectable )
66
66
else :
67
67
return False
68
68
@@ -637,7 +637,7 @@ def exists(self):
637
637
638
638
def sql_schema (self ):
639
639
from sqlalchemy .schema import CreateTable
640
- return str (CreateTable (self .table ).compile (self .pd_sql .engine ))
640
+ return str (CreateTable (self .table ).compile (self .pd_sql .connection ))
641
641
642
642
def _execute_create (self ):
643
643
# Inserting table into database, add to MetaData object
@@ -1006,20 +1006,31 @@ class SQLDatabase(PandasSQL):
1006
1006
1007
1007
"""
1008
1008
1009
- def __init__ (self , engine , schema = None , meta = None ):
1010
- self .engine = engine
1009
+ def __init__ (self , connection , schema = None , meta = None ):
1010
+ import sqlalchemy .engine
1011
+ if isinstance (connection , sqlalchemy .engine .Engine ):
1012
+ self .connection = connection .connect ()
1013
+ else :
1014
+ self .connection = connection
1011
1015
if not meta :
1012
1016
from sqlalchemy .schema import MetaData
1013
- meta = MetaData (self .engine , schema = schema )
1017
+ meta = MetaData (self .connection , schema = schema )
1014
1018
1015
1019
self .meta = meta
1016
1020
1021
+ @contextmanager
1017
1022
def run_transaction (self ):
1018
- return self .engine .begin ()
1023
+ trans = self .connection .begin ()
1024
+ try :
1025
+ yield self .connection
1026
+ trans .commit ()
1027
+ except :
1028
+ trans .rollback ()
1029
+ raise
1019
1030
1020
1031
def execute (self , * args , ** kwargs ):
1021
1032
"""Simple passthrough to SQLAlchemy engine"""
1022
- return self .engine .execute (* args , ** kwargs )
1033
+ return self .connection .execute (* args , ** kwargs )
1023
1034
1024
1035
def read_table (self , table_name , index_col = None , coerce_float = True ,
1025
1036
parse_dates = None , columns = None , schema = None ,
@@ -1187,7 +1198,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
1187
1198
table .create ()
1188
1199
table .insert (chunksize )
1189
1200
# check for potentially case sensitivity issues (GH7815)
1190
- if name not in self .engine .table_names (schema = schema or self .meta .schema ):
1201
+ if name not in self .connection . engine .table_names (schema = schema or self .meta .schema , connection = self . connection ):
1191
1202
warnings .warn ("The provided table name '{0}' is not found exactly "
1192
1203
"as such in the database after writing the table, "
1193
1204
"possibly due to case sensitivity issues. Consider "
@@ -1198,7 +1209,7 @@ def tables(self):
1198
1209
return self .meta .tables
1199
1210
1200
1211
def has_table (self , name , schema = None ):
1201
- return self .engine .has_table (name , schema or self .meta .schema )
1212
+ return self .connection . engine .has_table (name , schema or self .meta .schema )
1202
1213
1203
1214
def get_table (self , table_name , schema = None ):
1204
1215
schema = schema or self .meta .schema
@@ -1217,7 +1228,7 @@ def get_table(self, table_name, schema=None):
1217
1228
1218
1229
def drop_table (self , table_name , schema = None ):
1219
1230
schema = schema or self .meta .schema
1220
- if self .engine .has_table (table_name , schema ):
1231
+ if self .connection . engine .has_table (table_name , schema ):
1221
1232
self .meta .reflect (only = [table_name ], schema = schema )
1222
1233
self .get_table (table_name , schema ).drop ()
1223
1234
self .meta .clear ()
0 commit comments