diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 3d0acb09ca0cc..5acacdb110958 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -5,7 +5,7 @@ from __future__ import print_function from datetime import datetime, date -from pandas.compat import range, lzip, map, zip +from pandas.compat import range, lzip, map, zip, raise_with_traceback import pandas.compat as compat import numpy as np import traceback @@ -18,11 +18,22 @@ from pandas.core.api import DataFrame, isnull from pandas.io import sql_legacy + + +class SQLAlchemyRequired(ImportError): + pass + +class LegacyMySQLConnection(Exception): + pass + +class DatabaseError(IOError): + pass + + #------------------------------------------------------------------------------ # Helper execution function - -def execute(sql, con, retry=True, cur=None, params=None): +def execute(sql, con=None, retry=True, cur=None, params=None, engine=None): """ Execute the given SQL query using the provided connection object. @@ -44,6 +55,13 @@ def execute(sql, con, retry=True, cur=None, params=None): ------- Cursor object """ + if engine is not None: + try: + return engine.execute(sql, params=params) + except Exception as e: + ex = DatabaseError("Execution failed with: %s" % e) + raise_with_traceback(ex) + try: if cur is None: cur = con.cursor() @@ -53,17 +71,18 @@ def execute(sql, con, retry=True, cur=None, params=None): else: cur.execute(sql, params) return cur - except Exception: + except Exception as e: try: con.rollback() except Exception: # pragma: no cover - pass + ex = DatabaseError("Execution failed on sql: %s\n%s\nunable to rollback" % (sql, e)) + raise_with_traceback(ex) - print('Error on sql %s' % sql) - raise + ex = DatabaseError("Execution failed on sql: %s" % sql) + raise_with_traceback(ex) - -def _safe_fetch(cur): +def _safe_fetch(cur=None): + '''ensures result of fetchall is a list''' try: result = cur.fetchall() if not isinstance(result, list): @@ -74,8 +93,7 @@ def _safe_fetch(cur): if excName == 'OperationalError': return [] - -def tquery(sql, con=None, cur=None, retry=True): +def tquery(sql, con=None, retry=True, cur=None, engine=None, params=None): """ Returns list of tuples corresponding to each row in given sql query. @@ -87,12 +105,32 @@ def tquery(sql, con=None, cur=None, retry=True): sql: string SQL query to be executed con: SQLConnection or DB API 2.0-compliant connection + retry : bool cur: DB API 2.0 cursor Provide a specific connection or a specific cursor if you are executing a lot of sequential statements and want to commit outside. """ - cur = execute(sql, con, cur=cur) + if params is None: + params = [] + if engine: + result = execute(sql, *params, engine=engine) + return result.fetchall() # is this tuples? + else: + result = _cur_tquery(sql, con=con, retry=retry, cur=cur, params=params) + + # This makes into tuples? + if result and len(result[0]) == 1: + # python 3 compat + result = list(lzip(*result)[0]) + elif result is None: # pragma: no cover + result = [] + return result + + +def _cur_tquery(sql, con=None, retry=True, cur=None, engine=None, params=None): + + cur = execute(sql, con, cur=cur, params=params) result = _safe_fetch(cur) if con is not None: @@ -110,23 +148,28 @@ def tquery(sql, con=None, cur=None, retry=True): if retry: return tquery(sql, con=con, retry=False) - if result and len(result[0]) == 1: - # python 3 compat - result = list(lzip(*result)[0]) - elif result is None: # pragma: no cover - result = [] - return result -def uquery(sql, con=None, cur=None, retry=True, params=None): +def uquery(sql, con=None, cur=None, retry=True, params=None, engine=None): """ Does the same thing as tquery, but instead of returning results, it returns the number of rows affected. Good for update queries. """ - cur = execute(sql, con, cur=cur, retry=retry, params=params) + if params is None: + params = [] + + if engine: + result = execute(sql, *params, engine=engine) + return result.rowcount + + else: + return _cur_uquery(sql, con=con, cur=cur, retry=retry, params=params) - result = cur.rowcount + +def _cur_uquery(sql, con=None, cur=None, retry=True, params=None, engine=None): + cur = execute(sql, con, cur=cur, retry=retry, params=params) + row_count = cur.rowcount try: con.commit() except Exception as e: @@ -138,13 +181,8 @@ def uquery(sql, con=None, cur=None, retry=True, params=None): if retry: print ('Looks like your connection failed, reconnecting...') return uquery(sql, con, retry=False) - return result + return row_count -class SQLAlchemyRequired(Exception): - pass - -class LegacyMySQLConnection(Exception): - pass def get_connection(con, dialect, driver, username, password, host, port, database): @@ -181,14 +219,14 @@ def get_connection(con, dialect, driver, username, password, return engine.connect() if hasattr(con, 'cursor') and callable(con.cursor): # This looks like some Connection object from a driver module. - raise NotImplementedError, \ + raise NotImplementedError( """To ensure robust support of varied SQL dialects, pandas - only supports database connections from SQLAlchemy. (Legacy - support for MySQLdb connections are available but buggy.)""" + only support database connections from SQLAlchemy. See + documentation.""") else: - raise ValueError, \ + raise ValueError( """con must be a string, a Connection to a sqlite Database, - or a SQLAlchemy Connection or Engine object.""" + or a SQLAlchemy Connection or Engine object.""") def _alchemy_connect_sqlite(path): @@ -204,9 +242,7 @@ def _build_url(dialect, driver, username, password, host, port, database): required_params = [dialect, username, password, host, database] for p in required_params: if not isinstance(p, basestring): - raise ValueError, \ - "Insufficient information to connect to a database;" \ - "see docstring." + raise ValueError("Insufficient information to connect to a database; see docstring.") url = dialect if driver is not None: url += "+%s" % driver @@ -218,7 +254,7 @@ def _build_url(dialect, driver, username, password, host, port, database): def read_sql(sql, con=None, index_col=None, flavor=None, driver=None, username=None, password=None, host=None, port=None, - database=None, coerce_float=True, params=None): + database=None, coerce_float=True, params=None, engine=None): """ Returns a DataFrame corresponding to the result set of the query string. @@ -250,24 +286,33 @@ def read_sql(sql, con=None, index_col=None, flavor=None, driver=None, decimal.Decimal) to floating point, useful for SQL result sets params: list or tuple, optional List of parameters to pass to execute method. - """ - dialect = flavor - try: - connection = get_connection(con, dialect, driver, username, password, - host, port, database) - except LegacyMySQLConnection: - warnings.warn("For more robust support, connect using " \ - "SQLAlchemy. See documentation.") - return sql_legacy.read_frame(sql, con, index_col, coerce_float, params) + engine : SQLAlchemy engine, optional + """ if params is None: params = [] - cursor = connection.execute(sql, *params) - result = _safe_fetch(cursor) - columns = [col_desc[0] for col_desc in cursor.description] - cursor.close() - result = DataFrame.from_records(result, columns=columns) + if engine: + result = engine.execute(sql, *params) + data = result.fetchall() + columns = result.keys() + + else: + dialect = flavor + try: + connection = get_connection(con, dialect, driver, username, password, + host, port, database) + except LegacyMySQLConnection: + warnings.warn("For more robust support, connect using " \ + "SQLAlchemy. See documentation.") + return sql_legacy.read_frame(sql, con, index_col, coerce_float, params) + + cursor = connection.execute(sql, *params) + data = _safe_fetch(cursor) + columns = [col_desc[0] for col_desc in cursor.description] + cursor.close() + + result = DataFrame.from_records(data, columns=columns) if index_col is not None: result = result.set_index(index_col) @@ -278,7 +323,7 @@ def read_sql(sql, con=None, index_col=None, flavor=None, driver=None, read_frame = read_sql -def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs): +def write_frame(frame, name, con=None, flavor='sqlite', if_exists='fail', engine=None, **kwargs): """ Write records stored in a DataFrame to a SQL database. @@ -301,70 +346,108 @@ def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs): if kwargs['append']: if_exists = 'append' else: - if_exists = 'fail' - exists = table_exists(name, con, flavor) - if if_exists == 'fail' and exists: - raise ValueError("Table '%s' already exists." % name) - - #create or drop-recreate if necessary - create = None - if exists and if_exists == 'replace': - create = "DROP TABLE %s" % name - elif not exists: - create = get_schema(frame, name, flavor) + if_exists='fail' + + if engine: + exists = engine.has_table(name) + else: + exists = table_exists(name, con, flavor) + + create = None #create or drop-recreate if necessary + if exists: + if if_exists == 'fail': + raise ValueError("Table '%s' already exists." % name) + elif if_exists == 'replace': + if engine: + _engine_drop_table(name) + else: + create = "DROP TABLE %s" % name + else: + if engine: + _engine_create_table(frame, name, engine=engine) + else: + create = get_schema(frame, name, flavor) if create is not None: cur = con.cursor() cur.execute(create) cur.close() - cur = con.cursor() - # Replace spaces in DataFrame column names with _. - safe_names = [s.replace(' ', '_').strip() for s in frame.columns] - flavor_picker = {'sqlite': _write_sqlite, - 'mysql': _write_mysql} - - func = flavor_picker.get(flavor, None) - if func is None: - raise NotImplementedError - func(frame, name, safe_names, cur) - cur.close() - con.commit() + if engine: + _engine_write(frame, name, engine) + else: + cur = con.cursor() + # Replace spaces in DataFrame column names with _. + safe_names = [s.replace(' ', '_').strip() for s in frame.columns] + flavor_picker = {'sqlite' : _cur_write_sqlite, + 'mysql' : _cur_write_mysql} + + func = flavor_picker.get(flavor, None) + if func is None: + raise NotImplementedError + func(frame, name, safe_names, cur) + cur.close() + con.commit() -def _write_sqlite(frame, table, names, cur): +def _cur_write_sqlite(frame, table, names, cur): bracketed_names = ['[' + column + ']' for column in names] col_names = ','.join(bracketed_names) wildcards = ','.join(['?'] * len(names)) insert_query = 'INSERT INTO %s (%s) VALUES (%s)' % ( table, col_names, wildcards) # pandas types are badly handled if there is only 1 column ( Issue #3628 ) - if not len(frame.columns) == 1: + if len(frame.columns) != 1: data = [tuple(x) for x in frame.values] else: data = [tuple(x) for x in frame.values.tolist()] cur.executemany(insert_query, data) - -def _write_mysql(frame, table, names, cur): +def _cur_write_mysql(frame, table, names, cur): bracketed_names = ['`' + column + '`' for column in names] col_names = ','.join(bracketed_names) wildcards = ','.join([r'%s'] * len(names)) insert_query = "INSERT INTO %s (%s) VALUES (%s)" % ( table, col_names, wildcards) - data = [tuple(x) for x in frame.values] + # pandas types are badly handled if there is only 1 column ( Issue #3628 ) + if len(frame.columns) != 1: + data = [tuple(x) for x in frame.values] + else: + data = [tuple(x) for x in frame.values.tolist()] cur.executemany(insert_query, data) +def _engine_write(frame, table_name, engine): + table = _engine_get_table(table_name, engine) + ins = table.insert() + # TODO: do this in one pass + # engine.execute(ins, *(t[1:] for t in frame.itertuples())) # t[1:] doesn't include index + # engine.execute(ins, *[tuple(x) for x in frame.values]) + + # TODO this should be done globally first (or work out how to pass np dtypes to sql) + def maybe_asscalar(i): + try: + return np.asscalar(i) + except AttributeError: + return i + + for t in frame.iterrows(): + engine.execute(ins, **dict((k, maybe_asscalar(v)) for k, v in t[1].iteritems())) + # TODO more efficient, I'm *sure* this was just working with tuples + -def table_exists(name, con, flavor): - flavor_map = { - 'sqlite': ("SELECT name FROM sqlite_master " - "WHERE type='table' AND name='%s';") % name, - 'mysql': "SHOW TABLES LIKE '%s'" % name} - query = flavor_map.get(flavor, None) - if query is None: - raise NotImplementedError - return len(tquery(query, con)) > 0 +def table_exists(name, con=None, flavor=None, engine=None): + if engine: + return engine.has_table(name) + + else: + flavor_map = { + 'sqlite': ("SELECT name FROM sqlite_master " + "WHERE type='table' AND name='%s';") % name, + 'mysql' : "SHOW TABLES LIKE '%s'" % name} + query = flavor_map.get(flavor, None) + if query is None: + raise NotImplementedError + return len(tquery(query, con)) > 0 def get_sqltype(pytype, flavor): @@ -435,3 +518,102 @@ def sequence2dict(seq): for k, v in zip(range(1, 1 + len(seq)), seq): d[str(k)] = v return d + + +def _engine_drop_table(table_name, engine): + if engine.has_table(table_name): + table = _engine_get_table(table_name, engine=engine) + table.drop() + +def _engine_lookup_type(dtype): + from sqlalchemy import Table, Column, INT, FLOAT, TEXT, BOOLEAN + + pytype = dtype.type + + if issubclass(pytype, np.floating): + return FLOAT + + if issubclass(pytype, np.integer): + #TODO: Refine integer size. + return INT + + if issubclass(pytype, np.datetime64) or pytype is datetime: + # Caution: np.datetime64 is also a subclass of np.number. + return DATETIME + + if pytype is datetime.date: + return DATE + + if issubclass(pytype, np.bool_): + return BOOLEAN + + return TEXT + +def _engine_create_table(frame, table_name, engine, keys=None, meta=None): + from sqlalchemy import Table, Column + if keys is None: + keys = [] + if not meta: + from sqlalchemy.schema import MetaData + meta = MetaData(engine) + meta.reflect(engine) + + safe_columns = [s.replace(' ', '_').strip() for s in frame.dtypes.index] # may not be safe enough... + column_types = map(_engine_lookup_type, frame.dtypes) + + columns = [(col_name, col_sqltype, col_name in keys) + for col_name, col_sqltype in zip(safe_columns, column_types)] + columns = map(lambda (name, typ, pk): Column(name, typ, primary_key=pk), columns) + + table = Table(table_name, meta, *columns) + + table.create() + +def _engine_get_table(table_name, engine, meta=None): + if engine.has_table(table_name): + if not meta: + from sqlalchemy.schema import MetaData + meta = MetaData(engine) + meta.reflect(engine) + return meta.tables[table_name] + else: + return None + +def _engine_read_sql(sql, engine, params=None, index_col=None): + + if params is None: + params = [] + + try: + result = engine.execute(sql, *params) + except Exception as e: + raise DatabaseError + data = result.fetchall() + columns = result.keys() + + df = DataFrame.from_records(data, columns=columns) + if index_col is not None: + df.set_index(index_col, inplace=True) + return df + +def _engine_read_table_name(table_name, engine, meta=None, index_col=None): + table = _engine_get_table(table_name, engine=engine, meta=meta) + + if table is not None: + sql_select = table.select() + return _engine_read_sql(sql_select, engine=engine, index_col=index_col) + else: + raise ValueError("Table %s not found with %s." % table_name, engine) + +def _engine_write_frame(frame, name, engine, if_exists='fail'): + + exists = engine.has_table(name) + if exists: + if if_exists == 'fail': + raise ValueError("Table '%s' already exists." % name) + elif if_exists == 'replace': + _engine_drop_table(name) + else: + _engine_create_table(frame, name, engine=engine) + + _engine_write(frame, name, engine) diff --git a/pandas/io/sql_legacy.py b/pandas/io/sql_legacy.py index 11b139b620175..91cb2ec77af08 100644 --- a/pandas/io/sql_legacy.py +++ b/pandas/io/sql_legacy.py @@ -91,7 +91,7 @@ def tquery(sql, con=None, cur=None, retry=True): try: cur.close() con.commit() - except Exception, e: + except Exception as e: excName = e.__class__.__name__ if excName == 'OperationalError': # pragma: no cover print ('Failed to commit, may need to restart interpreter') @@ -121,7 +121,7 @@ def uquery(sql, con=None, cur=None, retry=True, params=None): result = cur.rowcount try: con.commit() - except Exception, e: + except Exception as e: excName = e.__class__.__name__ if excName != 'OperationalError': raise diff --git a/pandas/io/tests/test_sql.py b/pandas/io/tests/test_sql.py index c3461f1df8de5..dc3815d68e2f3 100644 --- a/pandas/io/tests/test_sql.py +++ b/pandas/io/tests/test_sql.py @@ -14,17 +14,265 @@ from pandas.compat import StringIO, range, lrange import pandas.compat as compat + import pandas.io.sql as sql +from pandas.io.sql import DatabaseError import pandas.util.testing as tm -from pandas import Series, Index, DataFrame +from pandas import Series, Index, DataFrame, isnull from datetime import datetime import sqlalchemy +import sqlite3 # try to import other db modules in their test classes + +from sqlalchemy import Table, Column, INT, FLOAT, TEXT + + +class TestSQLAlchemy(unittest.TestCase): + + def set_flavor_engine(self): + # override for other db modules + self.engine = sqlalchemy.create_engine('sqlite:///:memory:') + + def setUp(self): + # this is overriden for other db modules + self.set_flavor_engine() + + # shared for all db modules + self.meta = sqlalchemy.schema.MetaData(self.engine) + self.drop_table('test') # should already be done ? + self.meta.reflect(self.engine) # not sure if this is different + + self.frame = tm.makeTimeDataFrame() + + def drop_table(self, table_name): + sql._engine_drop_table(table_name, engine=self.engine) + + def create_table(self, frame, table_name, keys=None): + return sql._engine_create_table(frame, table_name, keys=None, engine=self.engine) + + def get_table(self, table_name): + return sql._engine_get_table(table_name, self.engine) + + def tquery(self, fmt_sql, params=None, retry=False): + sql.tquery(fmt_sql, engine=self.engine, params=params, retry=retry) + + def read_frame(self, fmt_sql=None): + return sql.read_frame(fmt_sql, engine=self.engine) + + def _check_roundtrip(self, frame): + self.drop_table('test') + sql._engine_write_frame(self.frame, 'test', self.engine) + result = sql._engine_read_table_name('test', engine=self.engine) + + # HACK! + result.index = self.frame.index + + tm.assert_frame_equal(result, self.frame) + + self.frame['txt'] = ['a'] * len(self.frame) + frame2 = self.frame.copy() + frame2['Idx'] = Index(range(len(frame2))) + 10 + + self.drop_table('test_table2') + sql._engine_write_frame(frame2, 'test_table2', self.engine) + result = sql._engine_read_table_name('test_table2', engine=self.engine, index_col='Idx') + + self.assertRaises(DatabaseError, self.tquery, + 'insert into blah values (1)') + + self.assertRaises(DatabaseError, self.tquery, + 'insert into blah values (1)', + retry=True) + + + def test_basic(self): + self._check_roundtrip(self.frame) + + # not sure what intention of this was? + def test_na_roundtrip(self): + pass + + def test_write_row_by_row(self): + self.frame.ix[0, 0] = np.nan + self.create_table(self.frame, 'test') + + test_table = self.get_table('test') + + ins = test_table.insert() # INSERT INTO test VALUES (%s, %s, %s, %s) + for idx, row in self.frame.iterrows(): + values = tuple(row) + sql.execute(ins.values(values), engine=self.engine) + + select_test = test_table.select() # SELECT * FROM test + + result = self.read_frame(select_test) + + result.index = self.frame.index + tm.assert_frame_equal(result, self.frame) + + def test_execute(self): + # drop_sql = "DROP TABLE IF EXISTS test" # should already be done + self.create_table(self.frame, 'test') + + test_table = self.get_table('test') + + ins = test_table.insert() # INSERT INTO test VALUES (%s, %s, %s, %s) + + row = self.frame.ix[0] + self.engine.execute(ins, **row) + + select_test = test_table.select() # SELECT * FROM test + result = self.read_frame(select_test) + result.index = self.frame.index[:1] + tm.assert_frame_equal(result, self.frame[:1]) + + def test_execute_fail(self): + """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a, b) + ); + """ + from sqlalchemy import Table, Column, TEXT, REAL + test_table = Table('test', self.meta, + Column('a', TEXT), Column('b', TEXT), Column('c', REAL)) + test_table.create() + + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', engine=self.engine) + sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', engine=self.engine) + + self.assertRaises(DatabaseError, sql.execute, + 'INSERT INTO test VALUES("foo", "bar", 7)', + self.engine) + + def test_tquery(self): + self.drop_table('test_table') + sql._engine_write_frame(self.frame, 'test_table', self.engine) + result = sql.tquery("select A from test_table", engine=self.engine) + expected = self.frame.A + result = DataFrame(result, self.frame.index, columns=['A'])['A'] + tm.assert_series_equal(result, expected) + + self.assertRaises(DatabaseError, sql.tquery, + 'select * from blah', engine=self.engine) + + self.assertRaises(DatabaseError, sql.tquery, + 'select * from blah', con=self.engine, retry=True) + + def test_uquery(self): + self.drop_table('test_table') + sql._engine_write_frame(self.frame, 'test_table', self.engine) + + ins = sql._engine_get_table('test_table', self.engine).insert() + params = (2.314, -123.1, 1.234, 2.3) + self.assertEqual(sql.uquery(ins, params, engine=self.engine), 1) + + self.assertRaises(DatabaseError, sql.uquery, + 'insert into blah values (1)', engine=self.engine) + + self.assertRaises(DatabaseError, sql.tquery, + 'insert into blah values (1)', engine=self.engine, retry=True) + + + def test_onecolumn_of_integer(self): + 'GH 3628, a column_of_integers dataframe should transfer well to sql' + mono_df = DataFrame([1 , 2], columns=['c0']) + sql._engine_write_frame(mono_df, 'mono_df', self.engine) + # computing the sum via sql + select = sql._engine_get_table('mono_df', self.engine).select() + the_sum = sum([my_c0[0] for my_c0 in self.engine.execute(select)]) + # it should not fail, and gives 3 ( Issue #3628 ) + self.assertEqual(the_sum , 3) + + result = sql._engine_read_table_name('mono_df', engine=self.engine) + tm.assert_frame_equal(result, mono_df) + + def test_keyword_as_column_names(self): + df = DataFrame({'From':np.ones(5)}) + sql.write_frame(df, engine=self.engine, name='testkeywords', + if_exists='replace', flavor='mysql') + + + # Not needed with engines, but add into con/cur tests later + + # def test_execute_closed_connection(self): + # create_sql = """ + # CREATE TABLE test + # ( + # a TEXT, + # b TEXT, + # c REAL, + # PRIMARY KEY (a, b) + # ); + # """ + # cur = self.db.cursor() + # cur.execute(create_sql) + + # sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) + # self.db.close() + # try: + # sys.stdout = StringIO() + # self.assertRaises(Exception, sql.tquery, "select * from test", + # con=self.db) + # finally: + # sys.stdout = sys.__stdout__ + + # def test_schema(self): + # create_sql = self.create_table(self.frame, 'test')[1] + # lines = create_sql.splitlines() + # for l in lines: + # tokens = l.split(' ') + # if len(tokens) == 2 and tokens[0] == 'A': + # self.assert_(tokens[1] == 'DATETIME') + # self.drop_table('test') + # create_sql = self.create_table(frame, 'test', keys=['A', 'B'])[1] + # self.assert_('PRIMARY KEY (A,B)' in create_sql) + + +class TestSQLA_pymysql(TestSQLAlchemy): + def set_flavor_engine(self): + # if can't import should skip all tests + try: + import pymysql + except ImportError: + raise nose.SkipTest("pymysql was not installed") + + try: + self.engine = sqlalchemy.create_engine("mysql+pymysql://root:@localhost/pandas_nosetest") + except pymysql.Error, e: + raise nose.SkipTest( + "Cannot connect to database. " + "Create a group of connection parameters under the heading " + "[pandas] in your system's mysql default file, " + "typically located at ~/.my.cnf or /etc/.my.cnf. ") + except pymysql.ProgrammingError, e: + raise nose.SkipTest( + "Create a group of connection parameters under the heading " + "[pandas] in your system's mysql default file, " + "typically located at ~/.my.cnf or /etc/.my.cnf. ") + +class TestSQLA_MySQLdb(TestSQLAlchemy): + def set_flavor_engine(self): + # if can't import should skip all tests + try: + import MySQLdb + except ImportError: + raise nose.SkipTest("MySQLdb was not installed") -if __name__ == '__main__': - # unittest.main() - # nose.runmodule(argv=[__file__,'-vvs','-x', '--pdb-failure'], - # exit=False) - nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'], - exit=False) + try: + self.engine = sqlalchemy.create_engine("mysql+mysqldb://root:@localhost/pandas_nosetest") + except MySQLdb.Error: + raise nose.SkipTest( + "Cannot connect to database. " + "Create a group of connection parameters under the heading " + "[pandas] in your system's mysql default file, " + "typically located at ~/.my.cnf or /etc/.my.cnf. ") + except MySQLdb.ProgrammingError: + raise nose.SkipTest( + "Create a group of connection parameters under the heading " + "[pandas] in your system's mysql default file, " + "typically located at ~/.my.cnf or /etc/.my.cnf. ") \ No newline at end of file diff --git a/pandas/io/tests/test_sql_legacy.py b/pandas/io/tests/test_sql_legacy.py index 69620146c22cd..3c6e992097d30 100644 --- a/pandas/io/tests/test_sql_legacy.py +++ b/pandas/io/tests/test_sql_legacy.py @@ -1,5 +1,5 @@ from __future__ import with_statement -from pandas.util.py3compat import StringIO +from pandas.compat import StringIO import unittest import sqlite3 import sys @@ -12,8 +12,11 @@ from pandas.core.datetools import format as date_format from pandas.core.api import DataFrame, isnull +from pandas.compat import StringIO, range, lrange +import pandas.compat as compat import pandas.io.sql as sql +from pandas.io.sql import DatabaseError import pandas.util.testing as tm from pandas import Series, Index, DataFrame from datetime import datetime @@ -193,10 +196,10 @@ def test_tquery(self): try: sys.stdout = StringIO() - self.assertRaises(sqlite3.OperationalError, sql.tquery, + self.assertRaises(DatabaseError, sql.tquery, 'select * from blah', con=self.db) - self.assertRaises(sqlite3.OperationalError, sql.tquery, + self.assertRaises(DatabaseError, sql.tquery, 'select * from blah', con=self.db, retry=True) finally: sys.stdout = sys.__stdout__ @@ -210,10 +213,10 @@ def test_uquery(self): try: sys.stdout = StringIO() - self.assertRaises(sqlite3.OperationalError, sql.tquery, + self.assertRaises(DatabaseError, sql.tquery, 'insert into blah values (1)', con=self.db) - self.assertRaises(sqlite3.OperationalError, sql.tquery, + self.assertRaises(DatabaseError, sql.tquery, 'insert into blah values (1)', con=self.db, retry=True) finally: @@ -445,10 +448,10 @@ def test_tquery(self): try: sys.stdout = StringIO() - self.assertRaises(MySQLdb.ProgrammingError, sql.tquery, + self.assertRaises(DatabaseError, sql.tquery, 'select * from blah', con=self.db) - self.assertRaises(MySQLdb.ProgrammingError, sql.tquery, + self.assertRaises(DatabaseError, sql.tquery, 'select * from blah', con=self.db, retry=True) finally: sys.stdout = sys.__stdout__ @@ -469,10 +472,10 @@ def test_uquery(self): try: sys.stdout = StringIO() - self.assertRaises(MySQLdb.ProgrammingError, sql.tquery, + self.assertRaises(DatabaseError, sql.tquery, 'insert into blah values (1)', con=self.db) - self.assertRaises(MySQLdb.ProgrammingError, sql.tquery, + self.assertRaises(DatabaseError, sql.tquery, 'insert into blah values (1)', con=self.db, retry=True) finally: @@ -483,7 +486,7 @@ def test_keyword_as_column_names(self): ''' _skip_if_no_MySQLdb() df = DataFrame({'From':np.ones(5)}) - sql.write_frame(df, con = self.db, name = 'testkeywords', + sql.write_frame(df, name='testkeywords', con=self.db, if_exists='replace', flavor='mysql') if __name__ == '__main__':