Skip to content

Commit 5fd3d9a

Browse files
author
Thomas Grainger
committed
support both sqlalchemy engines and connections Fixes pandas-dev#7877
1 parent 1c3449d commit 5fd3d9a

File tree

2 files changed

+90
-17
lines changed

2 files changed

+90
-17
lines changed

pandas/io/sql.py

+33-13
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class DatabaseError(IOError):
3838
_SQLALCHEMY_INSTALLED = None
3939

4040

41-
def _is_sqlalchemy_engine(con):
41+
def _is_sqlalchemy_connectable(con):
4242
global _SQLALCHEMY_INSTALLED
4343
if _SQLALCHEMY_INSTALLED is None:
4444
try:
@@ -62,7 +62,7 @@ def compile_big_int_sqlite(type_, compiler, **kw):
6262

6363
if _SQLALCHEMY_INSTALLED:
6464
import sqlalchemy
65-
return isinstance(con, sqlalchemy.engine.Engine)
65+
return isinstance(con, sqlalchemy.engine.Connectable)
6666
else:
6767
return False
6868

@@ -328,7 +328,7 @@ def read_sql_table(table_name, con, schema=None, index_col=None,
328328
read_sql
329329
330330
"""
331-
if not _is_sqlalchemy_engine(con):
331+
if not _is_sqlalchemy_connectable(con):
332332
raise NotImplementedError("read_sql_table only supported for "
333333
"SQLAlchemy engines.")
334334
import sqlalchemy
@@ -592,7 +592,7 @@ def pandasSQL_builder(con, flavor=None, schema=None, meta=None,
592592
"""
593593
# When support for DBAPI connections is removed,
594594
# is_cursor should not be necessary.
595-
if _is_sqlalchemy_engine(con):
595+
if _is_sqlalchemy_connectable(con):
596596
return SQLDatabase(con, schema=schema, meta=meta)
597597
else:
598598
if flavor == 'mysql':
@@ -637,7 +637,8 @@ def exists(self):
637637

638638
def sql_schema(self):
639639
from sqlalchemy.schema import CreateTable
640-
return str(CreateTable(self.table).compile(self.pd_sql.engine))
640+
engine = self.pd_sql.connectable.engine
641+
return str(CreateTable(self.table).compile(engine))
641642

642643
def _execute_create(self):
643644
# Inserting table into database, add to MetaData object
@@ -993,7 +994,7 @@ class SQLDatabase(PandasSQL):
993994
994995
Parameters
995996
----------
996-
engine : SQLAlchemy engine
997+
engine : SQLAlchemy Connectable (Engine or Connection)
997998
Engine to connect with the database. Using SQLAlchemy makes it
998999
possible to use any DB supported by that library.
9991000
schema : string, default None
@@ -1007,19 +1008,35 @@ class SQLDatabase(PandasSQL):
10071008
"""
10081009

10091010
def __init__(self, engine, schema=None, meta=None):
1010-
self.engine = engine
1011+
self.connectable = engine
10111012
if not meta:
10121013
from sqlalchemy.schema import MetaData
1013-
meta = MetaData(self.engine, schema=schema)
1014+
meta = MetaData(self.connectable, schema=schema)
10141015

10151016
self.meta = meta
10161017

1018+
class _RunTransaction(object):
1019+
def __init__(self, connectable):
1020+
tx = connectable.begin()
1021+
if hasattr(tx, 'execute'):
1022+
self.connectable = tx
1023+
else:
1024+
self.connectable = connectable
1025+
1026+
def __enter__(self, *args, **kwargs):
1027+
self.connectable
1028+
1029+
@contextmanager
10171030
def run_transaction(self):
1018-
return self.engine.begin()
1031+
with self.connectable.begin() as tx:
1032+
if hasattr(tx, 'execute'):
1033+
yield tx
1034+
else:
1035+
yield self.connectable
10191036

10201037
def execute(self, *args, **kwargs):
10211038
"""Simple passthrough to SQLAlchemy engine"""
1022-
return self.engine.execute(*args, **kwargs)
1039+
return self.connectable.execute(*args, **kwargs)
10231040

10241041
def read_table(self, table_name, index_col=None, coerce_float=True,
10251042
parse_dates=None, columns=None, schema=None,
@@ -1187,7 +1204,8 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
11871204
table.create()
11881205
table.insert(chunksize)
11891206
# check for potentially case sensitivity issues (GH7815)
1190-
if name not in self.engine.table_names(schema=schema or self.meta.schema):
1207+
engine = self.connectable.engine
1208+
if name not in engine.table_names(schema=schema or self.meta.schema):
11911209
warnings.warn("The provided table name '{0}' is not found exactly "
11921210
"as such in the database after writing the table, "
11931211
"possibly due to case sensitivity issues. Consider "
@@ -1198,7 +1216,8 @@ def tables(self):
11981216
return self.meta.tables
11991217

12001218
def has_table(self, name, schema=None):
1201-
return self.engine.has_table(name, schema or self.meta.schema)
1219+
engine = self.connectable.engine
1220+
return engine.has_table(name, schema or self.meta.schema)
12021221

12031222
def get_table(self, table_name, schema=None):
12041223
schema = schema or self.meta.schema
@@ -1217,7 +1236,8 @@ def get_table(self, table_name, schema=None):
12171236

12181237
def drop_table(self, table_name, schema=None):
12191238
schema = schema or self.meta.schema
1220-
if self.engine.has_table(table_name, schema):
1239+
engine = self.connectable.engine
1240+
if engine.has_table(table_name, schema):
12211241
self.meta.reflect(only=[table_name], schema=schema)
12221242
self.get_table(table_name, schema).drop()
12231243
self.meta.clear()

pandas/io/tests/test_sql.py

+57-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
- `TestSQLiteFallbackApi`: test the public API with a sqlite DBAPI connection
1010
- Tests for the different SQL flavors (flavor specific type conversions)
1111
- Tests for the sqlalchemy mode: `_TestSQLAlchemy` is the base class with
12-
common methods, the different tested flavors (sqlite3, MySQL, PostgreSQL)
12+
common methods, `_TestSQLAlchemyConn` tests the API with a SQLAlchemy
13+
Connection object. The different tested flavors (sqlite3, MySQL, PostgreSQL)
1314
derive from the base class
1415
- Tests for the fallback mode (`TestSQLiteFallback` and `TestMySQLLegacy`)
1516
@@ -854,6 +855,30 @@ def test_sqlalchemy_type_mapping(self):
854855
self.assertTrue(isinstance(table.table.c['time'].type, sqltypes.DateTime))
855856

856857

858+
class _EngineToConnMixin(object):
859+
"""
860+
A mixin that causes setup_connect to create a conn rather than an engine.
861+
"""
862+
863+
def setup_connect(self):
864+
try:
865+
engine = self.connect() # connect lies, it returns an engine
866+
conn = engine.connect()
867+
self.addCleanup(conn.close)
868+
tx = conn.begin()
869+
self.addCleanup(tx.rollback)
870+
self.pandasSQL = sql.SQLDatabase(conn)
871+
self.conn = conn
872+
except sqlalchemy.exc.OperationalError:
873+
raise nose.SkipTest("Can't connect to {0} server".format(
874+
self.flavor
875+
))
876+
877+
878+
class TestSQLApiConn(_EngineToConnMixin, TestSQLApi):
879+
pass
880+
881+
857882
class TestSQLiteFallbackApi(_TestSQLApi):
858883
"""
859884
Test the public sqlite connection fallback API
@@ -1347,8 +1372,11 @@ def test_double_precision(self):
13471372
self.assertTrue(isinstance(col_dict['i64'].type, sqltypes.BigInteger))
13481373

13491374

1375+
class _TestSQLAlchemyConn(_EngineToConnMixin, _TestSQLAlchemy):
1376+
pass
1377+
13501378

1351-
class TestSQLiteAlchemy(_TestSQLAlchemy):
1379+
class _TestSQLiteAlchemy(object):
13521380
"""
13531381
Test the sqlalchemy backend against an in-memory sqlite database.
13541382
@@ -1404,7 +1432,7 @@ def test_bigint_warning(self):
14041432
self.assertEqual(len(w), 0, "Warning triggered for other table")
14051433

14061434

1407-
class TestMySQLAlchemy(_TestSQLAlchemy):
1435+
class _TestMySQLAlchemy(object):
14081436
"""
14091437
Test the sqlalchemy backend against an MySQL database.
14101438
@@ -1478,7 +1506,7 @@ def test_read_procedure(self):
14781506
tm.assert_frame_equal(df, res2)
14791507

14801508

1481-
class TestPostgreSQLAlchemy(_TestSQLAlchemy):
1509+
class _TestPostgreSQLAlchemy(object):
14821510
"""
14831511
Test the sqlalchemy backend against an PostgreSQL database.
14841512
@@ -1574,6 +1602,31 @@ def test_datetime_with_time_zone(self):
15741602
# "2000-06-01 00:00:00-07:00" should convert to "2000-06-01 07:00:00"
15751603
self.assertEqual(df.DateColWithTz[1], Timestamp('2000-06-01 07:00:00'))
15761604

1605+
1606+
class TestMySQLAlchemy(_TestMySQLAlchemy, _TestSQLAlchemy):
1607+
pass
1608+
1609+
1610+
class TestMySQLAlchemyConn(_TestMySQLAlchemy, _TestSQLAlchemyConn):
1611+
pass
1612+
1613+
1614+
class TestPostgreSQLAlchemy(_TestPostgreSQLAlchemy, _TestSQLAlchemy):
1615+
pass
1616+
1617+
1618+
class TestPostgreSQLAlchemyConn(_TestPostgreSQLAlchemy, _TestSQLAlchemyConn):
1619+
pass
1620+
1621+
1622+
class TestSQLiteAlchemy(_TestSQLiteAlchemy, _TestSQLAlchemy):
1623+
pass
1624+
1625+
1626+
class TestSQLiteAlchemyConn(_TestSQLiteAlchemy, _TestSQLAlchemyConn):
1627+
pass
1628+
1629+
15771630
#------------------------------------------------------------------------------
15781631
#--- Test Sqlite / MySQL fallback
15791632

0 commit comments

Comments
 (0)