Skip to content

Test parameter escaping #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 14, 2022
6 changes: 4 additions & 2 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from databricks.sql import *
from databricks.sql.exc import OperationalError
from databricks.sql.thrift_backend import ThriftBackend
from databricks.sql.utils import ExecuteResponse, ParamEscaper
from databricks.sql.utils import ExecuteResponse, ParamEscaper, inject_parameters
from databricks.sql.types import Row
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
Expand Down Expand Up @@ -309,7 +309,9 @@ def execute(
:returns self
"""
if parameters is not None:
operation = operation % self.escaper.escape_args(parameters)
operation = inject_parameters(
operation, self.escaper.escape_args(parameters)
)

self._check_not_closed()
self._close_and_clear_active_result_set()
Expand Down
8 changes: 6 additions & 2 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Iterable
import datetime
from enum import Enum

from typing import Dict
import pyarrow


Expand Down Expand Up @@ -146,7 +146,7 @@ def escape_string(self, item):
# This is good enough when backslashes are literal, newlines are just followed, and the way
# to escape a single quote is to put two single quotes.
# (i.e. only special character is single quote)
return "'{}'".format(item.replace("'", "''"))
return "'{}'".format(item.replace("\\", "\\\\").replace("'", "\\'"))

def escape_sequence(self, item):
l = map(str, map(self.escape_item, item))
Expand All @@ -172,3 +172,7 @@ def escape_item(self, item):
return self.escape_datetime(item, self._DATE_FORMAT)
else:
raise exc.ProgrammingError("Unsupported object {}".format(item))


def inject_parameters(operation: str, parameters: Dict[str, str]):
return operation % parameters
14 changes: 14 additions & 0 deletions tests/e2e/driver_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,20 @@ def test_get_columns(self):
for table in table_names:
cursor.execute('DROP TABLE IF EXISTS {}'.format(table))

def test_escape_single_quotes(self):
with self.cursor({}) as cursor:
table_name = 'table_{uuid}'.format(uuid=str(uuid4()).replace('-', '_'))
# Test escape syntax directly
cursor.execute("CREATE TABLE IF NOT EXISTS {} AS (SELECT 'you\\'re' AS col_1)".format(table_name))
cursor.execute("SELECT * FROM {} WHERE col_1 LIKE 'you\\'re'".format(table_name))
rows = cursor.fetchall()
assert rows[0]["col_1"] == "you're"

# Test escape syntax in parameter
cursor.execute("SELECT * FROM {} WHERE {}.col_1 LIKE %(var)s".format(table_name, table_name), parameters={"var": "you're"})
rows = cursor.fetchall()
assert rows[0]["col_1"] == "you're"

def test_get_schemas(self):
with self.cursor({}) as cursor:
database_name = 'db_{uuid}'.format(uuid=str(uuid4()).replace('-', '_'))
Expand Down
150 changes: 150 additions & 0 deletions tests/unit/test_param_escaper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from datetime import date, datetime
import unittest, pytest

from databricks.sql.utils import ParamEscaper, inject_parameters

pe = ParamEscaper()

class TestIndividualFormatters(object):

# Test individual type escapers
def test_escape_number_integer(self):
"""This behaviour falls back to Python's default string formatting of numbers
"""
assert pe.escape_number(100) == 100

def test_escape_number_float(self):
"""This behaviour falls back to Python's default string formatting of numbers
"""
assert pe.escape_number(100.1234) == 100.1234

def test_escape_string_normal(self):
"""
"""

assert pe.escape_string("golly bob howdy") == "'golly bob howdy'"

def test_escape_string_that_includes_special_characters(self):
"""Tests for how special characters are treated.
When passed a string, the `escape_string` method wraps it in single quotes
and escapes any special characters with a back stroke (\)
Example:
IN : his name was 'robert palmer'
OUT: 'his name was \'robert palmer\''
"""

# Testing for the presence of these characters: '"/\😂

assert pe.escape_string("his name was 'robert palmer'") == r"'his name was \'robert palmer\''"

# These tests represent the same user input in the several ways it can be written in Python
# Each argument to `escape_string` evaluates to the same bytes. But Python lets us write it differently.
assert pe.escape_string("his name was \"robert palmer\"") == "'his name was \"robert palmer\"'"
assert pe.escape_string('his name was "robert palmer"') == "'his name was \"robert palmer\"'"
assert pe.escape_string('his name was {}'.format('"robert palmer"')) == "'his name was \"robert palmer\"'"

assert pe.escape_string("his name was robert / palmer") == r"'his name was robert / palmer'"

# If you need to include a single backslash, use an r-string to prevent Python from raising a
# DeprecationWarning for an invalid escape sequence
assert pe.escape_string("his name was robert \\/ palmer") == r"'his name was robert \\/ palmer'"
assert pe.escape_string("his name was robert \\ palmer") == r"'his name was robert \\ palmer'"
assert pe.escape_string("his name was robert \\\\ palmer") == r"'his name was robert \\\\ palmer'"

assert pe.escape_string("his name was robert palmer 😂") == r"'his name was robert palmer 😂'"

# Adding the test from PR #56 to prove escape behaviour

assert pe.escape_string("you're") == r"'you\'re'"

# Adding this test from #51 to prove escape behaviour when the target string involves repeated SQL escape chars
assert pe.escape_string("cat\\'s meow") == r"'cat\\\'s meow'"

# Tests from the docs: https://docs.databricks.com/sql/language-manual/data-types/string-type.html

assert pe.escape_string('Spark') == "'Spark'"
assert pe.escape_string("O'Connell") == r"'O\'Connell'"
assert pe.escape_string("Some\\nText") == r"'Some\\nText'"
assert pe.escape_string("Some\\\\nText") == r"'Some\\\\nText'"
assert pe.escape_string("서울시") == "'서울시'"
assert pe.escape_string("\\\\") == r"'\\\\'"

def test_escape_date_time(self):
INPUT = datetime(1991,8,3,21,55)
FORMAT = "%Y-%m-%d %H:%M:%S"
OUTPUT = "'1991-08-03 21:55:00'"
assert pe.escape_datetime(INPUT, FORMAT) == OUTPUT

def test_escape_date(self):
INPUT = date(1991,8,3)
FORMAT = "%Y-%m-%d"
OUTPUT = "'1991-08-03'"
assert pe.escape_datetime(INPUT, FORMAT) == OUTPUT

def test_escape_sequence_integer(self):
assert pe.escape_sequence([1,2,3,4]) == "(1,2,3,4)"

def test_escape_sequence_float(self):
assert pe.escape_sequence([1.1,2.2,3.3,4.4]) == "(1.1,2.2,3.3,4.4)"

def test_escape_sequence_string(self):
assert pe.escape_sequence(
["his", "name", "was", "robert", "palmer"]) == \
"('his','name','was','robert','palmer')"

def test_escape_sequence_sequence_of_strings(self):
# This is not valid SQL.
INPUT = [["his", "name"], ["was", "robert"], ["palmer"]]
OUTPUT = "(('his','name'),('was','robert'),('palmer'))"

assert pe.escape_sequence(INPUT) == OUTPUT


class TestFullQueryEscaping(object):

def test_simple(self):

INPUT = """
SELECT
field1,
field2,
field3
FROM
table
WHERE
field1 = %(param1)s
"""

OUTPUT = """
SELECT
field1,
field2,
field3
FROM
table
WHERE
field1 = ';DROP ALL TABLES'
"""

args = {"param1": ";DROP ALL TABLES"}

assert inject_parameters(INPUT, pe.escape_args(args)) == OUTPUT

@unittest.skipUnless(False, "Thrift server supports native parameter binding.")
def test_only_bind_in_where_clause(self):

INPUT = """
SELECT
%(field)s,
field2,
field3
FROM table
"""

args = {"field": "Some Value"}

with pytest.raises(Exception):
inject_parameters(INPUT, pe.escape_args(args))