Skip to content

Commit faf9457

Browse files
committed
feat: support query params
1 parent 8c324dc commit faf9457

File tree

4 files changed

+124
-1
lines changed

4 files changed

+124
-1
lines changed

prestodb/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818
from . import client
1919
from . import constants
2020
from . import exceptions
21+
from . import escaper
2122

2223
__version__ = "0.8.3"

prestodb/dbapi.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from prestodb import constants
2929
import prestodb.exceptions
30+
import prestodb.escaper
3031
import prestodb.client
3132
import prestodb.redirect
3233
from prestodb.transaction import Transaction, IsolationLevel, NO_TRANSACTION
@@ -232,7 +233,8 @@ def setoutputsize(self, size, column):
232233
raise prestodb.exceptions.NotSupportedError
233234

234235
def execute(self, operation, params=None):
235-
self._query = prestodb.client.PrestoQuery(self._request, sql=operation)
236+
sql = operation if params is None else operation % prestodb.escaper.escape(params)
237+
self._query = prestodb.client.PrestoQuery(self._request, sql=sql)
236238
result = self._query.execute()
237239
self._iterator = iter(result)
238240
return result

prestodb/escaper.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
#
13+
# This code is forked from https://github.com/dropbox/PyHive (the Apache License, Version 2.0)
14+
from __future__ import absolute_import
15+
16+
import datetime
17+
18+
try:
19+
from collections.abc import Iterable
20+
except ImportError:
21+
from collections import Iterable
22+
23+
import prestodb.exceptions
24+
25+
class ParamsEscaper(object):
26+
_DATE_FORMAT = "%Y-%m-%d"
27+
_TIME_FORMAT = "%H:%M:%S.%f"
28+
_DATETIME_FORMAT = "{} {}".format(_DATE_FORMAT, _TIME_FORMAT)
29+
30+
def escape_args(self, parameters):
31+
if isinstance(parameters, dict):
32+
return {k: self.escape_item(v) for k, v in parameters.items()}
33+
34+
if isinstance(parameters, (list, tuple)):
35+
return tuple(self.escape_item(x) for x in parameters)
36+
37+
raise prestodb.exceptions.ProgrammingError("Unsupported param format: {}".format(parameters))
38+
39+
def escape_number(self, item):
40+
return item
41+
42+
def escape_bytes(self, item):
43+
return self.escape_string(item.decode("utf-8"))
44+
45+
def escape_string(self, item):
46+
# This is good enough when backslashes are literal, newlines are just followed, and the way
47+
# to escape a single quote is to put two single quotes.
48+
# (i.e. only special character is single quote)
49+
return "'{}'".format(item.replace("'", "''"))
50+
51+
def escape_sequence(self, item):
52+
l = map(str, map(self.escape_item, item))
53+
return '(' + ','.join(l) + ')'
54+
55+
def escape_datetime(self, item, format, cutoff=0):
56+
dt_str = item.strftime(format)
57+
formatted = dt_str[:-cutoff] if cutoff and format.endswith(".%f") else dt_str
58+
59+
_type = "timestamp" if isinstance(item, datetime.datetime) else "date"
60+
return "{} {}".format(_type, formatted)
61+
62+
def escape_item(self, item):
63+
if item is None:
64+
return 'NULL'
65+
66+
if isinstance(item, (int, float)):
67+
return self.escape_number(item)
68+
69+
if isinstance(item, bytes):
70+
return self.escape_bytes(item)
71+
72+
if isinstance(item, str):
73+
return self.escape_string(item)
74+
75+
if isinstance(item, Iterable):
76+
return self.escape_sequence(item)
77+
78+
if isinstance(item, datetime.datetime):
79+
return self.escape_datetime(item, self._DATETIME_FORMAT)
80+
81+
if isinstance(item, datetime.date):
82+
return self.escape_datetime(item, self._DATE_FORMAT)
83+
84+
raise prestodb.exceptions.ProgrammingError("Unsupported object {}".format(item))
85+
86+
escaper = ParamsEscaper()
87+
88+
def escape(params):
89+
return escaper.escape_args(params)

tests/test_escaper.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
from __future__ import absolute_import
13+
import datetime
14+
import prestodb.escaper
15+
16+
def test_escape_args():
17+
escaper = prestodb.escaper.ParamsEscaper()
18+
19+
assert escaper.escape_args({'foo': 'bar'}) == {'foo': "'bar'"}
20+
assert escaper.escape_args({'foo': 123}) == {'foo': 123}
21+
assert escaper.escape_args({'foo': 123.456}) == {'foo': 123.456}
22+
assert escaper.escape_args({'foo': ['a', 'b', 'c']}) == {'foo': "('a','b','c')"}
23+
assert escaper.escape_args({'foo': ('a', 'b', 'c')}) == {'foo': "('a','b','c')"}
24+
assert escaper.escape_args({'foo': {'a', 'b'}}) in ({'foo': "('a','b')"}, {'foo': "('b','a')"})
25+
assert escaper.escape_args(('bar',)) == ("'bar'",)
26+
assert escaper.escape_args([123]) == (123,)
27+
assert escaper.escape_args((123.456,)) == (123.456,)
28+
assert escaper.escape_args((['a', 'b', 'c'],)) == ("('a','b','c')",)
29+
30+
assert escaper.escape_args((datetime.date(2020, 4, 17),)) == ('date 2020-04-17',)
31+
assert escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)) == ('timestamp 2020-04-17 12:00:00.123456',)

0 commit comments

Comments
 (0)