Skip to content

Commit 0222d39

Browse files
danielballanjreback
authored andcommitted
ENH #4163 Use SQLAlchemy for DB abstraction
TST Import sqlalchemy on Travis. DOC add docstrings to read sql ENH read_sql connects via Connection, Engine, file path, or :memory: string CLN Separate legacy code into new file, and fallback so that all old tests pass. TST to use sqlachemy syntax in tests CLN sql into classes, legacy passes FIX few engine vs con calls CLN pep8 cleanup add postgres support for pandas.io.sql.get_schema WIP: cleaup of sql io module - imported correct SQLALCHEMY type, delete redundant PandasSQLWithCon TODO: renamed _engine_read_table, need to think of a better name. TODO: clean up get_conneciton function ENH: cleanup of SQL io TODO: check that legacy mode works TODO: run tests correctly enabled coerce_float option Cleanup and bug-fixing mainly on legacy mode sql. IMPORTANT - changed legacy to require connection rather than cursor. This is still not yet finalized. TODO: tests and doc Added Test coverage for basic functionality using in-memory SQLite database Simplified API by automatically distinguishing between engine and connection. Added warnings
1 parent e7e5621 commit 0222d39

File tree

3 files changed

+980
-0
lines changed

3 files changed

+980
-0
lines changed

pandas/io/sql_legacy.py

+332
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
"""
2+
Collection of query wrappers / abstractions to both facilitate data
3+
retrieval and to reduce dependency on DB-specific API.
4+
"""
5+
from datetime import datetime, date
6+
7+
import numpy as np
8+
import traceback
9+
10+
from pandas.core.datetools import format as date_format
11+
from pandas.core.api import DataFrame, isnull
12+
13+
#------------------------------------------------------------------------------
14+
# Helper execution function
15+
16+
17+
def execute(sql, con, retry=True, cur=None, params=None):
18+
"""
19+
Execute the given SQL query using the provided connection object.
20+
21+
Parameters
22+
----------
23+
sql: string
24+
Query to be executed
25+
con: database connection instance
26+
Database connection. Must implement PEP249 (Database API v2.0).
27+
retry: bool
28+
Not currently implemented
29+
cur: database cursor, optional
30+
Must implement PEP249 (Datbase API v2.0). If cursor is not provided,
31+
one will be obtained from the database connection.
32+
params: list or tuple, optional
33+
List of parameters to pass to execute method.
34+
35+
Returns
36+
-------
37+
Cursor object
38+
"""
39+
try:
40+
if cur is None:
41+
cur = con.cursor()
42+
43+
if params is None:
44+
cur.execute(sql)
45+
else:
46+
cur.execute(sql, params)
47+
return cur
48+
except Exception:
49+
try:
50+
con.rollback()
51+
except Exception: # pragma: no cover
52+
pass
53+
54+
print ('Error on sql %s' % sql)
55+
raise
56+
57+
58+
def _safe_fetch(cur):
59+
try:
60+
result = cur.fetchall()
61+
if not isinstance(result, list):
62+
result = list(result)
63+
return result
64+
except Exception, e: # pragma: no cover
65+
excName = e.__class__.__name__
66+
if excName == 'OperationalError':
67+
return []
68+
69+
70+
def tquery(sql, con=None, cur=None, retry=True):
71+
"""
72+
Returns list of tuples corresponding to each row in given sql
73+
query.
74+
75+
If only one column selected, then plain list is returned.
76+
77+
Parameters
78+
----------
79+
sql: string
80+
SQL query to be executed
81+
con: SQLConnection or DB API 2.0-compliant connection
82+
cur: DB API 2.0 cursor
83+
84+
Provide a specific connection or a specific cursor if you are executing a
85+
lot of sequential statements and want to commit outside.
86+
"""
87+
cur = execute(sql, con, cur=cur)
88+
result = _safe_fetch(cur)
89+
90+
if con is not None:
91+
try:
92+
cur.close()
93+
con.commit()
94+
except Exception as e:
95+
excName = e.__class__.__name__
96+
if excName == 'OperationalError': # pragma: no cover
97+
print ('Failed to commit, may need to restart interpreter')
98+
else:
99+
raise
100+
101+
traceback.print_exc()
102+
if retry:
103+
return tquery(sql, con=con, retry=False)
104+
105+
if result and len(result[0]) == 1:
106+
# python 3 compat
107+
result = list(list(zip(*result))[0])
108+
elif result is None: # pragma: no cover
109+
result = []
110+
111+
return result
112+
113+
114+
def uquery(sql, con=None, cur=None, retry=True, params=None):
115+
"""
116+
Does the same thing as tquery, but instead of returning results, it
117+
returns the number of rows affected. Good for update queries.
118+
"""
119+
cur = execute(sql, con, cur=cur, retry=retry, params=params)
120+
121+
result = cur.rowcount
122+
try:
123+
con.commit()
124+
except Exception as e:
125+
excName = e.__class__.__name__
126+
if excName != 'OperationalError':
127+
raise
128+
129+
traceback.print_exc()
130+
if retry:
131+
print ('Looks like your connection failed, reconnecting...')
132+
return uquery(sql, con, retry=False)
133+
return result
134+
135+
136+
def read_frame(sql, con, index_col=None, coerce_float=True, params=None):
137+
"""
138+
Returns a DataFrame corresponding to the result set of the query
139+
string.
140+
141+
Optionally provide an index_col parameter to use one of the
142+
columns as the index. Otherwise will be 0 to len(results) - 1.
143+
144+
Parameters
145+
----------
146+
sql: string
147+
SQL query to be executed
148+
con: DB connection object, optional
149+
index_col: string, optional
150+
column name to use for the returned DataFrame object.
151+
coerce_float : boolean, default True
152+
Attempt to convert values to non-string, non-numeric objects (like
153+
decimal.Decimal) to floating point, useful for SQL result sets
154+
params: list or tuple, optional
155+
List of parameters to pass to execute method.
156+
"""
157+
cur = execute(sql, con, params=params)
158+
rows = _safe_fetch(cur)
159+
columns = [col_desc[0] for col_desc in cur.description]
160+
161+
cur.close()
162+
con.commit()
163+
164+
result = DataFrame.from_records(rows, columns=columns,
165+
coerce_float=coerce_float)
166+
167+
if index_col is not None:
168+
result = result.set_index(index_col)
169+
170+
return result
171+
172+
frame_query = read_frame
173+
read_sql = read_frame
174+
175+
176+
def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs):
177+
"""
178+
Write records stored in a DataFrame to a SQL database.
179+
180+
Parameters
181+
----------
182+
frame: DataFrame
183+
name: name of SQL table
184+
con: an open SQL database connection object
185+
flavor: {'sqlite', 'mysql', 'oracle'}, default 'sqlite'
186+
if_exists: {'fail', 'replace', 'append'}, default 'fail'
187+
fail: If table exists, do nothing.
188+
replace: If table exists, drop it, recreate it, and insert data.
189+
append: If table exists, insert data. Create if does not exist.
190+
"""
191+
192+
if 'append' in kwargs:
193+
import warnings
194+
warnings.warn("append is deprecated, use if_exists instead",
195+
FutureWarning)
196+
if kwargs['append']:
197+
if_exists='append'
198+
else:
199+
if_exists='fail'
200+
exists = table_exists(name, con, flavor)
201+
if if_exists == 'fail' and exists:
202+
raise ValueError, "Table '%s' already exists." % name
203+
204+
#create or drop-recreate if necessary
205+
create = None
206+
if exists and if_exists == 'replace':
207+
create = "DROP TABLE %s" % name
208+
elif not exists:
209+
create = get_schema(frame, name, flavor)
210+
211+
if create is not None:
212+
cur = con.cursor()
213+
cur.execute(create)
214+
cur.close()
215+
216+
cur = con.cursor()
217+
# Replace spaces in DataFrame column names with _.
218+
safe_names = [s.replace(' ', '_').strip() for s in frame.columns]
219+
flavor_picker = {'sqlite' : _write_sqlite,
220+
'mysql' : _write_mysql}
221+
222+
func = flavor_picker.get(flavor, None)
223+
if func is None:
224+
raise NotImplementedError
225+
func(frame, name, safe_names, cur)
226+
cur.close()
227+
con.commit()
228+
229+
230+
def _write_sqlite(frame, table, names, cur):
231+
bracketed_names = ['[' + column + ']' for column in names]
232+
col_names = ','.join(bracketed_names)
233+
wildcards = ','.join(['?'] * len(names))
234+
insert_query = 'INSERT INTO %s (%s) VALUES (%s)' % (
235+
table, col_names, wildcards)
236+
# pandas types are badly handled if there is only 1 column ( Issue #3628 )
237+
if not len(frame.columns )==1 :
238+
data = [tuple(x) for x in frame.values]
239+
else :
240+
data = [tuple(x) for x in frame.values.tolist()]
241+
cur.executemany(insert_query, data)
242+
243+
244+
def _write_mysql(frame, table, names, cur):
245+
bracketed_names = ['`' + column + '`' for column in names]
246+
col_names = ','.join(bracketed_names)
247+
wildcards = ','.join([r'%s'] * len(names))
248+
insert_query = "INSERT INTO %s (%s) VALUES (%s)" % (
249+
table, col_names, wildcards)
250+
data = [tuple(x) for x in frame.values]
251+
cur.executemany(insert_query, data)
252+
253+
254+
def table_exists(name, con, flavor):
255+
flavor_map = {
256+
'sqlite': ("SELECT name FROM sqlite_master "
257+
"WHERE type='table' AND name='%s';") % name,
258+
'mysql' : "SHOW TABLES LIKE '%s'" % name}
259+
query = flavor_map.get(flavor, None)
260+
if query is None:
261+
raise NotImplementedError
262+
return len(tquery(query, con)) > 0
263+
264+
265+
def get_sqltype(pytype, flavor):
266+
sqltype = {'mysql': 'VARCHAR (63)',
267+
'sqlite': 'TEXT'}
268+
269+
if issubclass(pytype, np.floating):
270+
sqltype['mysql'] = 'FLOAT'
271+
sqltype['sqlite'] = 'REAL'
272+
273+
if issubclass(pytype, np.integer):
274+
#TODO: Refine integer size.
275+
sqltype['mysql'] = 'BIGINT'
276+
sqltype['sqlite'] = 'INTEGER'
277+
278+
if issubclass(pytype, np.datetime64) or pytype is datetime:
279+
# Caution: np.datetime64 is also a subclass of np.number.
280+
sqltype['mysql'] = 'DATETIME'
281+
sqltype['sqlite'] = 'TIMESTAMP'
282+
283+
if pytype is datetime.date:
284+
sqltype['mysql'] = 'DATE'
285+
sqltype['sqlite'] = 'TIMESTAMP'
286+
287+
if issubclass(pytype, np.bool_):
288+
sqltype['sqlite'] = 'INTEGER'
289+
290+
return sqltype[flavor]
291+
292+
293+
def get_schema(frame, name, flavor, keys=None):
294+
"Return a CREATE TABLE statement to suit the contents of a DataFrame."
295+
lookup_type = lambda dtype: get_sqltype(dtype.type, flavor)
296+
# Replace spaces in DataFrame column names with _.
297+
safe_columns = [s.replace(' ', '_').strip() for s in frame.dtypes.index]
298+
column_types = zip(safe_columns, map(lookup_type, frame.dtypes))
299+
if flavor == 'sqlite':
300+
columns = ',\n '.join('[%s] %s' % x for x in column_types)
301+
else:
302+
columns = ',\n '.join('`%s` %s' % x for x in column_types)
303+
304+
keystr = ''
305+
if keys is not None:
306+
if isinstance(keys, basestring):
307+
keys = (keys,)
308+
keystr = ', PRIMARY KEY (%s)' % ','.join(keys)
309+
template = """CREATE TABLE %(name)s (
310+
%(columns)s
311+
%(keystr)s
312+
);"""
313+
create_statement = template % {'name': name, 'columns': columns,
314+
'keystr': keystr}
315+
return create_statement
316+
317+
318+
def sequence2dict(seq):
319+
"""Helper function for cx_Oracle.
320+
321+
For each element in the sequence, creates a dictionary item equal
322+
to the element and keyed by the position of the item in the list.
323+
>>> sequence2dict(("Matt", 1))
324+
{'1': 'Matt', '2': 1}
325+
326+
Source:
327+
http://www.gingerandjohn.com/archives/2004/02/26/cx_oracle-executemany-example/
328+
"""
329+
d = {}
330+
for k,v in zip(range(1, 1 + len(seq)), seq):
331+
d[str(k)] = v
332+
return d

0 commit comments

Comments
 (0)