Skip to content

Commit a774ee8

Browse files
Merge pull request #8986 from artemyk/sql_injection_fix
BUG: Dynamically created table names allow SQL injection
2 parents 327340b + 337b94d commit a774ee8

File tree

3 files changed

+105
-31
lines changed

3 files changed

+105
-31
lines changed

doc/source/whatsnew/v0.16.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ Enhancements
106106
- ``tseries.frequencies.to_offset()`` now accepts ``Timedelta`` as input (:issue:`9064`)
107107

108108
- ``Timedelta`` will now accept nanoseconds keyword in constructor (:issue:`9273`)
109+
- SQL code now safely escapes table and column names (:issue:`8986`)
109110

110111
- Added auto-complete for ``Series.str.<tab>``, ``Series.dt.<tab>`` and ``Series.cat.<tab>`` (:issue:`9322`)
111112

pandas/io/sql.py

+71-29
Original file line numberDiff line numberDiff line change
@@ -1253,18 +1253,58 @@ def _create_sql_schema(self, frame, table_name, keys=None, dtype=None):
12531253
}
12541254

12551255

1256+
def _get_unicode_name(name):
1257+
try:
1258+
uname = name.encode("utf-8", "strict").decode("utf-8")
1259+
except UnicodeError:
1260+
raise ValueError("Cannot convert identifier to UTF-8: '%s'" % name)
1261+
return uname
1262+
1263+
def _get_valid_mysql_name(name):
1264+
# Filter for unquoted identifiers
1265+
# See http://dev.mysql.com/doc/refman/5.0/en/identifiers.html
1266+
uname = _get_unicode_name(name)
1267+
if not len(uname):
1268+
raise ValueError("Empty table or column name specified")
1269+
1270+
basere = r'[0-9,a-z,A-Z$_]'
1271+
for c in uname:
1272+
if not re.match(basere, c):
1273+
if not (0x80 < ord(c) < 0xFFFF):
1274+
raise ValueError("Invalid MySQL identifier '%s'" % uname)
1275+
if not re.match(r'[^0-9]', uname):
1276+
raise ValueError('MySQL identifier cannot be entirely numeric')
1277+
1278+
return '`' + uname + '`'
1279+
1280+
1281+
def _get_valid_sqlite_name(name):
1282+
# See http://stackoverflow.com/questions/6514274/how-do-you-escape-strings-for-sqlite-table-column-names-in-python
1283+
# Ensure the string can be encoded as UTF-8.
1284+
# Ensure the string does not include any NUL characters.
1285+
# Replace all " with "".
1286+
# Wrap the entire thing in double quotes.
1287+
1288+
uname = _get_unicode_name(name)
1289+
if not len(uname):
1290+
raise ValueError("Empty table or column name specified")
1291+
1292+
nul_index = uname.find("\x00")
1293+
if nul_index >= 0:
1294+
raise ValueError('SQLite identifier cannot contain NULs')
1295+
return '"' + uname.replace('"', '""') + '"'
1296+
1297+
12561298
# SQL enquote and wildcard symbols
1257-
_SQL_SYMB = {
1258-
'mysql': {
1259-
'br_l': '`',
1260-
'br_r': '`',
1261-
'wld': '%s'
1262-
},
1263-
'sqlite': {
1264-
'br_l': '[',
1265-
'br_r': ']',
1266-
'wld': '?'
1267-
}
1299+
_SQL_WILDCARD = {
1300+
'mysql': '%s',
1301+
'sqlite': '?'
1302+
}
1303+
1304+
# Validate and return escaped identifier
1305+
_SQL_GET_IDENTIFIER = {
1306+
'mysql': _get_valid_mysql_name,
1307+
'sqlite': _get_valid_sqlite_name,
12681308
}
12691309

12701310

@@ -1290,18 +1330,17 @@ def _execute_create(self):
12901330
def insert_statement(self):
12911331
names = list(map(str, self.frame.columns))
12921332
flv = self.pd_sql.flavor
1293-
br_l = _SQL_SYMB[flv]['br_l'] # left val quote char
1294-
br_r = _SQL_SYMB[flv]['br_r'] # right val quote char
1295-
wld = _SQL_SYMB[flv]['wld'] # wildcard char
1333+
wld = _SQL_WILDCARD[flv] # wildcard char
1334+
escape = _SQL_GET_IDENTIFIER[flv]
12961335

12971336
if self.index is not None:
12981337
[names.insert(0, idx) for idx in self.index[::-1]]
12991338

1300-
bracketed_names = [br_l + column + br_r for column in names]
1339+
bracketed_names = [escape(column) for column in names]
13011340
col_names = ','.join(bracketed_names)
13021341
wildcards = ','.join([wld] * len(names))
13031342
insert_statement = 'INSERT INTO %s (%s) VALUES (%s)' % (
1304-
self.name, col_names, wildcards)
1343+
escape(self.name), col_names, wildcards)
13051344
return insert_statement
13061345

13071346
def _execute_insert(self, conn, keys, data_iter):
@@ -1323,29 +1362,28 @@ def _create_table_setup(self):
13231362
warnings.warn(_SAFE_NAMES_WARNING)
13241363

13251364
flv = self.pd_sql.flavor
1365+
escape = _SQL_GET_IDENTIFIER[flv]
13261366

1327-
br_l = _SQL_SYMB[flv]['br_l'] # left val quote char
1328-
br_r = _SQL_SYMB[flv]['br_r'] # right val quote char
1367+
create_tbl_stmts = [escape(cname) + ' ' + ctype
1368+
for cname, ctype, _ in column_names_and_types]
13291369

1330-
create_tbl_stmts = [(br_l + '%s' + br_r + ' %s') % (cname, col_type)
1331-
for cname, col_type, _ in column_names_and_types]
13321370
if self.keys is not None and len(self.keys):
1333-
cnames_br = ",".join([br_l + c + br_r for c in self.keys])
1371+
cnames_br = ",".join([escape(c) for c in self.keys])
13341372
create_tbl_stmts.append(
13351373
"CONSTRAINT {tbl}_pk PRIMARY KEY ({cnames_br})".format(
13361374
tbl=self.name, cnames_br=cnames_br))
13371375

1338-
create_stmts = ["CREATE TABLE " + self.name + " (\n" +
1376+
create_stmts = ["CREATE TABLE " + escape(self.name) + " (\n" +
13391377
',\n '.join(create_tbl_stmts) + "\n)"]
13401378

13411379
ix_cols = [cname for cname, _, is_index in column_names_and_types
13421380
if is_index]
13431381
if len(ix_cols):
13441382
cnames = "_".join(ix_cols)
1345-
cnames_br = ",".join([br_l + c + br_r for c in ix_cols])
1383+
cnames_br = ",".join([escape(c) for c in ix_cols])
13461384
create_stmts.append(
1347-
"CREATE INDEX ix_{tbl}_{cnames} ON {tbl} ({cnames_br})".format(
1348-
tbl=self.name, cnames=cnames, cnames_br=cnames_br))
1385+
"CREATE INDEX " + escape("ix_"+self.name+"_"+cnames) +
1386+
"ON " + escape(self.name) + " (" + cnames_br + ")")
13491387

13501388
return create_stmts
13511389

@@ -1519,19 +1557,23 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
15191557
table.insert(chunksize)
15201558

15211559
def has_table(self, name, schema=None):
1560+
escape = _SQL_GET_IDENTIFIER[self.flavor]
1561+
esc_name = escape(name)
1562+
wld = _SQL_WILDCARD[self.flavor]
15221563
flavor_map = {
15231564
'sqlite': ("SELECT name FROM sqlite_master "
1524-
"WHERE type='table' AND name='%s';") % name,
1525-
'mysql': "SHOW TABLES LIKE '%s'" % name}
1565+
"WHERE type='table' AND name=%s;") % wld,
1566+
'mysql': "SHOW TABLES LIKE %s" % wld}
15261567
query = flavor_map.get(self.flavor)
15271568

1528-
return len(self.execute(query).fetchall()) > 0
1569+
return len(self.execute(query, [name,]).fetchall()) > 0
15291570

15301571
def get_table(self, table_name, schema=None):
15311572
return None # not supported in fallback mode
15321573

15331574
def drop_table(self, name, schema=None):
1534-
drop_sql = "DROP TABLE %s" % name
1575+
escape = _SQL_GET_IDENTIFIER[self.flavor]
1576+
drop_sql = "DROP TABLE %s" % escape(name)
15351577
self.execute(drop_sql)
15361578

15371579
def _create_sql_schema(self, frame, table_name, keys=None, dtype=None):

pandas/io/tests/test_sql.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ def test_uquery(self):
873873
def _get_sqlite_column_type(self, schema, column):
874874

875875
for col in schema.split('\n'):
876-
if col.split()[0].strip('[]') == column:
876+
if col.split()[0].strip('""') == column:
877877
return col.split()[1]
878878
raise ValueError('Column %s not found' % (column))
879879

@@ -1667,6 +1667,24 @@ def test_notnull_dtype(self):
16671667
self.assertEqual(self._get_sqlite_column_type(tbl, 'Int'), 'INTEGER')
16681668
self.assertEqual(self._get_sqlite_column_type(tbl, 'Float'), 'REAL')
16691669

1670+
def test_illegal_names(self):
1671+
# For sqlite, these should work fine
1672+
df = DataFrame([[1, 2], [3, 4]], columns=['a', 'b'])
1673+
1674+
# Raise error on blank
1675+
self.assertRaises(ValueError, df.to_sql, "", self.conn,
1676+
flavor=self.flavor)
1677+
1678+
for ndx, weird_name in enumerate(['test_weird_name]','test_weird_name[',
1679+
'test_weird_name`','test_weird_name"', 'test_weird_name\'']):
1680+
df.to_sql(weird_name, self.conn, flavor=self.flavor)
1681+
sql.table_exists(weird_name, self.conn)
1682+
1683+
df2 = DataFrame([[1, 2], [3, 4]], columns=['a', weird_name])
1684+
c_tbl = 'test_weird_col_name%d'%ndx
1685+
df.to_sql(c_tbl, self.conn, flavor=self.flavor)
1686+
sql.table_exists(c_tbl, self.conn)
1687+
16701688

16711689
class TestMySQLLegacy(TestSQLiteFallback):
16721690
"""
@@ -1758,6 +1776,19 @@ def test_to_sql_save_index(self):
17581776
def test_to_sql_save_index(self):
17591777
self._to_sql_save_index()
17601778

1779+
def test_illegal_names(self):
1780+
# For MySQL, these should raise ValueError
1781+
for ndx, illegal_name in enumerate(['test_illegal_name]','test_illegal_name[',
1782+
'test_illegal_name`','test_illegal_name"', 'test_illegal_name\'', '']):
1783+
df = DataFrame([[1, 2], [3, 4]], columns=['a', 'b'])
1784+
self.assertRaises(ValueError, df.to_sql, illegal_name, self.conn,
1785+
flavor=self.flavor, index=False)
1786+
1787+
df2 = DataFrame([[1, 2], [3, 4]], columns=['a', illegal_name])
1788+
c_tbl = 'test_illegal_col_name%d'%ndx
1789+
self.assertRaises(ValueError, df2.to_sql, 'test_illegal_col_name',
1790+
self.conn, flavor=self.flavor, index=False)
1791+
17611792

17621793
#------------------------------------------------------------------------------
17631794
#--- Old tests from 0.13.1 (before refactor using sqlalchemy)
@@ -1854,7 +1885,7 @@ def test_schema(self):
18541885
frame = tm.makeTimeDataFrame()
18551886
create_sql = sql.get_schema(frame, 'test', 'sqlite', keys=['A', 'B'],)
18561887
lines = create_sql.splitlines()
1857-
self.assertTrue('PRIMARY KEY ([A],[B])' in create_sql)
1888+
self.assertTrue('PRIMARY KEY ("A","B")' in create_sql)
18581889
cur = self.db.cursor()
18591890
cur.execute(create_sql)
18601891

0 commit comments

Comments
 (0)