Skip to content

feat: support query params #127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions prestodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
from . import client
from . import constants
from . import exceptions
from . import escaper

__version__ = "0.8.3"
4 changes: 3 additions & 1 deletion prestodb/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from prestodb import constants
import prestodb.exceptions
import prestodb.escaper
import prestodb.client
import prestodb.redirect
from prestodb.transaction import Transaction, IsolationLevel, NO_TRANSACTION
Expand Down Expand Up @@ -232,7 +233,8 @@ def setoutputsize(self, size, column):
raise prestodb.exceptions.NotSupportedError

def execute(self, operation, params=None):
self._query = prestodb.client.PrestoQuery(self._request, sql=operation)
sql = operation if params is None else operation % prestodb.escaper.escape(params)
self._query = prestodb.client.PrestoQuery(self._request, sql=sql)
result = self._query.execute()
self._iterator = iter(result)
return result
Expand Down
89 changes: 89 additions & 0 deletions prestodb/escaper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is forked from https://github.com/dropbox/PyHive (the Apache License, Version 2.0)
from __future__ import absolute_import

import datetime

try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable

import prestodb.exceptions

class ParamsEscaper(object):
_DATE_FORMAT = "%Y-%m-%d"
_TIME_FORMAT = "%H:%M:%S.%f"
_DATETIME_FORMAT = "{} {}".format(_DATE_FORMAT, _TIME_FORMAT)

def escape_args(self, parameters):
if isinstance(parameters, dict):
return {k: self.escape_item(v) for k, v in parameters.items()}

if isinstance(parameters, (list, tuple)):
return tuple(self.escape_item(x) for x in parameters)

raise prestodb.exceptions.ProgrammingError("Unsupported param format: {}".format(parameters))

def escape_number(self, item):
return item

def escape_bytes(self, item):
return self.escape_string(item.decode("utf-8"))

def escape_string(self, item):
# This is good enough when backslashes are literal, newlines are just followed, and the way
# to escape a single quote is to put two single quotes.
# (i.e. only special character is single quote)
return "'{}'".format(item.replace("'", "''"))

def escape_sequence(self, item):
l = map(str, map(self.escape_item, item))
return '(' + ','.join(l) + ')'

def escape_datetime(self, item, format, cutoff=0):
dt_str = item.strftime(format)
formatted = dt_str[:-cutoff] if cutoff and format.endswith(".%f") else dt_str

_type = "timestamp" if isinstance(item, datetime.datetime) else "date"
return "{} {}".format(_type, formatted)

def escape_item(self, item):
if item is None:
return 'NULL'

if isinstance(item, (int, float)):
return self.escape_number(item)

if isinstance(item, bytes):
return self.escape_bytes(item)

if isinstance(item, str):
return self.escape_string(item)

if isinstance(item, Iterable):
return self.escape_sequence(item)

if isinstance(item, datetime.datetime):
return self.escape_datetime(item, self._DATETIME_FORMAT)

if isinstance(item, datetime.date):
return self.escape_datetime(item, self._DATE_FORMAT)

raise prestodb.exceptions.ProgrammingError("Unsupported object {}".format(item))

escaper = ParamsEscaper()

def escape(params):
return escaper.escape_args(params)
31 changes: 31 additions & 0 deletions tests/test_escaper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import datetime
import prestodb.escaper

def test_escape_args():
escaper = prestodb.escaper.ParamsEscaper()

assert escaper.escape_args({'foo': 'bar'}) == {'foo': "'bar'"}
assert escaper.escape_args({'foo': 123}) == {'foo': 123}
assert escaper.escape_args({'foo': 123.456}) == {'foo': 123.456}
assert escaper.escape_args({'foo': ['a', 'b', 'c']}) == {'foo': "('a','b','c')"}
assert escaper.escape_args({'foo': ('a', 'b', 'c')}) == {'foo': "('a','b','c')"}
assert escaper.escape_args({'foo': {'a', 'b'}}) in ({'foo': "('a','b')"}, {'foo': "('b','a')"})
assert escaper.escape_args(('bar',)) == ("'bar'",)
assert escaper.escape_args([123]) == (123,)
assert escaper.escape_args((123.456,)) == (123.456,)
assert escaper.escape_args((['a', 'b', 'c'],)) == ("('a','b','c')",)

assert escaper.escape_args((datetime.date(2020, 4, 17),)) == ('date 2020-04-17',)
assert escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)) == ('timestamp 2020-04-17 12:00:00.123456',)