@@ -38,7 +38,7 @@ class DatabaseError(IOError):
38
38
_SQLALCHEMY_INSTALLED = None
39
39
40
40
41
- def _is_sqlalchemy_engine (con ):
41
+ def _is_sqlalchemy_connectable (con ):
42
42
global _SQLALCHEMY_INSTALLED
43
43
if _SQLALCHEMY_INSTALLED is None :
44
44
try :
@@ -62,7 +62,7 @@ def compile_big_int_sqlite(type_, compiler, **kw):
62
62
63
63
if _SQLALCHEMY_INSTALLED :
64
64
import sqlalchemy
65
- return isinstance (con , sqlalchemy .engine .Engine )
65
+ return isinstance (con , sqlalchemy .engine .Connectable )
66
66
else :
67
67
return False
68
68
@@ -328,7 +328,7 @@ def read_sql_table(table_name, con, schema=None, index_col=None,
328
328
read_sql
329
329
330
330
"""
331
- if not _is_sqlalchemy_engine (con ):
331
+ if not _is_sqlalchemy_connectable (con ):
332
332
raise NotImplementedError ("read_sql_table only supported for "
333
333
"SQLAlchemy engines." )
334
334
import sqlalchemy
@@ -592,7 +592,7 @@ def pandasSQL_builder(con, flavor=None, schema=None, meta=None,
592
592
"""
593
593
# When support for DBAPI connections is removed,
594
594
# is_cursor should not be necessary.
595
- if _is_sqlalchemy_engine (con ):
595
+ if _is_sqlalchemy_connectable (con ):
596
596
return SQLDatabase (con , schema = schema , meta = meta )
597
597
else :
598
598
if flavor == 'mysql' :
@@ -637,7 +637,8 @@ def exists(self):
637
637
638
638
def sql_schema (self ):
639
639
from sqlalchemy .schema import CreateTable
640
- return str (CreateTable (self .table ).compile (self .pd_sql .engine ))
640
+ engine = self .pd_sql .connectable .engine
641
+ return str (CreateTable (self .table ).compile (engine ))
641
642
642
643
def _execute_create (self ):
643
644
# Inserting table into database, add to MetaData object
@@ -993,7 +994,7 @@ class SQLDatabase(PandasSQL):
993
994
994
995
Parameters
995
996
----------
996
- engine : SQLAlchemy engine
997
+ engine : SQLAlchemy Connectable (Engine or Connection)
997
998
Engine to connect with the database. Using SQLAlchemy makes it
998
999
possible to use any DB supported by that library.
999
1000
schema : string, default None
@@ -1007,19 +1008,35 @@ class SQLDatabase(PandasSQL):
1007
1008
"""
1008
1009
1009
1010
def __init__ (self , engine , schema = None , meta = None ):
1010
- self .engine = engine
1011
+ self .connectable = engine
1011
1012
if not meta :
1012
1013
from sqlalchemy .schema import MetaData
1013
- meta = MetaData (self .engine , schema = schema )
1014
+ meta = MetaData (self .connectable , schema = schema )
1014
1015
1015
1016
self .meta = meta
1016
1017
1018
+ class _RunTransaction (object ):
1019
+ def __init__ (self , connectable ):
1020
+ tx = connectable .begin ()
1021
+ if hasattr (tx , 'execute' ):
1022
+ self .connectable = tx
1023
+ else :
1024
+ self .connectable = connectable
1025
+
1026
+ def __enter__ (self , * args , ** kwargs ):
1027
+ self .connectable
1028
+
1029
+ @contextmanager
1017
1030
def run_transaction (self ):
1018
- return self .engine .begin ()
1031
+ with self .connectable .begin () as tx :
1032
+ if hasattr (tx , 'execute' ):
1033
+ yield tx
1034
+ else :
1035
+ yield self .connectable
1019
1036
1020
1037
def execute (self , * args , ** kwargs ):
1021
1038
"""Simple passthrough to SQLAlchemy engine"""
1022
- return self .engine .execute (* args , ** kwargs )
1039
+ return self .connectable .execute (* args , ** kwargs )
1023
1040
1024
1041
def read_table (self , table_name , index_col = None , coerce_float = True ,
1025
1042
parse_dates = None , columns = None , schema = None ,
@@ -1187,7 +1204,8 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
1187
1204
table .create ()
1188
1205
table .insert (chunksize )
1189
1206
# check for potentially case sensitivity issues (GH7815)
1190
- if name not in self .engine .table_names (schema = schema or self .meta .schema ):
1207
+ engine = self .connectable .engine
1208
+ if name not in engine .table_names (schema = schema or self .meta .schema ):
1191
1209
warnings .warn ("The provided table name '{0}' is not found exactly "
1192
1210
"as such in the database after writing the table, "
1193
1211
"possibly due to case sensitivity issues. Consider "
@@ -1198,7 +1216,8 @@ def tables(self):
1198
1216
return self .meta .tables
1199
1217
1200
1218
def has_table (self , name , schema = None ):
1201
- return self .engine .has_table (name , schema or self .meta .schema )
1219
+ engine = self .connectable .engine
1220
+ return engine .has_table (name , schema or self .meta .schema )
1202
1221
1203
1222
def get_table (self , table_name , schema = None ):
1204
1223
schema = schema or self .meta .schema
@@ -1217,7 +1236,8 @@ def get_table(self, table_name, schema=None):
1217
1236
1218
1237
def drop_table (self , table_name , schema = None ):
1219
1238
schema = schema or self .meta .schema
1220
- if self .engine .has_table (table_name , schema ):
1239
+ engine = self .connectable .engine
1240
+ if engine .has_table (table_name , schema ):
1221
1241
self .meta .reflect (only = [table_name ], schema = schema )
1222
1242
self .get_table (table_name , schema ).drop ()
1223
1243
self .meta .clear ()
0 commit comments