diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index e3190d45..b03ff6c2 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -7,7 +7,7 @@ from databricks.sql import * from databricks.sql.exc import OperationalError from databricks.sql.thrift_backend import ThriftBackend -from databricks.sql.utils import ExecuteResponse, ParamEscaper +from databricks.sql.utils import ExecuteResponse, ParamEscaper, inject_parameters from databricks.sql.types import Row from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import OAuthPersistence @@ -309,7 +309,9 @@ def execute( :returns self """ if parameters is not None: - operation = operation % self.escaper.escape_args(parameters) + operation = inject_parameters( + operation, self.escaper.escape_args(parameters) + ) self._check_not_closed() self._close_and_clear_active_result_set() diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 2961a1f5..1ffac8b1 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -2,7 +2,7 @@ from collections.abc import Iterable import datetime from enum import Enum - +from typing import Dict import pyarrow @@ -146,7 +146,7 @@ def escape_string(self, item): # This is good enough when backslashes are literal, newlines are just followed, and the way # to escape a single quote is to put two single quotes. # (i.e. only special character is single quote) - return "'{}'".format(item.replace("'", "''")) + return "'{}'".format(item.replace("\\", "\\\\").replace("'", "\\'")) def escape_sequence(self, item): l = map(str, map(self.escape_item, item)) @@ -172,3 +172,7 @@ def escape_item(self, item): return self.escape_datetime(item, self._DATE_FORMAT) else: raise exc.ProgrammingError("Unsupported object {}".format(item)) + + +def inject_parameters(operation: str, parameters: Dict[str, str]): + return operation % parameters diff --git a/tests/e2e/driver_tests.py b/tests/e2e/driver_tests.py index 9e400770..43973724 100644 --- a/tests/e2e/driver_tests.py +++ b/tests/e2e/driver_tests.py @@ -288,6 +288,20 @@ def test_get_columns(self): for table in table_names: cursor.execute('DROP TABLE IF EXISTS {}'.format(table)) + def test_escape_single_quotes(self): + with self.cursor({}) as cursor: + table_name = 'table_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) + # Test escape syntax directly + cursor.execute("CREATE TABLE IF NOT EXISTS {} AS (SELECT 'you\\'re' AS col_1)".format(table_name)) + cursor.execute("SELECT * FROM {} WHERE col_1 LIKE 'you\\'re'".format(table_name)) + rows = cursor.fetchall() + assert rows[0]["col_1"] == "you're" + + # Test escape syntax in parameter + cursor.execute("SELECT * FROM {} WHERE {}.col_1 LIKE %(var)s".format(table_name, table_name), parameters={"var": "you're"}) + rows = cursor.fetchall() + assert rows[0]["col_1"] == "you're" + def test_get_schemas(self): with self.cursor({}) as cursor: database_name = 'db_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) diff --git a/tests/unit/test_param_escaper.py b/tests/unit/test_param_escaper.py new file mode 100644 index 00000000..cb5758aa --- /dev/null +++ b/tests/unit/test_param_escaper.py @@ -0,0 +1,150 @@ +from datetime import date, datetime +import unittest, pytest + +from databricks.sql.utils import ParamEscaper, inject_parameters + +pe = ParamEscaper() + +class TestIndividualFormatters(object): + + # Test individual type escapers + def test_escape_number_integer(self): + """This behaviour falls back to Python's default string formatting of numbers + """ + assert pe.escape_number(100) == 100 + + def test_escape_number_float(self): + """This behaviour falls back to Python's default string formatting of numbers + """ + assert pe.escape_number(100.1234) == 100.1234 + + def test_escape_string_normal(self): + """ + """ + + assert pe.escape_string("golly bob howdy") == "'golly bob howdy'" + + def test_escape_string_that_includes_special_characters(self): + """Tests for how special characters are treated. + + When passed a string, the `escape_string` method wraps it in single quotes + and escapes any special characters with a back stroke (\) + + Example: + + IN : his name was 'robert palmer' + OUT: 'his name was \'robert palmer\'' + """ + + # Testing for the presence of these characters: '"/\πŸ˜‚ + + assert pe.escape_string("his name was 'robert palmer'") == r"'his name was \'robert palmer\''" + + # These tests represent the same user input in the several ways it can be written in Python + # Each argument to `escape_string` evaluates to the same bytes. But Python lets us write it differently. + assert pe.escape_string("his name was \"robert palmer\"") == "'his name was \"robert palmer\"'" + assert pe.escape_string('his name was "robert palmer"') == "'his name was \"robert palmer\"'" + assert pe.escape_string('his name was {}'.format('"robert palmer"')) == "'his name was \"robert palmer\"'" + + assert pe.escape_string("his name was robert / palmer") == r"'his name was robert / palmer'" + + # If you need to include a single backslash, use an r-string to prevent Python from raising a + # DeprecationWarning for an invalid escape sequence + assert pe.escape_string("his name was robert \\/ palmer") == r"'his name was robert \\/ palmer'" + assert pe.escape_string("his name was robert \\ palmer") == r"'his name was robert \\ palmer'" + assert pe.escape_string("his name was robert \\\\ palmer") == r"'his name was robert \\\\ palmer'" + + assert pe.escape_string("his name was robert palmer πŸ˜‚") == r"'his name was robert palmer πŸ˜‚'" + + # Adding the test from PR #56 to prove escape behaviour + + assert pe.escape_string("you're") == r"'you\'re'" + + # Adding this test from #51 to prove escape behaviour when the target string involves repeated SQL escape chars + assert pe.escape_string("cat\\'s meow") == r"'cat\\\'s meow'" + + # Tests from the docs: https://docs.databricks.com/sql/language-manual/data-types/string-type.html + + assert pe.escape_string('Spark') == "'Spark'" + assert pe.escape_string("O'Connell") == r"'O\'Connell'" + assert pe.escape_string("Some\\nText") == r"'Some\\nText'" + assert pe.escape_string("Some\\\\nText") == r"'Some\\\\nText'" + assert pe.escape_string("μ„œμšΈμ‹œ") == "'μ„œμšΈμ‹œ'" + assert pe.escape_string("\\\\") == r"'\\\\'" + + def test_escape_date_time(self): + INPUT = datetime(1991,8,3,21,55) + FORMAT = "%Y-%m-%d %H:%M:%S" + OUTPUT = "'1991-08-03 21:55:00'" + assert pe.escape_datetime(INPUT, FORMAT) == OUTPUT + + def test_escape_date(self): + INPUT = date(1991,8,3) + FORMAT = "%Y-%m-%d" + OUTPUT = "'1991-08-03'" + assert pe.escape_datetime(INPUT, FORMAT) == OUTPUT + + def test_escape_sequence_integer(self): + assert pe.escape_sequence([1,2,3,4]) == "(1,2,3,4)" + + def test_escape_sequence_float(self): + assert pe.escape_sequence([1.1,2.2,3.3,4.4]) == "(1.1,2.2,3.3,4.4)" + + def test_escape_sequence_string(self): + assert pe.escape_sequence( + ["his", "name", "was", "robert", "palmer"]) == \ + "('his','name','was','robert','palmer')" + + def test_escape_sequence_sequence_of_strings(self): + # This is not valid SQL. + INPUT = [["his", "name"], ["was", "robert"], ["palmer"]] + OUTPUT = "(('his','name'),('was','robert'),('palmer'))" + + assert pe.escape_sequence(INPUT) == OUTPUT + + +class TestFullQueryEscaping(object): + + def test_simple(self): + + INPUT = """ + SELECT + field1, + field2, + field3 + FROM + table + WHERE + field1 = %(param1)s + """ + + OUTPUT = """ + SELECT + field1, + field2, + field3 + FROM + table + WHERE + field1 = ';DROP ALL TABLES' + """ + + args = {"param1": ";DROP ALL TABLES"} + + assert inject_parameters(INPUT, pe.escape_args(args)) == OUTPUT + + @unittest.skipUnless(False, "Thrift server supports native parameter binding.") + def test_only_bind_in_where_clause(self): + + INPUT = """ + SELECT + %(field)s, + field2, + field3 + FROM table + """ + + args = {"field": "Some Value"} + + with pytest.raises(Exception): + inject_parameters(INPUT, pe.escape_args(args))