Skip to content

Commit 7179732

Browse files
committed
read_sql connects via Connection, Engine, file path, or :memory: string
1 parent 03355c4 commit 7179732

File tree

1 file changed

+96
-18
lines changed

1 file changed

+96
-18
lines changed

pandas/io/sql.py

+96-18
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import numpy as np
88
import traceback
99

10+
import sqlite3
11+
import warnings
12+
1013
from pandas.core.datetools import format as date_format
1114
from pandas.core.api import DataFrame, isnull
1215

@@ -132,10 +135,81 @@ def uquery(sql, con=None, cur=None, retry=True, params=None):
132135
return uquery(sql, con, retry=False)
133136
return result
134137

138+
class SQLAlchemyRequired(Exception):
139+
pass
135140

136-
def read_sql(sql, con=None, index_col=None,
137-
user=None, passwd=None, host=None, db=None, flavor=None,
138-
coerce_float=True, params=None):
141+
def get_connection(con, dialect, driver, username, password,
142+
host, port, database):
143+
if isinstance(con, basestring):
144+
try:
145+
import sqlalchemy
146+
return _alchemy_connect_sqlite(con)
147+
except:
148+
return sqlite3.connect(con)
149+
if isinstance(con, sqlite3.Connection):
150+
return con
151+
# If we reach here, SQLAlchemy will be needed.
152+
try:
153+
import sqlalchemy
154+
except ImportError:
155+
raise SQLAlchemyRequired
156+
if isinstance(con, sqlalchemy.engine.Engine):
157+
return con.connect()
158+
if isinstance(con, sqlalchemy.engine.Connection):
159+
return con
160+
if con is None:
161+
url_params = (dialect, driver, username, \
162+
password, host, port, database)
163+
url = _build_url(*url_params)
164+
engine = sqlalchemy.create_engine(url)
165+
return engine.connect()
166+
if hasattr(con, 'cursor') and callable(con.cursor):
167+
# This looks like some Connection object from a driver module.
168+
try:
169+
import MySQLdb
170+
warnings.warn("For more robust support, connect using " \
171+
"SQLAlchemy. See documentation.")
172+
return conn.cursor() # behaves like a sqlalchemy Connection
173+
except ImportError:
174+
pass
175+
raise NotImplementedError, \
176+
"""To ensure robust support of varied SQL dialects, pandas
177+
only support database connections from SQLAlchemy. See
178+
documentation."""
179+
else:
180+
raise ValueError, \
181+
"""con must be a string, a Connection to a sqlite Database,
182+
or a SQLAlchemy Connection or Engine object."""
183+
184+
185+
def _alchemy_connect_sqlite(path):
186+
if path == ':memory:':
187+
return create_engine('sqlite://').connect()
188+
else:
189+
return create_engine('sqlite:///%s' % path).connect()
190+
191+
def _build_url(dialect, driver, username, password, host, port, database):
192+
# Create an Engine and from that a Connection.
193+
# We use a string instead of sqlalchemy.engine.url.URL because
194+
# we do not necessarily know the driver; we know the dialect.
195+
required_params = [dialect, username, password, host, database]
196+
for p in required_params:
197+
if not isinstance(p, basestring):
198+
raise ValueError, \
199+
"Insufficient information to connect to a database;" \
200+
"see docstring."
201+
url = dialect
202+
if driver is not None:
203+
url += "+%s" % driver
204+
url += "://%s:%s@%s" % (username, password, host)
205+
if port is not None:
206+
url += ":%d" % port
207+
url += "/%s" % database
208+
return url
209+
210+
def read_sql(sql, con=None, index_col=None, flavor=None, driver=None,
211+
username=None, password=None, host=None, port=None,
212+
database=None, coerce_float=True, params=None):
139213
"""
140214
Returns a DataFrame corresponding to the result set of the query
141215
string.
@@ -147,34 +221,38 @@ def read_sql(sql, con=None, index_col=None,
147221
----------
148222
sql: string
149223
SQL query to be executed
150-
con : Connection object, SQLAlchemy Engine object, or a filepath (sqlite
151-
only). Alternatively, specify a user, passwd, host, and db below.
224+
con : Connection object, SQLAlchemy Engine object, a filepath string
225+
(sqlite only) or the string ':memory:' (sqlite only). Alternatively,
226+
specify a user, passwd, host, and db below.
152227
index_col: string, optional
153228
column name to use for the returned DataFrame object.
154-
user: username for database authentication
229+
flavor : string specifying the flavor of SQL to use
230+
driver : string specifying SQL driver (e.g., MySQLdb), optional
231+
username: username for database authentication
155232
only needed if a Connection, Engine, or filepath are not given
156-
passwd: password for database authentication
233+
password: password for database authentication
157234
only needed if a Connection, Engine, or filepath are not given
158235
host: host for database connection
159236
only needed if a Connection, Engine, or filepath are not given
160-
db: database name
237+
database: database name
161238
only needed if a Connection, Engine, or filepath are not given
162-
flavor : string specifying the flavor of SQL to use
163239
coerce_float : boolean, default True
164240
Attempt to convert values to non-string, non-numeric objects (like
165241
decimal.Decimal) to floating point, useful for SQL result sets
166242
params: list or tuple, optional
167243
List of parameters to pass to execute method.
168244
"""
169-
cur = execute(sql, con, params=params)
170-
rows = _safe_fetch(cur)
171-
columns = [col_desc[0] for col_desc in cur.description]
172-
173-
cur.close()
174-
con.commit()
175-
176-
result = DataFrame.from_records(rows, columns=columns,
177-
coerce_float=coerce_float)
245+
dialect = flavor
246+
connection = get_connection(con, dialect, driver, username, password,
247+
host, port, database)
248+
if params is None:
249+
params = []
250+
cursor = connection.execute(sql, *params)
251+
result = _safe_fetch(cursor)
252+
columns = [col_desc[0] for col_desc in cursor.description]
253+
cursor.close()
254+
255+
result = DataFrame.from_records(result, columns=columns)
178256

179257
if index_col is not None:
180258
result = result.set_index(index_col)

0 commit comments

Comments
 (0)