Skip to content

Commit 4d37747

Browse files
committed
Separate legacy code into new file, and fallback so that all old tests pass.
1 parent 7179732 commit 4d37747

File tree

3 files changed

+350
-12
lines changed

3 files changed

+350
-12
lines changed

pandas/io/sql.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from pandas.core.datetools import format as date_format
1414
from pandas.core.api import DataFrame, isnull
15+
from pandas.io import sql_legacy
1516

1617
#------------------------------------------------------------------------------
1718
# Helper execution function
@@ -138,6 +139,9 @@ def uquery(sql, con=None, cur=None, retry=True, params=None):
138139
class SQLAlchemyRequired(Exception):
139140
pass
140141

142+
class LegacyMySQLConnection(Exception):
143+
pass
144+
141145
def get_connection(con, dialect, driver, username, password,
142146
host, port, database):
143147
if isinstance(con, basestring):
@@ -148,6 +152,14 @@ def get_connection(con, dialect, driver, username, password,
148152
return sqlite3.connect(con)
149153
if isinstance(con, sqlite3.Connection):
150154
return con
155+
try:
156+
import MySQLdb
157+
except ImportError:
158+
# If we don't have MySQLdb, this can't be a MySQLdb connection.
159+
pass
160+
else:
161+
if isinstance(con, MySQLdb.connection):
162+
raise LegacyMySQLConnection
151163
# If we reach here, SQLAlchemy will be needed.
152164
try:
153165
import sqlalchemy
@@ -165,17 +177,10 @@ def get_connection(con, dialect, driver, username, password,
165177
return engine.connect()
166178
if hasattr(con, 'cursor') and callable(con.cursor):
167179
# This looks like some Connection object from a driver module.
168-
try:
169-
import MySQLdb
170-
warnings.warn("For more robust support, connect using " \
171-
"SQLAlchemy. See documentation.")
172-
return conn.cursor() # behaves like a sqlalchemy Connection
173-
except ImportError:
174-
pass
175180
raise NotImplementedError, \
176181
"""To ensure robust support of varied SQL dialects, pandas
177-
only support database connections from SQLAlchemy. See
178-
documentation."""
182+
only supports database connections from SQLAlchemy. (Legacy
183+
support for MySQLdb connections are available but buggy.)"""
179184
else:
180185
raise ValueError, \
181186
"""con must be a string, a Connection to a sqlite Database,
@@ -243,8 +248,14 @@ def read_sql(sql, con=None, index_col=None, flavor=None, driver=None,
243248
List of parameters to pass to execute method.
244249
"""
245250
dialect = flavor
246-
connection = get_connection(con, dialect, driver, username, password,
247-
host, port, database)
251+
try:
252+
connection = get_connection(con, dialect, driver, username, password,
253+
host, port, database)
254+
except LegacyMySQLConnection:
255+
warnings.warn("For more robust support, connect using " \
256+
"SQLAlchemy. See documentation.")
257+
return sql_legacy.read_frame(sql, con, index_col, coerce_float, params)
258+
248259
if params is None:
249260
params = []
250261
cursor = connection.execute(sql, *params)

pandas/io/sql_legacy.py

+325
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
"""
2+
Collection of query wrappers / abstractions to both facilitate data
3+
retrieval and to reduce dependency on DB-specific API.
4+
"""
5+
from datetime import datetime, date
6+
7+
import numpy as np
8+
import traceback
9+
10+
from pandas.core.datetools import format as date_format
11+
from pandas.core.api import DataFrame, isnull
12+
13+
#------------------------------------------------------------------------------
14+
# Helper execution function
15+
16+
17+
def execute(sql, con, retry=True, cur=None, params=None):
18+
"""
19+
Execute the given SQL query using the provided connection object.
20+
21+
Parameters
22+
----------
23+
sql: string
24+
Query to be executed
25+
con: database connection instance
26+
Database connection. Must implement PEP249 (Database API v2.0).
27+
retry: bool
28+
Not currently implemented
29+
cur: database cursor, optional
30+
Must implement PEP249 (Datbase API v2.0). If cursor is not provided,
31+
one will be obtained from the database connection.
32+
params: list or tuple, optional
33+
List of parameters to pass to execute method.
34+
35+
Returns
36+
-------
37+
Cursor object
38+
"""
39+
try:
40+
if cur is None:
41+
cur = con.cursor()
42+
43+
if params is None:
44+
cur.execute(sql)
45+
else:
46+
cur.execute(sql, params)
47+
return cur
48+
except Exception:
49+
try:
50+
con.rollback()
51+
except Exception: # pragma: no cover
52+
pass
53+
54+
print ('Error on sql %s' % sql)
55+
raise
56+
57+
58+
def _safe_fetch(cur):
59+
try:
60+
result = cur.fetchall()
61+
if not isinstance(result, list):
62+
result = list(result)
63+
return result
64+
except Exception, e: # pragma: no cover
65+
excName = e.__class__.__name__
66+
if excName == 'OperationalError':
67+
return []
68+
69+
70+
def tquery(sql, con=None, cur=None, retry=True):
71+
"""
72+
Returns list of tuples corresponding to each row in given sql
73+
query.
74+
75+
If only one column selected, then plain list is returned.
76+
77+
Parameters
78+
----------
79+
sql: string
80+
SQL query to be executed
81+
con: SQLConnection or DB API 2.0-compliant connection
82+
cur: DB API 2.0 cursor
83+
84+
Provide a specific connection or a specific cursor if you are executing a
85+
lot of sequential statements and want to commit outside.
86+
"""
87+
cur = execute(sql, con, cur=cur)
88+
result = _safe_fetch(cur)
89+
90+
if con is not None:
91+
try:
92+
cur.close()
93+
con.commit()
94+
except Exception, e:
95+
excName = e.__class__.__name__
96+
if excName == 'OperationalError': # pragma: no cover
97+
print ('Failed to commit, may need to restart interpreter')
98+
else:
99+
raise
100+
101+
traceback.print_exc()
102+
if retry:
103+
return tquery(sql, con=con, retry=False)
104+
105+
if result and len(result[0]) == 1:
106+
# python 3 compat
107+
result = list(list(zip(*result))[0])
108+
elif result is None: # pragma: no cover
109+
result = []
110+
111+
return result
112+
113+
114+
def uquery(sql, con=None, cur=None, retry=True, params=None):
115+
"""
116+
Does the same thing as tquery, but instead of returning results, it
117+
returns the number of rows affected. Good for update queries.
118+
"""
119+
cur = execute(sql, con, cur=cur, retry=retry, params=params)
120+
121+
result = cur.rowcount
122+
try:
123+
con.commit()
124+
except Exception, e:
125+
excName = e.__class__.__name__
126+
if excName != 'OperationalError':
127+
raise
128+
129+
traceback.print_exc()
130+
if retry:
131+
print ('Looks like your connection failed, reconnecting...')
132+
return uquery(sql, con, retry=False)
133+
return result
134+
135+
136+
def read_frame(sql, con, index_col=None, coerce_float=True, params=None):
137+
"""
138+
Returns a DataFrame corresponding to the result set of the query
139+
string.
140+
141+
Optionally provide an index_col parameter to use one of the
142+
columns as the index. Otherwise will be 0 to len(results) - 1.
143+
144+
Parameters
145+
----------
146+
sql: string
147+
SQL query to be executed
148+
con: DB connection object, optional
149+
index_col: string, optional
150+
column name to use for the returned DataFrame object.
151+
coerce_float : boolean, default True
152+
Attempt to convert values to non-string, non-numeric objects (like
153+
decimal.Decimal) to floating point, useful for SQL result sets
154+
params: list or tuple, optional
155+
List of parameters to pass to execute method.
156+
"""
157+
cur = execute(sql, con, params=params)
158+
rows = _safe_fetch(cur)
159+
columns = [col_desc[0] for col_desc in cur.description]
160+
161+
cur.close()
162+
con.commit()
163+
164+
result = DataFrame.from_records(rows, columns=columns,
165+
coerce_float=coerce_float)
166+
167+
if index_col is not None:
168+
result = result.set_index(index_col)
169+
170+
return result
171+
172+
frame_query = read_frame
173+
read_sql = read_frame
174+
175+
def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs):
176+
"""
177+
Write records stored in a DataFrame to a SQL database.
178+
179+
Parameters
180+
----------
181+
frame: DataFrame
182+
name: name of SQL table
183+
con: an open SQL database connection object
184+
flavor: {'sqlite', 'mysql', 'oracle'}, default 'sqlite'
185+
if_exists: {'fail', 'replace', 'append'}, default 'fail'
186+
fail: If table exists, do nothing.
187+
replace: If table exists, drop it, recreate it, and insert data.
188+
append: If table exists, insert data. Create if does not exist.
189+
"""
190+
191+
if 'append' in kwargs:
192+
import warnings
193+
warnings.warn("append is deprecated, use if_exists instead",
194+
FutureWarning)
195+
if kwargs['append']:
196+
if_exists='append'
197+
else:
198+
if_exists='fail'
199+
exists = table_exists(name, con, flavor)
200+
if if_exists == 'fail' and exists:
201+
raise ValueError, "Table '%s' already exists." % name
202+
203+
#create or drop-recreate if necessary
204+
create = None
205+
if exists and if_exists == 'replace':
206+
create = "DROP TABLE %s" % name
207+
elif not exists:
208+
create = get_schema(frame, name, flavor)
209+
210+
if create is not None:
211+
cur = con.cursor()
212+
cur.execute(create)
213+
cur.close()
214+
215+
cur = con.cursor()
216+
# Replace spaces in DataFrame column names with _.
217+
safe_names = [s.replace(' ', '_').strip() for s in frame.columns]
218+
flavor_picker = {'sqlite' : _write_sqlite,
219+
'mysql' : _write_mysql}
220+
221+
func = flavor_picker.get(flavor, None)
222+
if func is None:
223+
raise NotImplementedError
224+
func(frame, name, safe_names, cur)
225+
cur.close()
226+
con.commit()
227+
228+
def _write_sqlite(frame, table, names, cur):
229+
bracketed_names = ['[' + column + ']' for column in names]
230+
col_names = ','.join(bracketed_names)
231+
wildcards = ','.join(['?'] * len(names))
232+
insert_query = 'INSERT INTO %s (%s) VALUES (%s)' % (
233+
table, col_names, wildcards)
234+
# pandas types are badly handled if there is only 1 column ( Issue #3628 )
235+
if not len(frame.columns )==1 :
236+
data = [tuple(x) for x in frame.values]
237+
else :
238+
data = [tuple(x) for x in frame.values.tolist()]
239+
cur.executemany(insert_query, data)
240+
241+
def _write_mysql(frame, table, names, cur):
242+
bracketed_names = ['`' + column + '`' for column in names]
243+
col_names = ','.join(bracketed_names)
244+
wildcards = ','.join([r'%s'] * len(names))
245+
insert_query = "INSERT INTO %s (%s) VALUES (%s)" % (
246+
table, col_names, wildcards)
247+
data = [tuple(x) for x in frame.values]
248+
cur.executemany(insert_query, data)
249+
250+
def table_exists(name, con, flavor):
251+
flavor_map = {
252+
'sqlite': ("SELECT name FROM sqlite_master "
253+
"WHERE type='table' AND name='%s';") % name,
254+
'mysql' : "SHOW TABLES LIKE '%s'" % name}
255+
query = flavor_map.get(flavor, None)
256+
if query is None:
257+
raise NotImplementedError
258+
return len(tquery(query, con)) > 0
259+
260+
def get_sqltype(pytype, flavor):
261+
sqltype = {'mysql': 'VARCHAR (63)',
262+
'sqlite': 'TEXT'}
263+
264+
if issubclass(pytype, np.floating):
265+
sqltype['mysql'] = 'FLOAT'
266+
sqltype['sqlite'] = 'REAL'
267+
268+
if issubclass(pytype, np.integer):
269+
#TODO: Refine integer size.
270+
sqltype['mysql'] = 'BIGINT'
271+
sqltype['sqlite'] = 'INTEGER'
272+
273+
if issubclass(pytype, np.datetime64) or pytype is datetime:
274+
# Caution: np.datetime64 is also a subclass of np.number.
275+
sqltype['mysql'] = 'DATETIME'
276+
sqltype['sqlite'] = 'TIMESTAMP'
277+
278+
if pytype is datetime.date:
279+
sqltype['mysql'] = 'DATE'
280+
sqltype['sqlite'] = 'TIMESTAMP'
281+
282+
if issubclass(pytype, np.bool_):
283+
sqltype['sqlite'] = 'INTEGER'
284+
285+
return sqltype[flavor]
286+
287+
def get_schema(frame, name, flavor, keys=None):
288+
"Return a CREATE TABLE statement to suit the contents of a DataFrame."
289+
lookup_type = lambda dtype: get_sqltype(dtype.type, flavor)
290+
# Replace spaces in DataFrame column names with _.
291+
safe_columns = [s.replace(' ', '_').strip() for s in frame.dtypes.index]
292+
column_types = zip(safe_columns, map(lookup_type, frame.dtypes))
293+
if flavor == 'sqlite':
294+
columns = ',\n '.join('[%s] %s' % x for x in column_types)
295+
else:
296+
columns = ',\n '.join('`%s` %s' % x for x in column_types)
297+
298+
keystr = ''
299+
if keys is not None:
300+
if isinstance(keys, basestring):
301+
keys = (keys,)
302+
keystr = ', PRIMARY KEY (%s)' % ','.join(keys)
303+
template = """CREATE TABLE %(name)s (
304+
%(columns)s
305+
%(keystr)s
306+
);"""
307+
create_statement = template % {'name': name, 'columns': columns,
308+
'keystr': keystr}
309+
return create_statement
310+
311+
def sequence2dict(seq):
312+
"""Helper function for cx_Oracle.
313+
314+
For each element in the sequence, creates a dictionary item equal
315+
to the element and keyed by the position of the item in the list.
316+
>>> sequence2dict(("Matt", 1))
317+
{'1': 'Matt', '2': 1}
318+
319+
Source:
320+
http://www.gingerandjohn.com/archives/2004/02/26/cx_oracle-executemany-example/
321+
"""
322+
d = {}
323+
for k,v in zip(range(1, 1 + len(seq)), seq):
324+
d[str(k)] = v
325+
return d

0 commit comments

Comments
 (0)