Skip to content

Commit bce1598

Browse files
author
Jesse Whitehouse
committed
Add unit tests for inject_parameters
Signed-off-by: Jesse Whitehouse <[email protected]>
1 parent 49d5b85 commit bce1598

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

tests/unit/test_param_escaper.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from datetime import date, datetime
2+
import unittest, pytest
3+
4+
from databricks.sql.utils import ParamEscaper, inject_parameters
5+
6+
pe = ParamEscaper()
7+
8+
class TestIndividualFormatters(object):
9+
10+
# Test individual type escapers
11+
def test_escape_number_integer(self):
12+
"""This behaviour falls back to Python's default string formatting of numbers
13+
"""
14+
assert pe.escape_number(100) == 100
15+
16+
def test_escape_number_float(self):
17+
"""This behaviour falls back to Python's default string formatting of numbers
18+
"""
19+
assert pe.escape_number(100.1234) == 100.1234
20+
21+
def test_escape_string_normal(self):
22+
"""
23+
"""
24+
25+
assert pe.escape_string("golly bob howdy") == "'golly bob howdy'"
26+
27+
def test_escape_string_that_includes_quotes(self):
28+
# Databricks queries support just one special character: a single quote mark
29+
# These are escaped by doubling:
30+
# e.g. INPUT: his name was 'robert palmer'
31+
# e.g. OUTPUT: 'his name was ''robert palmer'''
32+
33+
assert pe.escape_string("his name was 'robert palmer'") == "'his name was ''robert palmer'''"
34+
35+
def test_escape_date_time(self):
36+
INPUT = datetime(1991,8,3,21,55)
37+
OUTPUT = "1991-08-03 21:55:00"
38+
assert pe.escape_datetime(INPUT, OUTPUT)
39+
40+
def test_escape_date(self):
41+
INPUT = date(1991,8,3)
42+
OUTPUT = "1991-08-03"
43+
assert pe.escape_datetime(INPUT, OUTPUT)
44+
45+
def test_escape_sequence_integer(self):
46+
assert pe.escape_sequence([1,2,3,4]) == "(1,2,3,4)"
47+
48+
def test_escape_sequence_float(self):
49+
assert pe.escape_sequence([1.1,2.2,3.3,4.4]) == "(1.1,2.2,3.3,4.4)"
50+
51+
def test_escape_sequence_string(self):
52+
assert pe.escape_sequence(
53+
["his", "name", "was", "robert", "palmer"]) == \
54+
"('his','name','was','robert','palmer')"
55+
56+
def test_escape_sequence_sequence_of_strings(self):
57+
# This is not valid SQL.
58+
INPUT = [["his", "name"], ["was", "robert"], ["palmer"]]
59+
OUTPUT = "(('his','name'),('was','robert'),('palmer'))"
60+
61+
assert pe.escape_sequence(INPUT) == OUTPUT
62+
63+
64+
class TestFullQueryEscaping(object):
65+
66+
def test_simple(self):
67+
68+
INPUT = """
69+
SELECT
70+
field1,
71+
field2,
72+
field3
73+
FROM
74+
table
75+
WHERE
76+
field1 = %(param1)s
77+
"""
78+
79+
OUTPUT = """
80+
SELECT
81+
field1,
82+
field2,
83+
field3
84+
FROM
85+
table
86+
WHERE
87+
field1 = ';DROP ALL TABLES'
88+
"""
89+
90+
args = {"param1": ";DROP ALL TABLES"}
91+
92+
assert inject_parameters(INPUT, pe.escape_args(args)) == OUTPUT
93+
94+
@unittest.skipUnless(False, "Thrift server supports native parameter binding.")
95+
def test_only_bind_in_where_clause(self):
96+
97+
INPUT = """
98+
SELECT
99+
%(field)s,
100+
field2,
101+
field3
102+
FROM table
103+
"""
104+
105+
args = {"field": "Some Value"}
106+
107+
with pytest.raises(Exception):
108+
inject_parameters(INPUT, pe.escape_args(args))

0 commit comments

Comments
 (0)