From 337b94d3f5be6d9beca2386f7d0188dfc3e8fa08 Mon Sep 17 00:00:00 2001 From: Artemy Kolchinsky Date: Wed, 3 Dec 2014 14:44:44 -0500 Subject: [PATCH] BUG: Dynamically created table names allow SQL injection Cleanup doc Check for empty identifiers Tests fix Tests pass Doc update Error catching --- doc/source/whatsnew/v0.16.0.txt | 1 + pandas/io/sql.py | 100 +++++++++++++++++++++++--------- pandas/io/tests/test_sql.py | 35 ++++++++++- 3 files changed, 105 insertions(+), 31 deletions(-) diff --git a/doc/source/whatsnew/v0.16.0.txt b/doc/source/whatsnew/v0.16.0.txt index b3ac58a9fb84a..7a47cc68fb627 100644 --- a/doc/source/whatsnew/v0.16.0.txt +++ b/doc/source/whatsnew/v0.16.0.txt @@ -106,6 +106,7 @@ Enhancements - ``tseries.frequencies.to_offset()`` now accepts ``Timedelta`` as input (:issue:`9064`) - ``Timedelta`` will now accept nanoseconds keyword in constructor (:issue:`9273`) +- SQL code now safely escapes table and column names (:issue:`8986`) Performance ~~~~~~~~~~~ diff --git a/pandas/io/sql.py b/pandas/io/sql.py index b4318bdc2a3bf..87c86e8ef91a8 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -1239,18 +1239,58 @@ def _create_sql_schema(self, frame, table_name, keys=None): } +def _get_unicode_name(name): + try: + uname = name.encode("utf-8", "strict").decode("utf-8") + except UnicodeError: + raise ValueError("Cannot convert identifier to UTF-8: '%s'" % name) + return uname + +def _get_valid_mysql_name(name): + # Filter for unquoted identifiers + # See http://dev.mysql.com/doc/refman/5.0/en/identifiers.html + uname = _get_unicode_name(name) + if not len(uname): + raise ValueError("Empty table or column name specified") + + basere = r'[0-9,a-z,A-Z$_]' + for c in uname: + if not re.match(basere, c): + if not (0x80 < ord(c) < 0xFFFF): + raise ValueError("Invalid MySQL identifier '%s'" % uname) + if not re.match(r'[^0-9]', uname): + raise ValueError('MySQL identifier cannot be entirely numeric') + + return '`' + uname + '`' + + +def _get_valid_sqlite_name(name): + # See http://stackoverflow.com/questions/6514274/how-do-you-escape-strings-for-sqlite-table-column-names-in-python + # Ensure the string can be encoded as UTF-8. + # Ensure the string does not include any NUL characters. + # Replace all " with "". + # Wrap the entire thing in double quotes. + + uname = _get_unicode_name(name) + if not len(uname): + raise ValueError("Empty table or column name specified") + + nul_index = uname.find("\x00") + if nul_index >= 0: + raise ValueError('SQLite identifier cannot contain NULs') + return '"' + uname.replace('"', '""') + '"' + + # SQL enquote and wildcard symbols -_SQL_SYMB = { - 'mysql': { - 'br_l': '`', - 'br_r': '`', - 'wld': '%s' - }, - 'sqlite': { - 'br_l': '[', - 'br_r': ']', - 'wld': '?' - } +_SQL_WILDCARD = { + 'mysql': '%s', + 'sqlite': '?' +} + +# Validate and return escaped identifier +_SQL_GET_IDENTIFIER = { + 'mysql': _get_valid_mysql_name, + 'sqlite': _get_valid_sqlite_name, } @@ -1276,18 +1316,17 @@ def _execute_create(self): def insert_statement(self): names = list(map(str, self.frame.columns)) flv = self.pd_sql.flavor - br_l = _SQL_SYMB[flv]['br_l'] # left val quote char - br_r = _SQL_SYMB[flv]['br_r'] # right val quote char - wld = _SQL_SYMB[flv]['wld'] # wildcard char + wld = _SQL_WILDCARD[flv] # wildcard char + escape = _SQL_GET_IDENTIFIER[flv] if self.index is not None: [names.insert(0, idx) for idx in self.index[::-1]] - bracketed_names = [br_l + column + br_r for column in names] + bracketed_names = [escape(column) for column in names] col_names = ','.join(bracketed_names) wildcards = ','.join([wld] * len(names)) insert_statement = 'INSERT INTO %s (%s) VALUES (%s)' % ( - self.name, col_names, wildcards) + escape(self.name), col_names, wildcards) return insert_statement def _execute_insert(self, conn, keys, data_iter): @@ -1309,29 +1348,28 @@ def _create_table_setup(self): warnings.warn(_SAFE_NAMES_WARNING) flv = self.pd_sql.flavor + escape = _SQL_GET_IDENTIFIER[flv] - br_l = _SQL_SYMB[flv]['br_l'] # left val quote char - br_r = _SQL_SYMB[flv]['br_r'] # right val quote char + create_tbl_stmts = [escape(cname) + ' ' + ctype + for cname, ctype, _ in column_names_and_types] - create_tbl_stmts = [(br_l + '%s' + br_r + ' %s') % (cname, col_type) - for cname, col_type, _ in column_names_and_types] if self.keys is not None and len(self.keys): - cnames_br = ",".join([br_l + c + br_r for c in self.keys]) + cnames_br = ",".join([escape(c) for c in self.keys]) create_tbl_stmts.append( "CONSTRAINT {tbl}_pk PRIMARY KEY ({cnames_br})".format( tbl=self.name, cnames_br=cnames_br)) - create_stmts = ["CREATE TABLE " + self.name + " (\n" + + create_stmts = ["CREATE TABLE " + escape(self.name) + " (\n" + ',\n '.join(create_tbl_stmts) + "\n)"] ix_cols = [cname for cname, _, is_index in column_names_and_types if is_index] if len(ix_cols): cnames = "_".join(ix_cols) - cnames_br = ",".join([br_l + c + br_r for c in ix_cols]) + cnames_br = ",".join([escape(c) for c in ix_cols]) create_stmts.append( - "CREATE INDEX ix_{tbl}_{cnames} ON {tbl} ({cnames_br})".format( - tbl=self.name, cnames=cnames, cnames_br=cnames_br)) + "CREATE INDEX " + escape("ix_"+self.name+"_"+cnames) + + "ON " + escape(self.name) + " (" + cnames_br + ")") return create_stmts @@ -1505,19 +1543,23 @@ def to_sql(self, frame, name, if_exists='fail', index=True, table.insert(chunksize) def has_table(self, name, schema=None): + escape = _SQL_GET_IDENTIFIER[self.flavor] + esc_name = escape(name) + wld = _SQL_WILDCARD[self.flavor] flavor_map = { 'sqlite': ("SELECT name FROM sqlite_master " - "WHERE type='table' AND name='%s';") % name, - 'mysql': "SHOW TABLES LIKE '%s'" % name} + "WHERE type='table' AND name=%s;") % wld, + 'mysql': "SHOW TABLES LIKE %s" % wld} query = flavor_map.get(self.flavor) - return len(self.execute(query).fetchall()) > 0 + return len(self.execute(query, [name,]).fetchall()) > 0 def get_table(self, table_name, schema=None): return None # not supported in fallback mode def drop_table(self, name, schema=None): - drop_sql = "DROP TABLE %s" % name + escape = _SQL_GET_IDENTIFIER[self.flavor] + drop_sql = "DROP TABLE %s" % escape(name) self.execute(drop_sql) def _create_sql_schema(self, frame, table_name, keys=None): diff --git a/pandas/io/tests/test_sql.py b/pandas/io/tests/test_sql.py index b185d530e056c..804d925790a6e 100644 --- a/pandas/io/tests/test_sql.py +++ b/pandas/io/tests/test_sql.py @@ -865,7 +865,7 @@ def test_uquery(self): def _get_sqlite_column_type(self, schema, column): for col in schema.split('\n'): - if col.split()[0].strip('[]') == column: + if col.split()[0].strip('""') == column: return col.split()[1] raise ValueError('Column %s not found' % (column)) @@ -1630,6 +1630,24 @@ def test_notnull_dtype(self): self.assertEqual(self._get_sqlite_column_type(tbl, 'Int'), 'INTEGER') self.assertEqual(self._get_sqlite_column_type(tbl, 'Float'), 'REAL') + def test_illegal_names(self): + # For sqlite, these should work fine + df = DataFrame([[1, 2], [3, 4]], columns=['a', 'b']) + + # Raise error on blank + self.assertRaises(ValueError, df.to_sql, "", self.conn, + flavor=self.flavor) + + for ndx, weird_name in enumerate(['test_weird_name]','test_weird_name[', + 'test_weird_name`','test_weird_name"', 'test_weird_name\'']): + df.to_sql(weird_name, self.conn, flavor=self.flavor) + sql.table_exists(weird_name, self.conn) + + df2 = DataFrame([[1, 2], [3, 4]], columns=['a', weird_name]) + c_tbl = 'test_weird_col_name%d'%ndx + df.to_sql(c_tbl, self.conn, flavor=self.flavor) + sql.table_exists(c_tbl, self.conn) + class TestMySQLLegacy(TestSQLiteFallback): """ @@ -1721,6 +1739,19 @@ def test_to_sql_save_index(self): def test_to_sql_save_index(self): self._to_sql_save_index() + def test_illegal_names(self): + # For MySQL, these should raise ValueError + for ndx, illegal_name in enumerate(['test_illegal_name]','test_illegal_name[', + 'test_illegal_name`','test_illegal_name"', 'test_illegal_name\'', '']): + df = DataFrame([[1, 2], [3, 4]], columns=['a', 'b']) + self.assertRaises(ValueError, df.to_sql, illegal_name, self.conn, + flavor=self.flavor, index=False) + + df2 = DataFrame([[1, 2], [3, 4]], columns=['a', illegal_name]) + c_tbl = 'test_illegal_col_name%d'%ndx + self.assertRaises(ValueError, df2.to_sql, 'test_illegal_col_name', + self.conn, flavor=self.flavor, index=False) + #------------------------------------------------------------------------------ #--- Old tests from 0.13.1 (before refactor using sqlalchemy) @@ -1817,7 +1848,7 @@ def test_schema(self): frame = tm.makeTimeDataFrame() create_sql = sql.get_schema(frame, 'test', 'sqlite', keys=['A', 'B'],) lines = create_sql.splitlines() - self.assertTrue('PRIMARY KEY ([A],[B])' in create_sql) + self.assertTrue('PRIMARY KEY ("A","B")' in create_sql) cur = self.db.cursor() cur.execute(create_sql)