From e71283260bc78a5b4b6087929ca70b18cef03d47 Mon Sep 17 00:00:00 2001 From: Dan Allan Date: Thu, 11 Jul 2013 10:01:37 -0400 Subject: [PATCH 1/4] TST: Import sqlalchemy on Travis. --- ci/requirements-2.6.txt | 1 + ci/requirements-2.7.txt | 1 + ci/requirements-2.7_LOCALE.txt | 1 + ci/requirements-3.2.txt | 1 + ci/requirements-3.3.txt | 1 + pandas/io/tests/test_legacy_sql.py | 492 +++++++++++++++++++++++++++++ pandas/io/tests/test_sql.py | 473 +-------------------------- 7 files changed, 498 insertions(+), 472 deletions(-) create mode 100644 pandas/io/tests/test_legacy_sql.py diff --git a/ci/requirements-2.6.txt b/ci/requirements-2.6.txt index ac77449b2df02..a61aeaefa8a26 100644 --- a/ci/requirements-2.6.txt +++ b/ci/requirements-2.6.txt @@ -4,3 +4,4 @@ python-dateutil==2.1 pytz==2013b http://www.crummy.com/software/BeautifulSoup/bs4/download/4.2/beautifulsoup4-4.2.0.tar.gz html5lib==1.0b2 +sqlalchemy==0.8 diff --git a/ci/requirements-2.7.txt b/ci/requirements-2.7.txt index 6a94d48ad7a5f..ebaaef80b1527 100644 --- a/ci/requirements-2.7.txt +++ b/ci/requirements-2.7.txt @@ -16,3 +16,4 @@ scikits.timeseries==0.91.3 MySQL-python==1.2.4 scipy==0.10.0 beautifulsoup4==4.2.1 +sqlalchemy==0.8 diff --git a/ci/requirements-2.7_LOCALE.txt b/ci/requirements-2.7_LOCALE.txt index 70c398816f23c..e7eecc8433094 100644 --- a/ci/requirements-2.7_LOCALE.txt +++ b/ci/requirements-2.7_LOCALE.txt @@ -14,3 +14,4 @@ html5lib==1.0b2 lxml==3.2.1 scipy==0.10.0 beautifulsoup4==4.2.1 +sqlalchemy==0.8 diff --git a/ci/requirements-3.2.txt b/ci/requirements-3.2.txt index e907a2fa828f1..9572288d79cb3 100644 --- a/ci/requirements-3.2.txt +++ b/ci/requirements-3.2.txt @@ -11,3 +11,4 @@ patsy==0.1.0 lxml==3.2.1 scipy==0.12.0 beautifulsoup4==4.2.1 +sqlalchemy==0.8 diff --git a/ci/requirements-3.3.txt b/ci/requirements-3.3.txt index eb1e725d98040..1a1c98db06054 100644 --- a/ci/requirements-3.3.txt +++ b/ci/requirements-3.3.txt @@ -12,3 +12,4 @@ patsy==0.1.0 lxml==3.2.1 scipy==0.12.0 beautifulsoup4==4.2.1 +sqlalchemy==0.8 diff --git a/pandas/io/tests/test_legacy_sql.py b/pandas/io/tests/test_legacy_sql.py new file mode 100644 index 0000000000000..5b23bf173ec4e --- /dev/null +++ b/pandas/io/tests/test_legacy_sql.py @@ -0,0 +1,492 @@ +from __future__ import with_statement +from pandas.util.py3compat import StringIO +import unittest +import sqlite3 +import sys + +import warnings + +import nose + +import numpy as np + +from pandas.core.datetools import format as date_format +from pandas.core.api import DataFrame, isnull + +import pandas.io.sql as sql +import pandas.util.testing as tm +from pandas import Series, Index, DataFrame +from datetime import datetime + +_formatters = { + datetime: lambda dt: "'%s'" % date_format(dt), + str: lambda x: "'%s'" % x, + np.str_: lambda x: "'%s'" % x, + unicode: lambda x: "'%s'" % x, + float: lambda x: "%.8f" % x, + int: lambda x: "%s" % x, + type(None): lambda x: "NULL", + np.float64: lambda x: "%.10f" % x, + bool: lambda x: "'%s'" % x, +} + +def format_query(sql, *args): + """ + + """ + processed_args = [] + for arg in args: + if isinstance(arg, float) and isnull(arg): + arg = None + + formatter = _formatters[type(arg)] + processed_args.append(formatter(arg)) + + return sql % tuple(processed_args) + +def _skip_if_no_MySQLdb(): + try: + import MySQLdb + except ImportError: + raise nose.SkipTest('MySQLdb not installed, skipping') + +class TestSQLite(unittest.TestCase): + + def setUp(self): + self.db = sqlite3.connect(':memory:') + + def test_basic(self): + frame = tm.makeTimeDataFrame() + self._check_roundtrip(frame) + + def test_write_row_by_row(self): + frame = tm.makeTimeDataFrame() + frame.ix[0, 0] = np.nan + create_sql = sql.get_schema(frame, 'test', 'sqlite') + cur = self.db.cursor() + cur.execute(create_sql) + + cur = self.db.cursor() + + ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" + for idx, row in frame.iterrows(): + fmt_sql = format_query(ins, *row) + sql.tquery(fmt_sql, cur=cur) + + self.db.commit() + + result = sql.read_frame("select * from test", con=self.db) + result.index = frame.index + tm.assert_frame_equal(result, frame) + + def test_execute(self): + frame = tm.makeTimeDataFrame() + create_sql = sql.get_schema(frame, 'test', 'sqlite') + cur = self.db.cursor() + cur.execute(create_sql) + ins = "INSERT INTO test VALUES (?, ?, ?, ?)" + + row = frame.ix[0] + sql.execute(ins, self.db, params=tuple(row)) + self.db.commit() + + result = sql.read_frame("select * from test", self.db) + result.index = frame.index[:1] + tm.assert_frame_equal(result, frame[:1]) + + def test_schema(self): + frame = tm.makeTimeDataFrame() + create_sql = sql.get_schema(frame, 'test', 'sqlite') + lines = create_sql.splitlines() + for l in lines: + tokens = l.split(' ') + if len(tokens) == 2 and tokens[0] == 'A': + self.assert_(tokens[1] == 'DATETIME') + + frame = tm.makeTimeDataFrame() + create_sql = sql.get_schema(frame, 'test', 'sqlite', keys=['A', 'B'],) + lines = create_sql.splitlines() + self.assert_('PRIMARY KEY (A,B)' in create_sql) + cur = self.db.cursor() + cur.execute(create_sql) + + def test_execute_fail(self): + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a, b) + ); + """ + cur = self.db.cursor() + cur.execute(create_sql) + + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) + sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.db) + + try: + sys.stdout = StringIO() + self.assertRaises(Exception, sql.execute, + 'INSERT INTO test VALUES("foo", "bar", 7)', + self.db) + finally: + sys.stdout = sys.__stdout__ + + def test_execute_closed_connection(self): + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a, b) + ); + """ + cur = self.db.cursor() + cur.execute(create_sql) + + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) + self.db.close() + try: + sys.stdout = StringIO() + self.assertRaises(Exception, sql.tquery, "select * from test", + con=self.db) + finally: + sys.stdout = sys.__stdout__ + + def test_na_roundtrip(self): + pass + + def _check_roundtrip(self, frame): + sql.write_frame(frame, name='test_table', con=self.db) + result = sql.read_frame("select * from test_table", self.db) + + # HACK! Change this once indexes are handled properly. + result.index = frame.index + + expected = frame + tm.assert_frame_equal(result, expected) + + frame['txt'] = ['a'] * len(frame) + frame2 = frame.copy() + frame2['Idx'] = Index(range(len(frame2))) + 10 + sql.write_frame(frame2, name='test_table2', con=self.db) + result = sql.read_frame("select * from test_table2", self.db, + index_col='Idx') + expected = frame.copy() + expected.index = Index(range(len(frame2))) + 10 + expected.index.name = 'Idx' + print expected.index.names + print result.index.names + tm.assert_frame_equal(expected, result) + + def test_tquery(self): + frame = tm.makeTimeDataFrame() + sql.write_frame(frame, name='test_table', con=self.db) + result = sql.tquery("select A from test_table", self.db) + expected = frame.A + result = Series(result, frame.index) + tm.assert_series_equal(result, expected) + + try: + sys.stdout = StringIO() + self.assertRaises(sqlite3.OperationalError, sql.tquery, + 'select * from blah', con=self.db) + + self.assertRaises(sqlite3.OperationalError, sql.tquery, + 'select * from blah', con=self.db, retry=True) + finally: + sys.stdout = sys.__stdout__ + + def test_uquery(self): + frame = tm.makeTimeDataFrame() + sql.write_frame(frame, name='test_table', con=self.db) + stmt = 'INSERT INTO test_table VALUES(2.314, -123.1, 1.234, 2.3)' + self.assertEqual(sql.uquery(stmt, con=self.db), 1) + + try: + sys.stdout = StringIO() + + self.assertRaises(sqlite3.OperationalError, sql.tquery, + 'insert into blah values (1)', con=self.db) + + self.assertRaises(sqlite3.OperationalError, sql.tquery, + 'insert into blah values (1)', con=self.db, + retry=True) + finally: + sys.stdout = sys.__stdout__ + + def test_keyword_as_column_names(self): + ''' + ''' + df = DataFrame({'From':np.ones(5)}) + sql.write_frame(df, con = self.db, name = 'testkeywords') + + def test_onecolumn_of_integer(self): + ''' + GH 3628 + a column_of_integers dataframe should transfer well to sql + ''' + mono_df=DataFrame([1 , 2], columns=['c0']) + sql.write_frame(mono_df, con = self.db, name = 'mono_df') + # computing the sum via sql + con_x=self.db + the_sum=sum([my_c0[0] for my_c0 in con_x.execute("select * from mono_df")]) + # it should not fail, and gives 3 ( Issue #3628 ) + self.assertEqual(the_sum , 3) + + result = sql.read_frame("select * from mono_df",con_x) + tm.assert_frame_equal(result,mono_df) + + +class TestMySQL(unittest.TestCase): + + def setUp(self): + _skip_if_no_MySQLdb() + import MySQLdb + try: + # Try Travis defaults. + # No real user should allow root access with a blank password. + self.db = MySQLdb.connect(host='localhost', user='root', passwd='', + db='pandas_nosetest') + except: + pass + else: + return + try: + self.db = MySQLdb.connect(read_default_group='pandas') + except MySQLdb.ProgrammingError, e: + raise nose.SkipTest( + "Create a group of connection parameters under the heading " + "[pandas] in your system's mysql default file, " + "typically located at ~/.my.cnf or /etc/.my.cnf. ") + except MySQLdb.Error, e: + raise nose.SkipTest( + "Cannot connect to database. " + "Create a group of connection parameters under the heading " + "[pandas] in your system's mysql default file, " + "typically located at ~/.my.cnf or /etc/.my.cnf. ") + + def test_basic(self): + _skip_if_no_MySQLdb() + frame = tm.makeTimeDataFrame() + self._check_roundtrip(frame) + + def test_write_row_by_row(self): + _skip_if_no_MySQLdb() + frame = tm.makeTimeDataFrame() + frame.ix[0, 0] = np.nan + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = sql.get_schema(frame, 'test', 'mysql') + cur = self.db.cursor() + cur.execute(drop_sql) + cur.execute(create_sql) + ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" + for idx, row in frame.iterrows(): + fmt_sql = format_query(ins, *row) + sql.tquery(fmt_sql, cur=cur) + + self.db.commit() + + result = sql.read_frame("select * from test", con=self.db) + result.index = frame.index + tm.assert_frame_equal(result, frame) + + def test_execute(self): + _skip_if_no_MySQLdb() + frame = tm.makeTimeDataFrame() + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = sql.get_schema(frame, 'test', 'mysql') + cur = self.db.cursor() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unknown table.*") + cur.execute(drop_sql) + cur.execute(create_sql) + ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" + + row = frame.ix[0] + sql.execute(ins, self.db, params=tuple(row)) + self.db.commit() + + result = sql.read_frame("select * from test", self.db) + result.index = frame.index[:1] + tm.assert_frame_equal(result, frame[:1]) + + def test_schema(self): + _skip_if_no_MySQLdb() + frame = tm.makeTimeDataFrame() + create_sql = sql.get_schema(frame, 'test', 'mysql') + lines = create_sql.splitlines() + for l in lines: + tokens = l.split(' ') + if len(tokens) == 2 and tokens[0] == 'A': + self.assert_(tokens[1] == 'DATETIME') + + frame = tm.makeTimeDataFrame() + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = sql.get_schema(frame, 'test', 'mysql', keys=['A', 'B'],) + lines = create_sql.splitlines() + self.assert_('PRIMARY KEY (A,B)' in create_sql) + cur = self.db.cursor() + cur.execute(drop_sql) + cur.execute(create_sql) + + def test_execute_fail(self): + _skip_if_no_MySQLdb() + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a(5), b(5)) + ); + """ + cur = self.db.cursor() + cur.execute(drop_sql) + cur.execute(create_sql) + + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) + sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.db) + + try: + sys.stdout = StringIO() + self.assertRaises(Exception, sql.execute, + 'INSERT INTO test VALUES("foo", "bar", 7)', + self.db) + finally: + sys.stdout = sys.__stdout__ + + def test_execute_closed_connection(self): + _skip_if_no_MySQLdb() + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a(5), b(5)) + ); + """ + cur = self.db.cursor() + cur.execute(drop_sql) + cur.execute(create_sql) + + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) + self.db.close() + try: + sys.stdout = StringIO() + self.assertRaises(Exception, sql.tquery, "select * from test", + con=self.db) + finally: + sys.stdout = sys.__stdout__ + + def test_na_roundtrip(self): + _skip_if_no_MySQLdb() + pass + + def _check_roundtrip(self, frame): + _skip_if_no_MySQLdb() + drop_sql = "DROP TABLE IF EXISTS test_table" + cur = self.db.cursor() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unknown table.*") + cur.execute(drop_sql) + sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') + result = sql.read_frame("select * from test_table", self.db) + + # HACK! Change this once indexes are handled properly. + result.index = frame.index + result.index.name = frame.index.name + + expected = frame + tm.assert_frame_equal(result, expected) + + frame['txt'] = ['a'] * len(frame) + frame2 = frame.copy() + index = Index(range(len(frame2))) + 10 + frame2['Idx'] = index + drop_sql = "DROP TABLE IF EXISTS test_table2" + cur = self.db.cursor() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unknown table.*") + cur.execute(drop_sql) + sql.write_frame(frame2, name='test_table2', con=self.db, flavor='mysql') + result = sql.read_frame("select * from test_table2", self.db, + index_col='Idx') + expected = frame.copy() + + # HACK! Change this once indexes are handled properly. + expected.index = index + expected.index.names = result.index.names + tm.assert_frame_equal(expected, result) + + def test_tquery(self): + try: + import MySQLdb + except ImportError: + raise nose.SkipTest + frame = tm.makeTimeDataFrame() + drop_sql = "DROP TABLE IF EXISTS test_table" + cur = self.db.cursor() + cur.execute(drop_sql) + sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') + result = sql.tquery("select A from test_table", self.db) + expected = frame.A + result = Series(result, frame.index) + tm.assert_series_equal(result, expected) + + try: + sys.stdout = StringIO() + self.assertRaises(MySQLdb.ProgrammingError, sql.tquery, + 'select * from blah', con=self.db) + + self.assertRaises(MySQLdb.ProgrammingError, sql.tquery, + 'select * from blah', con=self.db, retry=True) + finally: + sys.stdout = sys.__stdout__ + + def test_uquery(self): + try: + import MySQLdb + except ImportError: + raise nose.SkipTest + frame = tm.makeTimeDataFrame() + drop_sql = "DROP TABLE IF EXISTS test_table" + cur = self.db.cursor() + cur.execute(drop_sql) + sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') + stmt = 'INSERT INTO test_table VALUES(2.314, -123.1, 1.234, 2.3)' + self.assertEqual(sql.uquery(stmt, con=self.db), 1) + + try: + sys.stdout = StringIO() + + self.assertRaises(MySQLdb.ProgrammingError, sql.tquery, + 'insert into blah values (1)', con=self.db) + + self.assertRaises(MySQLdb.ProgrammingError, sql.tquery, + 'insert into blah values (1)', con=self.db, + retry=True) + finally: + sys.stdout = sys.__stdout__ + + def test_keyword_as_column_names(self): + ''' + ''' + _skip_if_no_MySQLdb() + df = DataFrame({'From':np.ones(5)}) + sql.write_frame(df, con = self.db, name = 'testkeywords', + if_exists='replace', flavor='mysql') + + +if __name__ == '__main__': + # unittest.main() + # nose.runmodule(argv=[__file__,'-vvs','-x', '--pdb-failure'], + # exit=False) + nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'], + exit=False) diff --git a/pandas/io/tests/test_sql.py b/pandas/io/tests/test_sql.py index 5b23bf173ec4e..7edc7e124a417 100644 --- a/pandas/io/tests/test_sql.py +++ b/pandas/io/tests/test_sql.py @@ -18,475 +18,4 @@ from pandas import Series, Index, DataFrame from datetime import datetime -_formatters = { - datetime: lambda dt: "'%s'" % date_format(dt), - str: lambda x: "'%s'" % x, - np.str_: lambda x: "'%s'" % x, - unicode: lambda x: "'%s'" % x, - float: lambda x: "%.8f" % x, - int: lambda x: "%s" % x, - type(None): lambda x: "NULL", - np.float64: lambda x: "%.10f" % x, - bool: lambda x: "'%s'" % x, -} - -def format_query(sql, *args): - """ - - """ - processed_args = [] - for arg in args: - if isinstance(arg, float) and isnull(arg): - arg = None - - formatter = _formatters[type(arg)] - processed_args.append(formatter(arg)) - - return sql % tuple(processed_args) - -def _skip_if_no_MySQLdb(): - try: - import MySQLdb - except ImportError: - raise nose.SkipTest('MySQLdb not installed, skipping') - -class TestSQLite(unittest.TestCase): - - def setUp(self): - self.db = sqlite3.connect(':memory:') - - def test_basic(self): - frame = tm.makeTimeDataFrame() - self._check_roundtrip(frame) - - def test_write_row_by_row(self): - frame = tm.makeTimeDataFrame() - frame.ix[0, 0] = np.nan - create_sql = sql.get_schema(frame, 'test', 'sqlite') - cur = self.db.cursor() - cur.execute(create_sql) - - cur = self.db.cursor() - - ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" - for idx, row in frame.iterrows(): - fmt_sql = format_query(ins, *row) - sql.tquery(fmt_sql, cur=cur) - - self.db.commit() - - result = sql.read_frame("select * from test", con=self.db) - result.index = frame.index - tm.assert_frame_equal(result, frame) - - def test_execute(self): - frame = tm.makeTimeDataFrame() - create_sql = sql.get_schema(frame, 'test', 'sqlite') - cur = self.db.cursor() - cur.execute(create_sql) - ins = "INSERT INTO test VALUES (?, ?, ?, ?)" - - row = frame.ix[0] - sql.execute(ins, self.db, params=tuple(row)) - self.db.commit() - - result = sql.read_frame("select * from test", self.db) - result.index = frame.index[:1] - tm.assert_frame_equal(result, frame[:1]) - - def test_schema(self): - frame = tm.makeTimeDataFrame() - create_sql = sql.get_schema(frame, 'test', 'sqlite') - lines = create_sql.splitlines() - for l in lines: - tokens = l.split(' ') - if len(tokens) == 2 and tokens[0] == 'A': - self.assert_(tokens[1] == 'DATETIME') - - frame = tm.makeTimeDataFrame() - create_sql = sql.get_schema(frame, 'test', 'sqlite', keys=['A', 'B'],) - lines = create_sql.splitlines() - self.assert_('PRIMARY KEY (A,B)' in create_sql) - cur = self.db.cursor() - cur.execute(create_sql) - - def test_execute_fail(self): - create_sql = """ - CREATE TABLE test - ( - a TEXT, - b TEXT, - c REAL, - PRIMARY KEY (a, b) - ); - """ - cur = self.db.cursor() - cur.execute(create_sql) - - sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) - sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.db) - - try: - sys.stdout = StringIO() - self.assertRaises(Exception, sql.execute, - 'INSERT INTO test VALUES("foo", "bar", 7)', - self.db) - finally: - sys.stdout = sys.__stdout__ - - def test_execute_closed_connection(self): - create_sql = """ - CREATE TABLE test - ( - a TEXT, - b TEXT, - c REAL, - PRIMARY KEY (a, b) - ); - """ - cur = self.db.cursor() - cur.execute(create_sql) - - sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) - self.db.close() - try: - sys.stdout = StringIO() - self.assertRaises(Exception, sql.tquery, "select * from test", - con=self.db) - finally: - sys.stdout = sys.__stdout__ - - def test_na_roundtrip(self): - pass - - def _check_roundtrip(self, frame): - sql.write_frame(frame, name='test_table', con=self.db) - result = sql.read_frame("select * from test_table", self.db) - - # HACK! Change this once indexes are handled properly. - result.index = frame.index - - expected = frame - tm.assert_frame_equal(result, expected) - - frame['txt'] = ['a'] * len(frame) - frame2 = frame.copy() - frame2['Idx'] = Index(range(len(frame2))) + 10 - sql.write_frame(frame2, name='test_table2', con=self.db) - result = sql.read_frame("select * from test_table2", self.db, - index_col='Idx') - expected = frame.copy() - expected.index = Index(range(len(frame2))) + 10 - expected.index.name = 'Idx' - print expected.index.names - print result.index.names - tm.assert_frame_equal(expected, result) - - def test_tquery(self): - frame = tm.makeTimeDataFrame() - sql.write_frame(frame, name='test_table', con=self.db) - result = sql.tquery("select A from test_table", self.db) - expected = frame.A - result = Series(result, frame.index) - tm.assert_series_equal(result, expected) - - try: - sys.stdout = StringIO() - self.assertRaises(sqlite3.OperationalError, sql.tquery, - 'select * from blah', con=self.db) - - self.assertRaises(sqlite3.OperationalError, sql.tquery, - 'select * from blah', con=self.db, retry=True) - finally: - sys.stdout = sys.__stdout__ - - def test_uquery(self): - frame = tm.makeTimeDataFrame() - sql.write_frame(frame, name='test_table', con=self.db) - stmt = 'INSERT INTO test_table VALUES(2.314, -123.1, 1.234, 2.3)' - self.assertEqual(sql.uquery(stmt, con=self.db), 1) - - try: - sys.stdout = StringIO() - - self.assertRaises(sqlite3.OperationalError, sql.tquery, - 'insert into blah values (1)', con=self.db) - - self.assertRaises(sqlite3.OperationalError, sql.tquery, - 'insert into blah values (1)', con=self.db, - retry=True) - finally: - sys.stdout = sys.__stdout__ - - def test_keyword_as_column_names(self): - ''' - ''' - df = DataFrame({'From':np.ones(5)}) - sql.write_frame(df, con = self.db, name = 'testkeywords') - - def test_onecolumn_of_integer(self): - ''' - GH 3628 - a column_of_integers dataframe should transfer well to sql - ''' - mono_df=DataFrame([1 , 2], columns=['c0']) - sql.write_frame(mono_df, con = self.db, name = 'mono_df') - # computing the sum via sql - con_x=self.db - the_sum=sum([my_c0[0] for my_c0 in con_x.execute("select * from mono_df")]) - # it should not fail, and gives 3 ( Issue #3628 ) - self.assertEqual(the_sum , 3) - - result = sql.read_frame("select * from mono_df",con_x) - tm.assert_frame_equal(result,mono_df) - - -class TestMySQL(unittest.TestCase): - - def setUp(self): - _skip_if_no_MySQLdb() - import MySQLdb - try: - # Try Travis defaults. - # No real user should allow root access with a blank password. - self.db = MySQLdb.connect(host='localhost', user='root', passwd='', - db='pandas_nosetest') - except: - pass - else: - return - try: - self.db = MySQLdb.connect(read_default_group='pandas') - except MySQLdb.ProgrammingError, e: - raise nose.SkipTest( - "Create a group of connection parameters under the heading " - "[pandas] in your system's mysql default file, " - "typically located at ~/.my.cnf or /etc/.my.cnf. ") - except MySQLdb.Error, e: - raise nose.SkipTest( - "Cannot connect to database. " - "Create a group of connection parameters under the heading " - "[pandas] in your system's mysql default file, " - "typically located at ~/.my.cnf or /etc/.my.cnf. ") - - def test_basic(self): - _skip_if_no_MySQLdb() - frame = tm.makeTimeDataFrame() - self._check_roundtrip(frame) - - def test_write_row_by_row(self): - _skip_if_no_MySQLdb() - frame = tm.makeTimeDataFrame() - frame.ix[0, 0] = np.nan - drop_sql = "DROP TABLE IF EXISTS test" - create_sql = sql.get_schema(frame, 'test', 'mysql') - cur = self.db.cursor() - cur.execute(drop_sql) - cur.execute(create_sql) - ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" - for idx, row in frame.iterrows(): - fmt_sql = format_query(ins, *row) - sql.tquery(fmt_sql, cur=cur) - - self.db.commit() - - result = sql.read_frame("select * from test", con=self.db) - result.index = frame.index - tm.assert_frame_equal(result, frame) - - def test_execute(self): - _skip_if_no_MySQLdb() - frame = tm.makeTimeDataFrame() - drop_sql = "DROP TABLE IF EXISTS test" - create_sql = sql.get_schema(frame, 'test', 'mysql') - cur = self.db.cursor() - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "Unknown table.*") - cur.execute(drop_sql) - cur.execute(create_sql) - ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" - - row = frame.ix[0] - sql.execute(ins, self.db, params=tuple(row)) - self.db.commit() - - result = sql.read_frame("select * from test", self.db) - result.index = frame.index[:1] - tm.assert_frame_equal(result, frame[:1]) - - def test_schema(self): - _skip_if_no_MySQLdb() - frame = tm.makeTimeDataFrame() - create_sql = sql.get_schema(frame, 'test', 'mysql') - lines = create_sql.splitlines() - for l in lines: - tokens = l.split(' ') - if len(tokens) == 2 and tokens[0] == 'A': - self.assert_(tokens[1] == 'DATETIME') - - frame = tm.makeTimeDataFrame() - drop_sql = "DROP TABLE IF EXISTS test" - create_sql = sql.get_schema(frame, 'test', 'mysql', keys=['A', 'B'],) - lines = create_sql.splitlines() - self.assert_('PRIMARY KEY (A,B)' in create_sql) - cur = self.db.cursor() - cur.execute(drop_sql) - cur.execute(create_sql) - - def test_execute_fail(self): - _skip_if_no_MySQLdb() - drop_sql = "DROP TABLE IF EXISTS test" - create_sql = """ - CREATE TABLE test - ( - a TEXT, - b TEXT, - c REAL, - PRIMARY KEY (a(5), b(5)) - ); - """ - cur = self.db.cursor() - cur.execute(drop_sql) - cur.execute(create_sql) - - sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) - sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.db) - - try: - sys.stdout = StringIO() - self.assertRaises(Exception, sql.execute, - 'INSERT INTO test VALUES("foo", "bar", 7)', - self.db) - finally: - sys.stdout = sys.__stdout__ - - def test_execute_closed_connection(self): - _skip_if_no_MySQLdb() - drop_sql = "DROP TABLE IF EXISTS test" - create_sql = """ - CREATE TABLE test - ( - a TEXT, - b TEXT, - c REAL, - PRIMARY KEY (a(5), b(5)) - ); - """ - cur = self.db.cursor() - cur.execute(drop_sql) - cur.execute(create_sql) - - sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) - self.db.close() - try: - sys.stdout = StringIO() - self.assertRaises(Exception, sql.tquery, "select * from test", - con=self.db) - finally: - sys.stdout = sys.__stdout__ - - def test_na_roundtrip(self): - _skip_if_no_MySQLdb() - pass - - def _check_roundtrip(self, frame): - _skip_if_no_MySQLdb() - drop_sql = "DROP TABLE IF EXISTS test_table" - cur = self.db.cursor() - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "Unknown table.*") - cur.execute(drop_sql) - sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') - result = sql.read_frame("select * from test_table", self.db) - - # HACK! Change this once indexes are handled properly. - result.index = frame.index - result.index.name = frame.index.name - - expected = frame - tm.assert_frame_equal(result, expected) - - frame['txt'] = ['a'] * len(frame) - frame2 = frame.copy() - index = Index(range(len(frame2))) + 10 - frame2['Idx'] = index - drop_sql = "DROP TABLE IF EXISTS test_table2" - cur = self.db.cursor() - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "Unknown table.*") - cur.execute(drop_sql) - sql.write_frame(frame2, name='test_table2', con=self.db, flavor='mysql') - result = sql.read_frame("select * from test_table2", self.db, - index_col='Idx') - expected = frame.copy() - - # HACK! Change this once indexes are handled properly. - expected.index = index - expected.index.names = result.index.names - tm.assert_frame_equal(expected, result) - - def test_tquery(self): - try: - import MySQLdb - except ImportError: - raise nose.SkipTest - frame = tm.makeTimeDataFrame() - drop_sql = "DROP TABLE IF EXISTS test_table" - cur = self.db.cursor() - cur.execute(drop_sql) - sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') - result = sql.tquery("select A from test_table", self.db) - expected = frame.A - result = Series(result, frame.index) - tm.assert_series_equal(result, expected) - - try: - sys.stdout = StringIO() - self.assertRaises(MySQLdb.ProgrammingError, sql.tquery, - 'select * from blah', con=self.db) - - self.assertRaises(MySQLdb.ProgrammingError, sql.tquery, - 'select * from blah', con=self.db, retry=True) - finally: - sys.stdout = sys.__stdout__ - - def test_uquery(self): - try: - import MySQLdb - except ImportError: - raise nose.SkipTest - frame = tm.makeTimeDataFrame() - drop_sql = "DROP TABLE IF EXISTS test_table" - cur = self.db.cursor() - cur.execute(drop_sql) - sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') - stmt = 'INSERT INTO test_table VALUES(2.314, -123.1, 1.234, 2.3)' - self.assertEqual(sql.uquery(stmt, con=self.db), 1) - - try: - sys.stdout = StringIO() - - self.assertRaises(MySQLdb.ProgrammingError, sql.tquery, - 'insert into blah values (1)', con=self.db) - - self.assertRaises(MySQLdb.ProgrammingError, sql.tquery, - 'insert into blah values (1)', con=self.db, - retry=True) - finally: - sys.stdout = sys.__stdout__ - - def test_keyword_as_column_names(self): - ''' - ''' - _skip_if_no_MySQLdb() - df = DataFrame({'From':np.ones(5)}) - sql.write_frame(df, con = self.db, name = 'testkeywords', - if_exists='replace', flavor='mysql') - - -if __name__ == '__main__': - # unittest.main() - # nose.runmodule(argv=[__file__,'-vvs','-x', '--pdb-failure'], - # exit=False) - nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'], - exit=False) +import sqlalchemy From 03355c4ac1d0b12f1bc4dc455f1b013bb0114186 Mon Sep 17 00:00:00 2001 From: Dan Allan Date: Thu, 11 Jul 2013 11:21:28 -0400 Subject: [PATCH 2/4] read_sql docstring edits --- pandas/io/sql.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 11b139b620175..7bd334fb58d94 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -133,7 +133,9 @@ def uquery(sql, con=None, cur=None, retry=True, params=None): return result -def read_frame(sql, con, index_col=None, coerce_float=True, params=None): +def read_sql(sql, con=None, index_col=None, + user=None, passwd=None, host=None, db=None, flavor=None, + coerce_float=True, params=None): """ Returns a DataFrame corresponding to the result set of the query string. @@ -145,9 +147,19 @@ def read_frame(sql, con, index_col=None, coerce_float=True, params=None): ---------- sql: string SQL query to be executed - con: DB connection object, optional + con : Connection object, SQLAlchemy Engine object, or a filepath (sqlite + only). Alternatively, specify a user, passwd, host, and db below. index_col: string, optional column name to use for the returned DataFrame object. + user: username for database authentication + only needed if a Connection, Engine, or filepath are not given + passwd: password for database authentication + only needed if a Connection, Engine, or filepath are not given + host: host for database connection + only needed if a Connection, Engine, or filepath are not given + db: database name + only needed if a Connection, Engine, or filepath are not given + flavor : string specifying the flavor of SQL to use coerce_float : boolean, default True Attempt to convert values to non-string, non-numeric objects (like decimal.Decimal) to floating point, useful for SQL result sets @@ -169,8 +181,8 @@ def read_frame(sql, con, index_col=None, coerce_float=True, params=None): return result -frame_query = read_frame -read_sql = read_frame +frame_query = read_sql +read_frame = read_sql def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs): """ From 71797324668b94082c960b95c16c56eb92c401f5 Mon Sep 17 00:00:00 2001 From: Dan Allan Date: Thu, 11 Jul 2013 14:06:21 -0400 Subject: [PATCH 3/4] read_sql connects via Connection, Engine, file path, or :memory: string --- pandas/io/sql.py | 114 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 96 insertions(+), 18 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 7bd334fb58d94..7d3f07c520d5c 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -7,6 +7,9 @@ import numpy as np import traceback +import sqlite3 +import warnings + from pandas.core.datetools import format as date_format from pandas.core.api import DataFrame, isnull @@ -132,10 +135,81 @@ def uquery(sql, con=None, cur=None, retry=True, params=None): return uquery(sql, con, retry=False) return result +class SQLAlchemyRequired(Exception): + pass -def read_sql(sql, con=None, index_col=None, - user=None, passwd=None, host=None, db=None, flavor=None, - coerce_float=True, params=None): +def get_connection(con, dialect, driver, username, password, + host, port, database): + if isinstance(con, basestring): + try: + import sqlalchemy + return _alchemy_connect_sqlite(con) + except: + return sqlite3.connect(con) + if isinstance(con, sqlite3.Connection): + return con + # If we reach here, SQLAlchemy will be needed. + try: + import sqlalchemy + except ImportError: + raise SQLAlchemyRequired + if isinstance(con, sqlalchemy.engine.Engine): + return con.connect() + if isinstance(con, sqlalchemy.engine.Connection): + return con + if con is None: + url_params = (dialect, driver, username, \ + password, host, port, database) + url = _build_url(*url_params) + engine = sqlalchemy.create_engine(url) + return engine.connect() + if hasattr(con, 'cursor') and callable(con.cursor): + # This looks like some Connection object from a driver module. + try: + import MySQLdb + warnings.warn("For more robust support, connect using " \ + "SQLAlchemy. See documentation.") + return conn.cursor() # behaves like a sqlalchemy Connection + except ImportError: + pass + raise NotImplementedError, \ + """To ensure robust support of varied SQL dialects, pandas + only support database connections from SQLAlchemy. See + documentation.""" + else: + raise ValueError, \ + """con must be a string, a Connection to a sqlite Database, + or a SQLAlchemy Connection or Engine object.""" + + +def _alchemy_connect_sqlite(path): + if path == ':memory:': + return create_engine('sqlite://').connect() + else: + return create_engine('sqlite:///%s' % path).connect() + +def _build_url(dialect, driver, username, password, host, port, database): + # Create an Engine and from that a Connection. + # We use a string instead of sqlalchemy.engine.url.URL because + # we do not necessarily know the driver; we know the dialect. + required_params = [dialect, username, password, host, database] + for p in required_params: + if not isinstance(p, basestring): + raise ValueError, \ + "Insufficient information to connect to a database;" \ + "see docstring." + url = dialect + if driver is not None: + url += "+%s" % driver + url += "://%s:%s@%s" % (username, password, host) + if port is not None: + url += ":%d" % port + url += "/%s" % database + return url + +def read_sql(sql, con=None, index_col=None, flavor=None, driver=None, + username=None, password=None, host=None, port=None, + database=None, coerce_float=True, params=None): """ Returns a DataFrame corresponding to the result set of the query string. @@ -147,34 +221,38 @@ def read_sql(sql, con=None, index_col=None, ---------- sql: string SQL query to be executed - con : Connection object, SQLAlchemy Engine object, or a filepath (sqlite - only). Alternatively, specify a user, passwd, host, and db below. + con : Connection object, SQLAlchemy Engine object, a filepath string + (sqlite only) or the string ':memory:' (sqlite only). Alternatively, + specify a user, passwd, host, and db below. index_col: string, optional column name to use for the returned DataFrame object. - user: username for database authentication + flavor : string specifying the flavor of SQL to use + driver : string specifying SQL driver (e.g., MySQLdb), optional + username: username for database authentication only needed if a Connection, Engine, or filepath are not given - passwd: password for database authentication + password: password for database authentication only needed if a Connection, Engine, or filepath are not given host: host for database connection only needed if a Connection, Engine, or filepath are not given - db: database name + database: database name only needed if a Connection, Engine, or filepath are not given - flavor : string specifying the flavor of SQL to use coerce_float : boolean, default True Attempt to convert values to non-string, non-numeric objects (like decimal.Decimal) to floating point, useful for SQL result sets params: list or tuple, optional List of parameters to pass to execute method. """ - cur = execute(sql, con, params=params) - rows = _safe_fetch(cur) - columns = [col_desc[0] for col_desc in cur.description] - - cur.close() - con.commit() - - result = DataFrame.from_records(rows, columns=columns, - coerce_float=coerce_float) + dialect = flavor + connection = get_connection(con, dialect, driver, username, password, + host, port, database) + if params is None: + params = [] + cursor = connection.execute(sql, *params) + result = _safe_fetch(cursor) + columns = [col_desc[0] for col_desc in cursor.description] + cursor.close() + + result = DataFrame.from_records(result, columns=columns) if index_col is not None: result = result.set_index(index_col) From 4d3774757df53fb67b8bf040a65e848728c14616 Mon Sep 17 00:00:00 2001 From: Dan Allan Date: Mon, 22 Jul 2013 17:48:15 -0400 Subject: [PATCH 4/4] Separate legacy code into new file, and fallback so that all old tests pass. --- pandas/io/sql.py | 33 +- pandas/io/sql_legacy.py | 325 ++++++++++++++++++ ...{test_legacy_sql.py => test_sql_legacy.py} | 4 +- 3 files changed, 350 insertions(+), 12 deletions(-) create mode 100644 pandas/io/sql_legacy.py rename pandas/io/tests/{test_legacy_sql.py => test_sql_legacy.py} (99%) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 7d3f07c520d5c..0673910e3bdde 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -12,6 +12,7 @@ from pandas.core.datetools import format as date_format from pandas.core.api import DataFrame, isnull +from pandas.io import sql_legacy #------------------------------------------------------------------------------ # Helper execution function @@ -138,6 +139,9 @@ def uquery(sql, con=None, cur=None, retry=True, params=None): class SQLAlchemyRequired(Exception): pass +class LegacyMySQLConnection(Exception): + pass + def get_connection(con, dialect, driver, username, password, host, port, database): if isinstance(con, basestring): @@ -148,6 +152,14 @@ def get_connection(con, dialect, driver, username, password, return sqlite3.connect(con) if isinstance(con, sqlite3.Connection): return con + try: + import MySQLdb + except ImportError: + # If we don't have MySQLdb, this can't be a MySQLdb connection. + pass + else: + if isinstance(con, MySQLdb.connection): + raise LegacyMySQLConnection # If we reach here, SQLAlchemy will be needed. try: import sqlalchemy @@ -165,17 +177,10 @@ def get_connection(con, dialect, driver, username, password, return engine.connect() if hasattr(con, 'cursor') and callable(con.cursor): # This looks like some Connection object from a driver module. - try: - import MySQLdb - warnings.warn("For more robust support, connect using " \ - "SQLAlchemy. See documentation.") - return conn.cursor() # behaves like a sqlalchemy Connection - except ImportError: - pass raise NotImplementedError, \ """To ensure robust support of varied SQL dialects, pandas - only support database connections from SQLAlchemy. See - documentation.""" + only supports database connections from SQLAlchemy. (Legacy + support for MySQLdb connections are available but buggy.)""" else: raise ValueError, \ """con must be a string, a Connection to a sqlite Database, @@ -243,8 +248,14 @@ def read_sql(sql, con=None, index_col=None, flavor=None, driver=None, List of parameters to pass to execute method. """ dialect = flavor - connection = get_connection(con, dialect, driver, username, password, - host, port, database) + try: + connection = get_connection(con, dialect, driver, username, password, + host, port, database) + except LegacyMySQLConnection: + warnings.warn("For more robust support, connect using " \ + "SQLAlchemy. See documentation.") + return sql_legacy.read_frame(sql, con, index_col, coerce_float, params) + if params is None: params = [] cursor = connection.execute(sql, *params) diff --git a/pandas/io/sql_legacy.py b/pandas/io/sql_legacy.py new file mode 100644 index 0000000000000..11b139b620175 --- /dev/null +++ b/pandas/io/sql_legacy.py @@ -0,0 +1,325 @@ +""" +Collection of query wrappers / abstractions to both facilitate data +retrieval and to reduce dependency on DB-specific API. +""" +from datetime import datetime, date + +import numpy as np +import traceback + +from pandas.core.datetools import format as date_format +from pandas.core.api import DataFrame, isnull + +#------------------------------------------------------------------------------ +# Helper execution function + + +def execute(sql, con, retry=True, cur=None, params=None): + """ + Execute the given SQL query using the provided connection object. + + Parameters + ---------- + sql: string + Query to be executed + con: database connection instance + Database connection. Must implement PEP249 (Database API v2.0). + retry: bool + Not currently implemented + cur: database cursor, optional + Must implement PEP249 (Datbase API v2.0). If cursor is not provided, + one will be obtained from the database connection. + params: list or tuple, optional + List of parameters to pass to execute method. + + Returns + ------- + Cursor object + """ + try: + if cur is None: + cur = con.cursor() + + if params is None: + cur.execute(sql) + else: + cur.execute(sql, params) + return cur + except Exception: + try: + con.rollback() + except Exception: # pragma: no cover + pass + + print ('Error on sql %s' % sql) + raise + + +def _safe_fetch(cur): + try: + result = cur.fetchall() + if not isinstance(result, list): + result = list(result) + return result + except Exception, e: # pragma: no cover + excName = e.__class__.__name__ + if excName == 'OperationalError': + return [] + + +def tquery(sql, con=None, cur=None, retry=True): + """ + Returns list of tuples corresponding to each row in given sql + query. + + If only one column selected, then plain list is returned. + + Parameters + ---------- + sql: string + SQL query to be executed + con: SQLConnection or DB API 2.0-compliant connection + cur: DB API 2.0 cursor + + Provide a specific connection or a specific cursor if you are executing a + lot of sequential statements and want to commit outside. + """ + cur = execute(sql, con, cur=cur) + result = _safe_fetch(cur) + + if con is not None: + try: + cur.close() + con.commit() + except Exception, e: + excName = e.__class__.__name__ + if excName == 'OperationalError': # pragma: no cover + print ('Failed to commit, may need to restart interpreter') + else: + raise + + traceback.print_exc() + if retry: + return tquery(sql, con=con, retry=False) + + if result and len(result[0]) == 1: + # python 3 compat + result = list(list(zip(*result))[0]) + elif result is None: # pragma: no cover + result = [] + + return result + + +def uquery(sql, con=None, cur=None, retry=True, params=None): + """ + Does the same thing as tquery, but instead of returning results, it + returns the number of rows affected. Good for update queries. + """ + cur = execute(sql, con, cur=cur, retry=retry, params=params) + + result = cur.rowcount + try: + con.commit() + except Exception, e: + excName = e.__class__.__name__ + if excName != 'OperationalError': + raise + + traceback.print_exc() + if retry: + print ('Looks like your connection failed, reconnecting...') + return uquery(sql, con, retry=False) + return result + + +def read_frame(sql, con, index_col=None, coerce_float=True, params=None): + """ + Returns a DataFrame corresponding to the result set of the query + string. + + Optionally provide an index_col parameter to use one of the + columns as the index. Otherwise will be 0 to len(results) - 1. + + Parameters + ---------- + sql: string + SQL query to be executed + con: DB connection object, optional + index_col: string, optional + column name to use for the returned DataFrame object. + coerce_float : boolean, default True + Attempt to convert values to non-string, non-numeric objects (like + decimal.Decimal) to floating point, useful for SQL result sets + params: list or tuple, optional + List of parameters to pass to execute method. + """ + cur = execute(sql, con, params=params) + rows = _safe_fetch(cur) + columns = [col_desc[0] for col_desc in cur.description] + + cur.close() + con.commit() + + result = DataFrame.from_records(rows, columns=columns, + coerce_float=coerce_float) + + if index_col is not None: + result = result.set_index(index_col) + + return result + +frame_query = read_frame +read_sql = read_frame + +def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs): + """ + Write records stored in a DataFrame to a SQL database. + + Parameters + ---------- + frame: DataFrame + name: name of SQL table + con: an open SQL database connection object + flavor: {'sqlite', 'mysql', 'oracle'}, default 'sqlite' + if_exists: {'fail', 'replace', 'append'}, default 'fail' + fail: If table exists, do nothing. + replace: If table exists, drop it, recreate it, and insert data. + append: If table exists, insert data. Create if does not exist. + """ + + if 'append' in kwargs: + import warnings + warnings.warn("append is deprecated, use if_exists instead", + FutureWarning) + if kwargs['append']: + if_exists='append' + else: + if_exists='fail' + exists = table_exists(name, con, flavor) + if if_exists == 'fail' and exists: + raise ValueError, "Table '%s' already exists." % name + + #create or drop-recreate if necessary + create = None + if exists and if_exists == 'replace': + create = "DROP TABLE %s" % name + elif not exists: + create = get_schema(frame, name, flavor) + + if create is not None: + cur = con.cursor() + cur.execute(create) + cur.close() + + cur = con.cursor() + # Replace spaces in DataFrame column names with _. + safe_names = [s.replace(' ', '_').strip() for s in frame.columns] + flavor_picker = {'sqlite' : _write_sqlite, + 'mysql' : _write_mysql} + + func = flavor_picker.get(flavor, None) + if func is None: + raise NotImplementedError + func(frame, name, safe_names, cur) + cur.close() + con.commit() + +def _write_sqlite(frame, table, names, cur): + bracketed_names = ['[' + column + ']' for column in names] + col_names = ','.join(bracketed_names) + wildcards = ','.join(['?'] * len(names)) + insert_query = 'INSERT INTO %s (%s) VALUES (%s)' % ( + table, col_names, wildcards) + # pandas types are badly handled if there is only 1 column ( Issue #3628 ) + if not len(frame.columns )==1 : + data = [tuple(x) for x in frame.values] + else : + data = [tuple(x) for x in frame.values.tolist()] + cur.executemany(insert_query, data) + +def _write_mysql(frame, table, names, cur): + bracketed_names = ['`' + column + '`' for column in names] + col_names = ','.join(bracketed_names) + wildcards = ','.join([r'%s'] * len(names)) + insert_query = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, col_names, wildcards) + data = [tuple(x) for x in frame.values] + cur.executemany(insert_query, data) + +def table_exists(name, con, flavor): + flavor_map = { + 'sqlite': ("SELECT name FROM sqlite_master " + "WHERE type='table' AND name='%s';") % name, + 'mysql' : "SHOW TABLES LIKE '%s'" % name} + query = flavor_map.get(flavor, None) + if query is None: + raise NotImplementedError + return len(tquery(query, con)) > 0 + +def get_sqltype(pytype, flavor): + sqltype = {'mysql': 'VARCHAR (63)', + 'sqlite': 'TEXT'} + + if issubclass(pytype, np.floating): + sqltype['mysql'] = 'FLOAT' + sqltype['sqlite'] = 'REAL' + + if issubclass(pytype, np.integer): + #TODO: Refine integer size. + sqltype['mysql'] = 'BIGINT' + sqltype['sqlite'] = 'INTEGER' + + if issubclass(pytype, np.datetime64) or pytype is datetime: + # Caution: np.datetime64 is also a subclass of np.number. + sqltype['mysql'] = 'DATETIME' + sqltype['sqlite'] = 'TIMESTAMP' + + if pytype is datetime.date: + sqltype['mysql'] = 'DATE' + sqltype['sqlite'] = 'TIMESTAMP' + + if issubclass(pytype, np.bool_): + sqltype['sqlite'] = 'INTEGER' + + return sqltype[flavor] + +def get_schema(frame, name, flavor, keys=None): + "Return a CREATE TABLE statement to suit the contents of a DataFrame." + lookup_type = lambda dtype: get_sqltype(dtype.type, flavor) + # Replace spaces in DataFrame column names with _. + safe_columns = [s.replace(' ', '_').strip() for s in frame.dtypes.index] + column_types = zip(safe_columns, map(lookup_type, frame.dtypes)) + if flavor == 'sqlite': + columns = ',\n '.join('[%s] %s' % x for x in column_types) + else: + columns = ',\n '.join('`%s` %s' % x for x in column_types) + + keystr = '' + if keys is not None: + if isinstance(keys, basestring): + keys = (keys,) + keystr = ', PRIMARY KEY (%s)' % ','.join(keys) + template = """CREATE TABLE %(name)s ( + %(columns)s + %(keystr)s + );""" + create_statement = template % {'name': name, 'columns': columns, + 'keystr': keystr} + return create_statement + +def sequence2dict(seq): + """Helper function for cx_Oracle. + + For each element in the sequence, creates a dictionary item equal + to the element and keyed by the position of the item in the list. + >>> sequence2dict(("Matt", 1)) + {'1': 'Matt', '2': 1} + + Source: + http://www.gingerandjohn.com/archives/2004/02/26/cx_oracle-executemany-example/ + """ + d = {} + for k,v in zip(range(1, 1 + len(seq)), seq): + d[str(k)] = v + return d diff --git a/pandas/io/tests/test_legacy_sql.py b/pandas/io/tests/test_sql_legacy.py similarity index 99% rename from pandas/io/tests/test_legacy_sql.py rename to pandas/io/tests/test_sql_legacy.py index 5b23bf173ec4e..19cdd9b26cb54 100644 --- a/pandas/io/tests/test_legacy_sql.py +++ b/pandas/io/tests/test_sql_legacy.py @@ -272,7 +272,9 @@ def setUp(self): def test_basic(self): _skip_if_no_MySQLdb() frame = tm.makeTimeDataFrame() - self._check_roundtrip(frame) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "For more robust support.*") + self._check_roundtrip(frame) def test_write_row_by_row(self): _skip_if_no_MySQLdb()