@@ -38,7 +38,7 @@ class DatabaseError(IOError):
38
38
_SQLALCHEMY_INSTALLED = None
39
39
40
40
41
- def _is_sqlalchemy_engine (con ):
41
+ def _is_sqlalchemy_connectable (con ):
42
42
global _SQLALCHEMY_INSTALLED
43
43
if _SQLALCHEMY_INSTALLED is None :
44
44
try :
@@ -61,8 +61,8 @@ def compile_big_int_sqlite(type_, compiler, **kw):
61
61
_SQLALCHEMY_INSTALLED = False
62
62
63
63
if _SQLALCHEMY_INSTALLED :
64
- import sqlalchemy
65
- return isinstance (con , sqlalchemy . engine . Engine )
64
+ from sqlalchemy . engine import interfaces
65
+ return isinstance (con , interfaces . Connectable )
66
66
else :
67
67
return False
68
68
@@ -328,7 +328,7 @@ def read_sql_table(table_name, con, schema=None, index_col=None,
328
328
read_sql
329
329
330
330
"""
331
- if not _is_sqlalchemy_engine (con ):
331
+ if not _is_sqlalchemy_connectable (con ):
332
332
raise NotImplementedError ("read_sql_table only supported for "
333
333
"SQLAlchemy engines." )
334
334
import sqlalchemy
@@ -592,7 +592,7 @@ def pandasSQL_builder(con, flavor=None, schema=None, meta=None,
592
592
"""
593
593
# When support for DBAPI connections is removed,
594
594
# is_cursor should not be necessary.
595
- if _is_sqlalchemy_engine (con ):
595
+ if _is_sqlalchemy_connectable (con ):
596
596
return SQLDatabase (con , schema = schema , meta = meta )
597
597
else :
598
598
if flavor == 'mysql' :
@@ -993,7 +993,7 @@ class SQLDatabase(PandasSQL):
993
993
994
994
Parameters
995
995
----------
996
- engine : SQLAlchemy engine
996
+ engine : SQLAlchemy engine or connection
997
997
Engine to connect with the database. Using SQLAlchemy makes it
998
998
possible to use any DB supported by that library.
999
999
schema : string, default None
@@ -1014,8 +1014,28 @@ def __init__(self, engine, schema=None, meta=None):
1014
1014
1015
1015
self .meta = meta
1016
1016
1017
+
1018
+ class SQLAlchemyTransaction (object ):
1019
+ """
1020
+ Context manager for sql alchemy transactions
1021
+ that returns the relevent connectable.
1022
+ """
1023
+ def __init__ (connectable ):
1024
+ self .connectable = connectable
1025
+
1026
+ def __enter__ (self ):
1027
+ self .tx = self .connectable .begin ()
1028
+ if _is_sqlalchemy_connectable (self .tx ):
1029
+ return self .tx
1030
+ else :
1031
+ return self .connectable
1032
+
1033
+ def __exit__ (self , * args , ** kwargs ):
1034
+ return self .tx .__exit__ (* args , ** kwargs )
1035
+
1036
+
1017
1037
def run_transaction (self ):
1018
- return self .engine . begin ( )
1038
+ return self .SQLAlchemyTransaction ( self . engine )
1019
1039
1020
1040
def execute (self , * args , ** kwargs ):
1021
1041
"""Simple passthrough to SQLAlchemy engine"""
0 commit comments