Skip to content

Commit d2d9015

Browse files
author
Jesse
authored
Add tests for parameter sanitisation / escaping (#46)
* Refactor so we can unit test `inject_parameters` * Add unit tests for inject_parameters * Remove inaccurate comment. Per #51, spark sql does not support escaping a single quote with a second single quote. * Closes #51 and adds unit tests plus the integration test provided in #56 Signed-off-by: Jesse Whitehouse <[email protected]> Co-authored-by: Courtney Holcomb (@courtneyholcomb) Co-authored-by: @mcannamela
1 parent f99123c commit d2d9015

File tree

4 files changed

+174
-4
lines changed

4 files changed

+174
-4
lines changed

src/databricks/sql/client.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from databricks.sql import *
88
from databricks.sql.exc import OperationalError
99
from databricks.sql.thrift_backend import ThriftBackend
10-
from databricks.sql.utils import ExecuteResponse, ParamEscaper
10+
from databricks.sql.utils import ExecuteResponse, ParamEscaper, inject_parameters
1111
from databricks.sql.types import Row
1212
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
1313
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
@@ -310,7 +310,9 @@ def execute(
310310
:returns self
311311
"""
312312
if parameters is not None:
313-
operation = operation % self.escaper.escape_args(parameters)
313+
operation = inject_parameters(
314+
operation, self.escaper.escape_args(parameters)
315+
)
314316

315317
self._check_not_closed()
316318
self._close_and_clear_active_result_set()

src/databricks/sql/utils.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections.abc import Iterable
33
import datetime
44
from enum import Enum
5-
5+
from typing import Dict
66
import pyarrow
77

88

@@ -146,7 +146,7 @@ def escape_string(self, item):
146146
# This is good enough when backslashes are literal, newlines are just followed, and the way
147147
# to escape a single quote is to put two single quotes.
148148
# (i.e. only special character is single quote)
149-
return "'{}'".format(item.replace("'", "''"))
149+
return "'{}'".format(item.replace("\\", "\\\\").replace("'", "\\'"))
150150

151151
def escape_sequence(self, item):
152152
l = map(str, map(self.escape_item, item))
@@ -172,3 +172,7 @@ def escape_item(self, item):
172172
return self.escape_datetime(item, self._DATE_FORMAT)
173173
else:
174174
raise exc.ProgrammingError("Unsupported object {}".format(item))
175+
176+
177+
def inject_parameters(operation: str, parameters: Dict[str, str]):
178+
return operation % parameters

tests/e2e/driver_tests.py

+14
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,20 @@ def test_get_columns(self):
288288
for table in table_names:
289289
cursor.execute('DROP TABLE IF EXISTS {}'.format(table))
290290

291+
def test_escape_single_quotes(self):
292+
with self.cursor({}) as cursor:
293+
table_name = 'table_{uuid}'.format(uuid=str(uuid4()).replace('-', '_'))
294+
# Test escape syntax directly
295+
cursor.execute("CREATE TABLE IF NOT EXISTS {} AS (SELECT 'you\\'re' AS col_1)".format(table_name))
296+
cursor.execute("SELECT * FROM {} WHERE col_1 LIKE 'you\\'re'".format(table_name))
297+
rows = cursor.fetchall()
298+
assert rows[0]["col_1"] == "you're"
299+
300+
# Test escape syntax in parameter
301+
cursor.execute("SELECT * FROM {} WHERE {}.col_1 LIKE %(var)s".format(table_name, table_name), parameters={"var": "you're"})
302+
rows = cursor.fetchall()
303+
assert rows[0]["col_1"] == "you're"
304+
291305
def test_get_schemas(self):
292306
with self.cursor({}) as cursor:
293307
database_name = 'db_{uuid}'.format(uuid=str(uuid4()).replace('-', '_'))

tests/unit/test_param_escaper.py

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from datetime import date, datetime
2+
import unittest, pytest
3+
4+
from databricks.sql.utils import ParamEscaper, inject_parameters
5+
6+
pe = ParamEscaper()
7+
8+
class TestIndividualFormatters(object):
9+
10+
# Test individual type escapers
11+
def test_escape_number_integer(self):
12+
"""This behaviour falls back to Python's default string formatting of numbers
13+
"""
14+
assert pe.escape_number(100) == 100
15+
16+
def test_escape_number_float(self):
17+
"""This behaviour falls back to Python's default string formatting of numbers
18+
"""
19+
assert pe.escape_number(100.1234) == 100.1234
20+
21+
def test_escape_string_normal(self):
22+
"""
23+
"""
24+
25+
assert pe.escape_string("golly bob howdy") == "'golly bob howdy'"
26+
27+
def test_escape_string_that_includes_special_characters(self):
28+
"""Tests for how special characters are treated.
29+
30+
When passed a string, the `escape_string` method wraps it in single quotes
31+
and escapes any special characters with a back stroke (\)
32+
33+
Example:
34+
35+
IN : his name was 'robert palmer'
36+
OUT: 'his name was \'robert palmer\''
37+
"""
38+
39+
# Testing for the presence of these characters: '"/\😂
40+
41+
assert pe.escape_string("his name was 'robert palmer'") == r"'his name was \'robert palmer\''"
42+
43+
# These tests represent the same user input in the several ways it can be written in Python
44+
# Each argument to `escape_string` evaluates to the same bytes. But Python lets us write it differently.
45+
assert pe.escape_string("his name was \"robert palmer\"") == "'his name was \"robert palmer\"'"
46+
assert pe.escape_string('his name was "robert palmer"') == "'his name was \"robert palmer\"'"
47+
assert pe.escape_string('his name was {}'.format('"robert palmer"')) == "'his name was \"robert palmer\"'"
48+
49+
assert pe.escape_string("his name was robert / palmer") == r"'his name was robert / palmer'"
50+
51+
# If you need to include a single backslash, use an r-string to prevent Python from raising a
52+
# DeprecationWarning for an invalid escape sequence
53+
assert pe.escape_string("his name was robert \\/ palmer") == r"'his name was robert \\/ palmer'"
54+
assert pe.escape_string("his name was robert \\ palmer") == r"'his name was robert \\ palmer'"
55+
assert pe.escape_string("his name was robert \\\\ palmer") == r"'his name was robert \\\\ palmer'"
56+
57+
assert pe.escape_string("his name was robert palmer 😂") == r"'his name was robert palmer 😂'"
58+
59+
# Adding the test from PR #56 to prove escape behaviour
60+
61+
assert pe.escape_string("you're") == r"'you\'re'"
62+
63+
# Adding this test from #51 to prove escape behaviour when the target string involves repeated SQL escape chars
64+
assert pe.escape_string("cat\\'s meow") == r"'cat\\\'s meow'"
65+
66+
# Tests from the docs: https://docs.databricks.com/sql/language-manual/data-types/string-type.html
67+
68+
assert pe.escape_string('Spark') == "'Spark'"
69+
assert pe.escape_string("O'Connell") == r"'O\'Connell'"
70+
assert pe.escape_string("Some\\nText") == r"'Some\\nText'"
71+
assert pe.escape_string("Some\\\\nText") == r"'Some\\\\nText'"
72+
assert pe.escape_string("서울시") == "'서울시'"
73+
assert pe.escape_string("\\\\") == r"'\\\\'"
74+
75+
def test_escape_date_time(self):
76+
INPUT = datetime(1991,8,3,21,55)
77+
FORMAT = "%Y-%m-%d %H:%M:%S"
78+
OUTPUT = "'1991-08-03 21:55:00'"
79+
assert pe.escape_datetime(INPUT, FORMAT) == OUTPUT
80+
81+
def test_escape_date(self):
82+
INPUT = date(1991,8,3)
83+
FORMAT = "%Y-%m-%d"
84+
OUTPUT = "'1991-08-03'"
85+
assert pe.escape_datetime(INPUT, FORMAT) == OUTPUT
86+
87+
def test_escape_sequence_integer(self):
88+
assert pe.escape_sequence([1,2,3,4]) == "(1,2,3,4)"
89+
90+
def test_escape_sequence_float(self):
91+
assert pe.escape_sequence([1.1,2.2,3.3,4.4]) == "(1.1,2.2,3.3,4.4)"
92+
93+
def test_escape_sequence_string(self):
94+
assert pe.escape_sequence(
95+
["his", "name", "was", "robert", "palmer"]) == \
96+
"('his','name','was','robert','palmer')"
97+
98+
def test_escape_sequence_sequence_of_strings(self):
99+
# This is not valid SQL.
100+
INPUT = [["his", "name"], ["was", "robert"], ["palmer"]]
101+
OUTPUT = "(('his','name'),('was','robert'),('palmer'))"
102+
103+
assert pe.escape_sequence(INPUT) == OUTPUT
104+
105+
106+
class TestFullQueryEscaping(object):
107+
108+
def test_simple(self):
109+
110+
INPUT = """
111+
SELECT
112+
field1,
113+
field2,
114+
field3
115+
FROM
116+
table
117+
WHERE
118+
field1 = %(param1)s
119+
"""
120+
121+
OUTPUT = """
122+
SELECT
123+
field1,
124+
field2,
125+
field3
126+
FROM
127+
table
128+
WHERE
129+
field1 = ';DROP ALL TABLES'
130+
"""
131+
132+
args = {"param1": ";DROP ALL TABLES"}
133+
134+
assert inject_parameters(INPUT, pe.escape_args(args)) == OUTPUT
135+
136+
@unittest.skipUnless(False, "Thrift server supports native parameter binding.")
137+
def test_only_bind_in_where_clause(self):
138+
139+
INPUT = """
140+
SELECT
141+
%(field)s,
142+
field2,
143+
field3
144+
FROM table
145+
"""
146+
147+
args = {"field": "Some Value"}
148+
149+
with pytest.raises(Exception):
150+
inject_parameters(INPUT, pe.escape_args(args))

0 commit comments

Comments
 (0)