Skip to content

Commit 0a8e2a5

Browse files
author
Thomas Grainger
committed
support both sqlalchemy engines and connections Fixes pandas-dev#7877
1 parent 5852e72 commit 0a8e2a5

File tree

2 files changed

+103
-30
lines changed

2 files changed

+103
-30
lines changed

pandas/io/sql.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class DatabaseError(IOError):
3838
_SQLALCHEMY_INSTALLED = None
3939

4040

41-
def _is_sqlalchemy_engine(con):
41+
def _is_sqlalchemy_connectable(con):
4242
global _SQLALCHEMY_INSTALLED
4343
if _SQLALCHEMY_INSTALLED is None:
4444
try:
@@ -62,7 +62,7 @@ def compile_big_int_sqlite(type_, compiler, **kw):
6262

6363
if _SQLALCHEMY_INSTALLED:
6464
import sqlalchemy
65-
return isinstance(con, sqlalchemy.engine.Engine)
65+
return isinstance(con, sqlalchemy.engine.Connectable)
6666
else:
6767
return False
6868

@@ -328,7 +328,7 @@ def read_sql_table(table_name, con, schema=None, index_col=None,
328328
read_sql
329329
330330
"""
331-
if not _is_sqlalchemy_engine(con):
331+
if not _is_sqlalchemy_connectable(con):
332332
raise NotImplementedError("read_sql_table only supported for "
333333
"SQLAlchemy engines.")
334334
import sqlalchemy
@@ -592,7 +592,7 @@ def pandasSQL_builder(con, flavor=None, schema=None, meta=None,
592592
"""
593593
# When support for DBAPI connections is removed,
594594
# is_cursor should not be necessary.
595-
if _is_sqlalchemy_engine(con):
595+
if _is_sqlalchemy_connectable(con):
596596
return SQLDatabase(con, schema=schema, meta=meta)
597597
else:
598598
if flavor == 'mysql':
@@ -637,7 +637,8 @@ def exists(self):
637637

638638
def sql_schema(self):
639639
from sqlalchemy.schema import CreateTable
640-
return str(CreateTable(self.table).compile(self.pd_sql.engine))
640+
engine = self.pd_sql.connectable.engine
641+
return str(CreateTable(self.table).compile(engine))
641642

642643
def _execute_create(self):
643644
# Inserting table into database, add to MetaData object
@@ -993,7 +994,7 @@ class SQLDatabase(PandasSQL):
993994
994995
Parameters
995996
----------
996-
engine : SQLAlchemy engine
997+
engine : SQLAlchemy Connectable (Engine or Connection)
997998
Engine to connect with the database. Using SQLAlchemy makes it
998999
possible to use any DB supported by that library.
9991000
schema : string, default None
@@ -1007,19 +1008,24 @@ class SQLDatabase(PandasSQL):
10071008
"""
10081009

10091010
def __init__(self, engine, schema=None, meta=None):
1010-
self.engine = engine
1011+
self.connectable = engine
10111012
if not meta:
10121013
from sqlalchemy.schema import MetaData
1013-
meta = MetaData(self.engine, schema=schema)
1014+
meta = MetaData(self.connectable, schema=schema)
10141015

10151016
self.meta = meta
10161017

1018+
@contextmanager
10171019
def run_transaction(self):
1018-
return self.engine.begin()
1020+
with self.connectable.begin() as tx:
1021+
if hasattr(tx, 'execute'):
1022+
yield tx
1023+
else:
1024+
yield self.connectable
10191025

10201026
def execute(self, *args, **kwargs):
10211027
"""Simple passthrough to SQLAlchemy engine"""
1022-
return self.engine.execute(*args, **kwargs)
1028+
return self.connectable.execute(*args, **kwargs)
10231029

10241030
def read_table(self, table_name, index_col=None, coerce_float=True,
10251031
parse_dates=None, columns=None, schema=None,
@@ -1187,7 +1193,13 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
11871193
table.create()
11881194
table.insert(chunksize)
11891195
# check for potentially case sensitivity issues (GH7815)
1190-
if name not in self.engine.table_names(schema=schema or self.meta.schema):
1196+
engine = self.connectable.engine
1197+
with self.connectable.connect() as conn:
1198+
table_names = engine.table_names(
1199+
schema=schema or self.meta.schema,
1200+
connection=conn,
1201+
)
1202+
if name not in table_names:
11911203
warnings.warn("The provided table name '{0}' is not found exactly "
11921204
"as such in the database after writing the table, "
11931205
"possibly due to case sensitivity issues. Consider "
@@ -1198,7 +1210,11 @@ def tables(self):
11981210
return self.meta.tables
11991211

12001212
def has_table(self, name, schema=None):
1201-
return self.engine.has_table(name, schema or self.meta.schema)
1213+
return self.connectable.run_callable(
1214+
self.connectable.dialect.has_table,
1215+
name,
1216+
schema or self.meta.schema,
1217+
)
12021218

12031219
def get_table(self, table_name, schema=None):
12041220
schema = schema or self.meta.schema
@@ -1217,7 +1233,7 @@ def get_table(self, table_name, schema=None):
12171233

12181234
def drop_table(self, table_name, schema=None):
12191235
schema = schema or self.meta.schema
1220-
if self.engine.has_table(table_name, schema):
1236+
if self.has_table(table_name, schema):
12211237
self.meta.reflect(only=[table_name], schema=schema)
12221238
self.get_table(table_name, schema).drop()
12231239
self.meta.clear()

pandas/io/tests/test_sql.py

+74-17
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
- `TestSQLiteFallbackApi`: test the public API with a sqlite DBAPI connection
1010
- Tests for the different SQL flavors (flavor specific type conversions)
1111
- Tests for the sqlalchemy mode: `_TestSQLAlchemy` is the base class with
12-
common methods, the different tested flavors (sqlite3, MySQL, PostgreSQL)
12+
common methods, `_TestSQLAlchemyConn` tests the API with a SQLAlchemy
13+
Connection object. The different tested flavors (sqlite3, MySQL, PostgreSQL)
1314
derive from the base class
1415
- Tests for the fallback mode (`TestSQLiteFallback` and `TestMySQLLegacy`)
1516
@@ -854,6 +855,31 @@ def test_sqlalchemy_type_mapping(self):
854855
self.assertTrue(isinstance(table.table.c['time'].type, sqltypes.DateTime))
855856

856857

858+
class _EngineToConnMixin(object):
859+
"""
860+
A mixin that causes setup_connect to create a conn rather than an engine.
861+
"""
862+
863+
def setUp(self):
864+
super(_EngineToConnMixin, self).setUp()
865+
engine = self.conn
866+
conn = engine.connect()
867+
self.__tx = conn.begin()
868+
self.pandasSQL = sql.SQLDatabase(conn)
869+
self.__engine = engine
870+
self.conn = conn
871+
872+
def tearDown(self):
873+
self.__tx.rollback()
874+
self.conn.close()
875+
self.conn = self.__engine
876+
self.pandasSQL = sql.SQLDatabase(self.__engine)
877+
878+
879+
class TestSQLApiConn(_EngineToConnMixin, TestSQLApi):
880+
pass
881+
882+
857883
class TestSQLiteFallbackApi(_TestSQLApi):
858884
"""
859885
Test the public sqlite connection fallback API
@@ -990,9 +1016,6 @@ def setup_connect(self):
9901016
except sqlalchemy.exc.OperationalError:
9911017
raise nose.SkipTest("Can't connect to {0} server".format(self.flavor))
9921018

993-
def tearDown(self):
994-
raise NotImplementedError()
995-
9961019
def test_aread_sql(self):
9971020
self._read_sql_iris()
9981021

@@ -1347,8 +1370,12 @@ def test_double_precision(self):
13471370
self.assertTrue(isinstance(col_dict['i64'].type, sqltypes.BigInteger))
13481371

13491372

1373+
class _TestSQLAlchemyConn(_EngineToConnMixin, _TestSQLAlchemy):
1374+
def test_transactions(self):
1375+
raise nose.SkipTest("Nested transactions rollbacks don't work with Pandas")
1376+
13501377

1351-
class TestSQLiteAlchemy(_TestSQLAlchemy):
1378+
class _TestSQLiteAlchemy(object):
13521379
"""
13531380
Test the sqlalchemy backend against an in-memory sqlite database.
13541381
@@ -1365,8 +1392,8 @@ def setup_driver(cls):
13651392
cls.driver = None
13661393

13671394
def tearDown(self):
1395+
super(_TestSQLiteAlchemy, self).tearDown()
13681396
# in memory so tables should not be removed explicitly
1369-
pass
13701397

13711398
def test_default_type_conversion(self):
13721399
df = sql.read_sql_table("types_test_data", self.conn)
@@ -1404,7 +1431,7 @@ def test_bigint_warning(self):
14041431
self.assertEqual(len(w), 0, "Warning triggered for other table")
14051432

14061433

1407-
class TestMySQLAlchemy(_TestSQLAlchemy):
1434+
class _TestMySQLAlchemy(object):
14081435
"""
14091436
Test the sqlalchemy backend against an MySQL database.
14101437
@@ -1425,6 +1452,7 @@ def setup_driver(cls):
14251452
raise nose.SkipTest('pymysql not installed')
14261453

14271454
def tearDown(self):
1455+
super(_TestMySQLAlchemy, self).tearDown()
14281456
c = self.conn.execute('SHOW TABLES')
14291457
for table in c.fetchall():
14301458
self.conn.execute('DROP TABLE %s' % table[0])
@@ -1478,7 +1506,7 @@ def test_read_procedure(self):
14781506
tm.assert_frame_equal(df, res2)
14791507

14801508

1481-
class TestPostgreSQLAlchemy(_TestSQLAlchemy):
1509+
class _TestPostgreSQLAlchemy(object):
14821510
"""
14831511
Test the sqlalchemy backend against an PostgreSQL database.
14841512
@@ -1499,6 +1527,7 @@ def setup_driver(cls):
14991527
raise nose.SkipTest('psycopg2 not installed')
15001528

15011529
def tearDown(self):
1530+
super(_TestPostgreSQLAlchemy, self).tearDown()
15021531
c = self.conn.execute(
15031532
"SELECT table_name FROM information_schema.tables"
15041533
" WHERE table_schema = 'public'")
@@ -1550,15 +1579,18 @@ def test_schema_support(self):
15501579

15511580
## specifying schema in user-provided meta
15521581

1553-
engine2 = self.connect()
1554-
meta = sqlalchemy.MetaData(engine2, schema='other')
1555-
pdsql = sql.SQLDatabase(engine2, meta=meta)
1556-
pdsql.to_sql(df, 'test_schema_other2', index=False)
1557-
pdsql.to_sql(df, 'test_schema_other2', index=False, if_exists='replace')
1558-
pdsql.to_sql(df, 'test_schema_other2', index=False, if_exists='append')
1559-
res1 = sql.read_sql_table('test_schema_other2', self.conn, schema='other')
1560-
res2 = pdsql.read_table('test_schema_other2')
1561-
tm.assert_frame_equal(res1, res2)
1582+
# The schema won't be applied on another Connection
1583+
# because of transactional schemas
1584+
if isinstance(self.conn, sqlalchemy.engine.Engine):
1585+
engine2 = self.connect()
1586+
meta = sqlalchemy.MetaData(engine2, schema='other')
1587+
pdsql = sql.SQLDatabase(engine2, meta=meta)
1588+
pdsql.to_sql(df, 'test_schema_other2', index=False)
1589+
pdsql.to_sql(df, 'test_schema_other2', index=False, if_exists='replace')
1590+
pdsql.to_sql(df, 'test_schema_other2', index=False, if_exists='append')
1591+
res1 = sql.read_sql_table('test_schema_other2', self.conn, schema='other')
1592+
res2 = pdsql.read_table('test_schema_other2')
1593+
tm.assert_frame_equal(res1, res2)
15621594

15631595
def test_datetime_with_time_zone(self):
15641596
# Test to see if we read the date column with timezones that
@@ -1574,6 +1606,31 @@ def test_datetime_with_time_zone(self):
15741606
# "2000-06-01 00:00:00-07:00" should convert to "2000-06-01 07:00:00"
15751607
self.assertEqual(df.DateColWithTz[1], Timestamp('2000-06-01 07:00:00'))
15761608

1609+
1610+
class TestMySQLAlchemy(_TestMySQLAlchemy, _TestSQLAlchemy):
1611+
pass
1612+
1613+
1614+
class TestMySQLAlchemyConn(_TestMySQLAlchemy, _TestSQLAlchemyConn):
1615+
pass
1616+
1617+
1618+
class TestPostgreSQLAlchemy(_TestPostgreSQLAlchemy, _TestSQLAlchemy):
1619+
pass
1620+
1621+
1622+
class TestPostgreSQLAlchemyConn(_TestPostgreSQLAlchemy, _TestSQLAlchemyConn):
1623+
pass
1624+
1625+
1626+
class TestSQLiteAlchemy(_TestSQLiteAlchemy, _TestSQLAlchemy):
1627+
pass
1628+
1629+
1630+
class TestSQLiteAlchemyConn(_TestSQLiteAlchemy, _TestSQLAlchemyConn):
1631+
pass
1632+
1633+
15771634
#------------------------------------------------------------------------------
15781635
#--- Test Sqlite / MySQL fallback
15791636

0 commit comments

Comments
 (0)