Skip to content

Commit 5c7d352

Browse files
author
Thomas Grainger
committed
support both sqlalchemy engines and connections Fixes #7877
1 parent f6c7d89 commit 5c7d352

File tree

2 files changed

+73
-14
lines changed

2 files changed

+73
-14
lines changed

pandas/io/sql.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def compile_big_int_sqlite(type_, compiler, **kw):
6262

6363
if _SQLALCHEMY_INSTALLED:
6464
import sqlalchemy
65-
return isinstance(con, sqlalchemy.engine.Engine)
65+
return isinstance(con, sqlalchemy.engine.Connectable)
6666
else:
6767
return False
6868

@@ -637,7 +637,7 @@ def exists(self):
637637

638638
def sql_schema(self):
639639
from sqlalchemy.schema import CreateTable
640-
return str(CreateTable(self.table).compile(self.pd_sql.engine))
640+
return str(CreateTable(self.table).compile(self.pd_sql.connection))
641641

642642
def _execute_create(self):
643643
# Inserting table into database, add to MetaData object
@@ -1006,20 +1006,31 @@ class SQLDatabase(PandasSQL):
10061006
10071007
"""
10081008

1009-
def __init__(self, engine, schema=None, meta=None):
1010-
self.engine = engine
1009+
def __init__(self, connection, schema=None, meta=None):
1010+
import sqlalchemy.engine
1011+
if isinstance(connection, sqlalchemy.engine.Engine):
1012+
self.connection = connection.connect()
1013+
else:
1014+
self.connection = connection
10111015
if not meta:
10121016
from sqlalchemy.schema import MetaData
1013-
meta = MetaData(self.engine, schema=schema)
1017+
meta = MetaData(self.connection, schema=schema)
10141018

10151019
self.meta = meta
10161020

1021+
@contextmanager
10171022
def run_transaction(self):
1018-
return self.engine.begin()
1023+
trans = self.connection.begin()
1024+
try:
1025+
yield self.connection
1026+
trans.commit()
1027+
except:
1028+
trans.rollback()
1029+
raise
10191030

10201031
def execute(self, *args, **kwargs):
10211032
"""Simple passthrough to SQLAlchemy engine"""
1022-
return self.engine.execute(*args, **kwargs)
1033+
return self.connection.execute(*args, **kwargs)
10231034

10241035
def read_table(self, table_name, index_col=None, coerce_float=True,
10251036
parse_dates=None, columns=None, schema=None,
@@ -1187,7 +1198,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
11871198
table.create()
11881199
table.insert(chunksize)
11891200
# check for potentially case sensitivity issues (GH7815)
1190-
if name not in self.engine.table_names(schema=schema or self.meta.schema):
1201+
if name not in self.connection.engine.table_names(schema=schema or self.meta.schema, connection=self.connection):
11911202
warnings.warn("The provided table name '{0}' is not found exactly "
11921203
"as such in the database after writing the table, "
11931204
"possibly due to case sensitivity issues. Consider "
@@ -1198,7 +1209,7 @@ def tables(self):
11981209
return self.meta.tables
11991210

12001211
def has_table(self, name, schema=None):
1201-
return self.engine.has_table(name, schema or self.meta.schema)
1212+
return self.connection.engine.has_table(name, schema or self.meta.schema)
12021213

12031214
def get_table(self, table_name, schema=None):
12041215
schema = schema or self.meta.schema
@@ -1217,7 +1228,7 @@ def get_table(self, table_name, schema=None):
12171228

12181229
def drop_table(self, table_name, schema=None):
12191230
schema = schema or self.meta.schema
1220-
if self.engine.has_table(table_name, schema):
1231+
if self.connection.engine.has_table(table_name, schema):
12211232
self.meta.reflect(only=[table_name], schema=schema)
12221233
self.get_table(table_name, schema).drop()
12231234
self.meta.clear()

pandas/io/tests/test_sql.py

+52-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
- `TestSQLiteFallbackApi`: test the public API with a sqlite DBAPI connection
1010
- Tests for the different SQL flavors (flavor specific type conversions)
1111
- Tests for the sqlalchemy mode: `_TestSQLAlchemy` is the base class with
12-
common methods, the different tested flavors (sqlite3, MySQL, PostgreSQL)
12+
common methods, `_TestSQLAlchemyConn` tests the API with a SQLAlchemy
13+
Connection object. The different tested flavors (sqlite3, MySQL, PostgreSQL)
1314
derive from the base class
1415
- Tests for the fallback mode (`TestSQLiteFallback` and `TestMySQLLegacy`)
1516
@@ -854,6 +855,18 @@ def test_sqlalchemy_type_mapping(self):
854855
self.assertTrue(isinstance(table.table.c['time'].type, sqltypes.DateTime))
855856

856857

858+
class TestSQLApiConn(TestSQLApi):
859+
def setUp(self):
860+
super(TestSQLApiConn, self).setUp()
861+
conn = self.conn
862+
conn.connect()
863+
self.addCleanup(conn.close)
864+
tx = conn.begin()
865+
self.addCleanup(tx.rollback)
866+
867+
self.conn = conn
868+
869+
857870
class TestSQLiteFallbackApi(_TestSQLApi):
858871
"""
859872
Test the public sqlite connection fallback API
@@ -1347,8 +1360,18 @@ def test_double_precision(self):
13471360
self.assertTrue(isinstance(col_dict['i64'].type, sqltypes.BigInteger))
13481361

13491362

1363+
class _TestSQLAlchemyConn(_TestSQLAlchemy):
1364+
def setUp(self):
1365+
super(_TestSQLAlchemyConn, self).setUp()
1366+
conn = self.conn.connect()
1367+
self.addCleanup(conn.close)
1368+
tx = conn.begin()
1369+
self.addCleanup(tx.rollback)
1370+
1371+
self.conn = conn
1372+
13501373

1351-
class TestSQLiteAlchemy(_TestSQLAlchemy):
1374+
class _TestSQLiteAlchemy(object):
13521375
"""
13531376
Test the sqlalchemy backend against an in-memory sqlite database.
13541377
@@ -1404,7 +1427,7 @@ def test_bigint_warning(self):
14041427
self.assertEqual(len(w), 0, "Warning triggered for other table")
14051428

14061429

1407-
class TestMySQLAlchemy(_TestSQLAlchemy):
1430+
class _TestMySQLAlchemy(object):
14081431
"""
14091432
Test the sqlalchemy backend against an MySQL database.
14101433
@@ -1478,7 +1501,7 @@ def test_read_procedure(self):
14781501
tm.assert_frame_equal(df, res2)
14791502

14801503

1481-
class TestPostgreSQLAlchemy(_TestSQLAlchemy):
1504+
class _TestPostgreSQLAlchemy(object):
14821505
"""
14831506
Test the sqlalchemy backend against an PostgreSQL database.
14841507
@@ -1574,6 +1597,31 @@ def test_datetime_with_time_zone(self):
15741597
# "2000-06-01 00:00:00-07:00" should convert to "2000-06-01 07:00:00"
15751598
self.assertEqual(df.DateColWithTz[1], Timestamp('2000-06-01 07:00:00'))
15761599

1600+
1601+
class TestMySQLAlchemy(_TestMySQLAlchemy, _TestSQLAlchemy):
1602+
pass
1603+
1604+
1605+
class TestMySQLAlchemyConn(_TestMySQLAlchemy, _TestSQLAlchemyConn):
1606+
pass
1607+
1608+
1609+
class TestPostgreSQLAlchemy(_TestPostgreSQLAlchemy, _TestSQLAlchemy):
1610+
pass
1611+
1612+
1613+
class TestPostgreSQLAlchemyConn(_TestPostgreSQLAlchemy, _TestSQLAlchemyConn):
1614+
pass
1615+
1616+
1617+
class TestSQLiteAlchemy(_TestSQLiteAlchemy, _TestSQLAlchemy):
1618+
pass
1619+
1620+
1621+
class TestSQLiteAlchemyConn(_TestSQLiteAlchemy, _TestSQLAlchemyConn):
1622+
pass
1623+
1624+
15771625
#------------------------------------------------------------------------------
15781626
#--- Test Sqlite / MySQL fallback
15791627

0 commit comments

Comments
 (0)