Skip to content

BUG: Dynamically created table names allow SQL injection #8986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 26, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.16.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Enhancements
- ``tseries.frequencies.to_offset()`` now accepts ``Timedelta`` as input (:issue:`9064`)

- ``Timedelta`` will now accept nanoseconds keyword in constructor (:issue:`9273`)
- SQL code now safely escapes table and column names (:issue:`8986`)

Performance
~~~~~~~~~~~
Expand Down
100 changes: 71 additions & 29 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,18 +1239,58 @@ def _create_sql_schema(self, frame, table_name, keys=None):
}


def _get_unicode_name(name):
try:
uname = name.encode("utf-8", "strict").decode("utf-8")
except UnicodeError:
raise ValueError("Cannot convert identifier to UTF-8: '%s'" % name)
return uname

def _get_valid_mysql_name(name):
# Filter for unquoted identifiers
# See http://dev.mysql.com/doc/refman/5.0/en/identifiers.html
uname = _get_unicode_name(name)
if not len(uname):
raise ValueError("Empty table or column name specified")

basere = r'[0-9,a-z,A-Z$_]'
for c in uname:
if not re.match(basere, c):
if not (0x80 < ord(c) < 0xFFFF):
raise ValueError("Invalid MySQL identifier '%s'" % uname)
if not re.match(r'[^0-9]', uname):
raise ValueError('MySQL identifier cannot be entirely numeric')

return '`' + uname + '`'


def _get_valid_sqlite_name(name):
# See http://stackoverflow.com/questions/6514274/how-do-you-escape-strings-for-sqlite-table-column-names-in-python
# Ensure the string can be encoded as UTF-8.
# Ensure the string does not include any NUL characters.
# Replace all " with "".
# Wrap the entire thing in double quotes.

uname = _get_unicode_name(name)
if not len(uname):
raise ValueError("Empty table or column name specified")

nul_index = uname.find("\x00")
if nul_index >= 0:
raise ValueError('SQLite identifier cannot contain NULs')
return '"' + uname.replace('"', '""') + '"'


# SQL enquote and wildcard symbols
_SQL_SYMB = {
'mysql': {
'br_l': '`',
'br_r': '`',
'wld': '%s'
},
'sqlite': {
'br_l': '[',
'br_r': ']',
'wld': '?'
}
_SQL_WILDCARD = {
'mysql': '%s',
'sqlite': '?'
}

# Validate and return escaped identifier
_SQL_GET_IDENTIFIER = {
'mysql': _get_valid_mysql_name,
'sqlite': _get_valid_sqlite_name,
}


Expand All @@ -1276,18 +1316,17 @@ def _execute_create(self):
def insert_statement(self):
names = list(map(str, self.frame.columns))
flv = self.pd_sql.flavor
br_l = _SQL_SYMB[flv]['br_l'] # left val quote char
br_r = _SQL_SYMB[flv]['br_r'] # right val quote char
wld = _SQL_SYMB[flv]['wld'] # wildcard char
wld = _SQL_WILDCARD[flv] # wildcard char
escape = _SQL_GET_IDENTIFIER[flv]

if self.index is not None:
[names.insert(0, idx) for idx in self.index[::-1]]

bracketed_names = [br_l + column + br_r for column in names]
bracketed_names = [escape(column) for column in names]
col_names = ','.join(bracketed_names)
wildcards = ','.join([wld] * len(names))
insert_statement = 'INSERT INTO %s (%s) VALUES (%s)' % (
self.name, col_names, wildcards)
escape(self.name), col_names, wildcards)
return insert_statement

def _execute_insert(self, conn, keys, data_iter):
Expand All @@ -1309,29 +1348,28 @@ def _create_table_setup(self):
warnings.warn(_SAFE_NAMES_WARNING)

flv = self.pd_sql.flavor
escape = _SQL_GET_IDENTIFIER[flv]

br_l = _SQL_SYMB[flv]['br_l'] # left val quote char
br_r = _SQL_SYMB[flv]['br_r'] # right val quote char
create_tbl_stmts = [escape(cname) + ' ' + ctype
for cname, ctype, _ in column_names_and_types]

create_tbl_stmts = [(br_l + '%s' + br_r + ' %s') % (cname, col_type)
for cname, col_type, _ in column_names_and_types]
if self.keys is not None and len(self.keys):
cnames_br = ",".join([br_l + c + br_r for c in self.keys])
cnames_br = ",".join([escape(c) for c in self.keys])
create_tbl_stmts.append(
"CONSTRAINT {tbl}_pk PRIMARY KEY ({cnames_br})".format(
tbl=self.name, cnames_br=cnames_br))

create_stmts = ["CREATE TABLE " + self.name + " (\n" +
create_stmts = ["CREATE TABLE " + escape(self.name) + " (\n" +
',\n '.join(create_tbl_stmts) + "\n)"]

ix_cols = [cname for cname, _, is_index in column_names_and_types
if is_index]
if len(ix_cols):
cnames = "_".join(ix_cols)
cnames_br = ",".join([br_l + c + br_r for c in ix_cols])
cnames_br = ",".join([escape(c) for c in ix_cols])
create_stmts.append(
"CREATE INDEX ix_{tbl}_{cnames} ON {tbl} ({cnames_br})".format(
tbl=self.name, cnames=cnames, cnames_br=cnames_br))
"CREATE INDEX " + escape("ix_"+self.name+"_"+cnames) +
"ON " + escape(self.name) + " (" + cnames_br + ")")

return create_stmts

Expand Down Expand Up @@ -1505,19 +1543,23 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
table.insert(chunksize)

def has_table(self, name, schema=None):
escape = _SQL_GET_IDENTIFIER[self.flavor]
esc_name = escape(name)
wld = _SQL_WILDCARD[self.flavor]
flavor_map = {
'sqlite': ("SELECT name FROM sqlite_master "
"WHERE type='table' AND name='%s';") % name,
'mysql': "SHOW TABLES LIKE '%s'" % name}
"WHERE type='table' AND name=%s;") % wld,
'mysql': "SHOW TABLES LIKE %s" % wld}
query = flavor_map.get(self.flavor)

return len(self.execute(query).fetchall()) > 0
return len(self.execute(query, [name,]).fetchall()) > 0

def get_table(self, table_name, schema=None):
return None # not supported in fallback mode

def drop_table(self, name, schema=None):
drop_sql = "DROP TABLE %s" % name
escape = _SQL_GET_IDENTIFIER[self.flavor]
drop_sql = "DROP TABLE %s" % escape(name)
self.execute(drop_sql)

def _create_sql_schema(self, frame, table_name, keys=None):
Expand Down
35 changes: 33 additions & 2 deletions pandas/io/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ def test_uquery(self):
def _get_sqlite_column_type(self, schema, column):

for col in schema.split('\n'):
if col.split()[0].strip('[]') == column:
if col.split()[0].strip('""') == column:
return col.split()[1]
raise ValueError('Column %s not found' % (column))

Expand Down Expand Up @@ -1630,6 +1630,24 @@ def test_notnull_dtype(self):
self.assertEqual(self._get_sqlite_column_type(tbl, 'Int'), 'INTEGER')
self.assertEqual(self._get_sqlite_column_type(tbl, 'Float'), 'REAL')

def test_illegal_names(self):
# For sqlite, these should work fine
df = DataFrame([[1, 2], [3, 4]], columns=['a', 'b'])

# Raise error on blank
self.assertRaises(ValueError, df.to_sql, "", self.conn,
flavor=self.flavor)

for ndx, weird_name in enumerate(['test_weird_name]','test_weird_name[',
'test_weird_name`','test_weird_name"', 'test_weird_name\'']):
df.to_sql(weird_name, self.conn, flavor=self.flavor)
sql.table_exists(weird_name, self.conn)

df2 = DataFrame([[1, 2], [3, 4]], columns=['a', weird_name])
c_tbl = 'test_weird_col_name%d'%ndx
df.to_sql(c_tbl, self.conn, flavor=self.flavor)
sql.table_exists(c_tbl, self.conn)


class TestMySQLLegacy(TestSQLiteFallback):
"""
Expand Down Expand Up @@ -1721,6 +1739,19 @@ def test_to_sql_save_index(self):
def test_to_sql_save_index(self):
self._to_sql_save_index()

def test_illegal_names(self):
# For MySQL, these should raise ValueError
for ndx, illegal_name in enumerate(['test_illegal_name]','test_illegal_name[',
'test_illegal_name`','test_illegal_name"', 'test_illegal_name\'', '']):
df = DataFrame([[1, 2], [3, 4]], columns=['a', 'b'])
self.assertRaises(ValueError, df.to_sql, illegal_name, self.conn,
flavor=self.flavor, index=False)

df2 = DataFrame([[1, 2], [3, 4]], columns=['a', illegal_name])
c_tbl = 'test_illegal_col_name%d'%ndx
self.assertRaises(ValueError, df2.to_sql, 'test_illegal_col_name',
self.conn, flavor=self.flavor, index=False)


#------------------------------------------------------------------------------
#--- Old tests from 0.13.1 (before refactor using sqlalchemy)
Expand Down Expand Up @@ -1817,7 +1848,7 @@ def test_schema(self):
frame = tm.makeTimeDataFrame()
create_sql = sql.get_schema(frame, 'test', 'sqlite', keys=['A', 'B'],)
lines = create_sql.splitlines()
self.assertTrue('PRIMARY KEY ([A],[B])' in create_sql)
self.assertTrue('PRIMARY KEY ("A","B")' in create_sql)
cur = self.db.cursor()
cur.execute(create_sql)

Expand Down