Skip to content

Commit d5f779b

Browse files
committed
Fix prepared statement handling
The prepared statement handling code assumed that for each query we'll always receive some non-empty response even after the initial response which is not a valid assumption. This assumption worked because earlier Trino used to send empty fake results even for queries which don't return results (like PREPARE and DEALLOCATE) but is now invalid with trinodb/trino@bc794cd. The other problem with the code was that it leaked HTTP protocol details into dbapi.py and worked around it by keeping a deep copy of the request object from the PREPARE execution and re-using it for the actual query execution. The new code fixes both issues by processing the prepared statement headers as they are received and storing the resulting set of active prepared statements on the ClientSession object. The ClientSession's set of prepared statements is then rendered into the prepared statement request header in TrinoRequest. Since the ClientSession is created and reused for the entire Connection this also means that we can now actually implement re-use of prepared statements within a single Connection.
1 parent efb6680 commit d5f779b

File tree

5 files changed

+90
-110
lines changed

5 files changed

+90
-110
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,3 +1067,38 @@ def test_set_role_trino_351(run_trino):
10671067
cur.execute("SET ROLE ALL")
10681068
cur.fetchall()
10691069
assert cur._request._client_session.role == "tpch=ALL"
1070+
1071+
1072+
def test_prepared_statements(run_trino):
1073+
_, host, port = run_trino
1074+
1075+
trino_connection = trino.dbapi.Connection(
1076+
host=host, port=port, user="test", catalog="tpch",
1077+
)
1078+
cur = trino_connection.cursor()
1079+
1080+
# Implicit prepared statements must work and deallocate statements on finish
1081+
assert cur._request._client_session.prepared_statements == {}
1082+
cur.execute('SELECT count(1) FROM tpch.tiny.nation WHERE nationkey = ?', (1,))
1083+
result = cur.fetchall()
1084+
assert result[0][0] == 1
1085+
assert cur._request._client_session.prepared_statements == {}
1086+
1087+
# Explicit prepared statements must also work
1088+
cur.execute('PREPARE test_prepared_statements FROM SELECT count(1) FROM tpch.tiny.nation WHERE nationkey = ?')
1089+
cur.fetchall()
1090+
assert 'test_prepared_statements' in cur._request._client_session.prepared_statements
1091+
cur.execute('EXECUTE test_prepared_statements USING 1')
1092+
cur.fetchall()
1093+
assert result[0][0] == 1
1094+
1095+
# An implicit prepared statement must not deallocate explicit statements
1096+
cur.execute('SELECT count(1) FROM tpch.tiny.nation WHERE nationkey = ?', (1,))
1097+
result = cur.fetchall()
1098+
assert result[0][0] == 1
1099+
assert 'test_prepared_statements' in cur._request._client_session.prepared_statements
1100+
1101+
assert 'test_prepared_statements' in cur._request._client_session.prepared_statements
1102+
cur.execute('DEALLOCATE PREPARE test_prepared_statements')
1103+
cur.fetchall()
1104+
assert cur._request._client_session.prepared_statements == {}

tests/unit/test_client.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -881,21 +881,6 @@ def __call__(self, *args, **kwargs):
881881
return http_response
882882

883883

884-
def test_trino_result_response_headers():
885-
"""
886-
Validates that the `TrinoResult.response_headers` property returns the
887-
headers associated to the TrinoQuery instance provided to the `TrinoResult`
888-
class.
889-
"""
890-
mock_trino_query = mock.Mock(respone_headers={
891-
'X-Trino-Fake-1': 'one',
892-
'X-Trino-Fake-2': 'two',
893-
})
894-
895-
result = TrinoResult(query=mock_trino_query, rows=[])
896-
assert result.response_headers == mock_trino_query.response_headers
897-
898-
899884
def test_trino_query_response_headers(sample_get_response_data):
900885
"""
901886
Validates that the `TrinoQuery.execute` function can take addtional headers

trino/client.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(
125125
self._extra_credential = extra_credential
126126
self._client_tags = client_tags
127127
self._role = role
128+
self._prepared_statements: Dict[str, str] = {}
128129
self._object_lock = threading.Lock()
129130

130131
@property
@@ -206,6 +207,15 @@ def role(self, role):
206207
with self._object_lock:
207208
self._role = role
208209

210+
@property
211+
def prepared_statements(self):
212+
return self._prepared_statements
213+
214+
@prepared_statements.setter
215+
def prepared_statements(self, prepared_statements):
216+
with self._object_lock:
217+
self._prepared_statements = prepared_statements
218+
209219

210220
def get_header_values(headers, header):
211221
return [val.strip() for val in headers[header].split(",")]
@@ -219,6 +229,14 @@ def get_session_property_values(headers, header):
219229
]
220230

221231

232+
def get_prepared_statement_values(headers, header):
233+
kvs = get_header_values(headers, header)
234+
return [
235+
(k.strip(), urllib.parse.unquote_plus(v.strip()))
236+
for k, v in (kv.split("=", 1) for kv in kvs)
237+
]
238+
239+
222240
class TrinoStatus(object):
223241
def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, columns=None):
224242
self.id = id
@@ -392,6 +410,13 @@ def http_headers(self) -> Dict[str, str]:
392410
for name, value in self._client_session.properties.items()
393411
)
394412

413+
if len(self._client_session.prepared_statements) != 0:
414+
# ``name`` must not contain ``=``
415+
headers[constants.HEADER_PREPARED_STATEMENT] = ",".join(
416+
"{}={}".format(name, urllib.parse.quote_plus(statement))
417+
for name, statement in self._client_session.prepared_statements.items()
418+
)
419+
395420
# merge custom http headers
396421
for key in self._client_session.headers:
397422
if key in headers.keys():
@@ -556,6 +581,18 @@ def process(self, http_response) -> TrinoStatus:
556581
if constants.HEADER_SET_ROLE in http_response.headers:
557582
self._client_session.role = http_response.headers[constants.HEADER_SET_ROLE]
558583

584+
if constants.HEADER_ADDED_PREPARE in http_response.headers:
585+
for name, statement in get_prepared_statement_values(
586+
http_response.headers, constants.HEADER_ADDED_PREPARE
587+
):
588+
self._client_session.prepared_statements[name] = statement
589+
590+
if constants.HEADER_DEALLOCATED_PREPARE in http_response.headers:
591+
for name in get_header_values(
592+
http_response.headers, constants.HEADER_DEALLOCATED_PREPARE
593+
):
594+
self._client_session.prepared_statements.pop(name)
595+
559596
self._next_uri = response.get("nextUri")
560597

561598
return TrinoStatus(
@@ -622,10 +659,6 @@ def __iter__(self):
622659

623660
self._rows = next_rows
624661

625-
@property
626-
def response_headers(self):
627-
return self._query.response_headers
628-
629662

630663
class TrinoQuery(object):
631664
"""Represent the execution of a SQL statement by Trino."""
@@ -648,7 +681,6 @@ def __init__(
648681
self._update_type = None
649682
self._sql = sql
650683
self._result: Optional[TrinoResult] = None
651-
self._response_headers = None
652684
self._experimental_python_types = experimental_python_types
653685
self._row_mapper: Optional[RowMapper] = None
654686

@@ -705,7 +737,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
705737
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
706738
self._result = TrinoResult(self, rows)
707739

708-
# Execute should block until at least one row is received
740+
# Execute should block until at least one row is received or query is finished or cancelled
709741
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
710742
self._result.rows += self.fetch()
711743
return self._result
@@ -725,7 +757,6 @@ def fetch(self) -> List[List[Any]]:
725757
status = self._request.process(response)
726758
self._update_state(status)
727759
logger.debug(status)
728-
self._response_headers = response.headers
729760
if status.next_uri is None:
730761
self._finished = True
731762

@@ -763,10 +794,6 @@ def finished(self) -> bool:
763794
def cancelled(self) -> bool:
764795
return self._cancelled
765796

766-
@property
767-
def response_headers(self):
768-
return self._response_headers
769-
770797

771798
def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts):
772799
def wrapper(func):

trino/dbapi.py

Lines changed: 17 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from decimal import Decimal
2121
from typing import Any, List, Optional # NOQA for mypy types
2222

23-
import copy
2423
import uuid
2524
import datetime
2625
import math
@@ -301,52 +300,25 @@ def setinputsizes(self, sizes):
301300
def setoutputsize(self, size, column):
302301
raise trino.exceptions.NotSupportedError
303302

304-
def _prepare_statement(self, operation, statement_name):
303+
def _prepare_statement(self, statement: str, name: str) -> None:
305304
"""
306-
Prepends the given `operation` with "PREPARE <statement_name> FROM" and
307-
executes as a prepare statement.
305+
Registers a prepared statement for the provided `operation` with the
306+
`name` assigned to it.
308307
309-
:param operation: sql to be executed.
310-
:param statement_name: name that will be assigned to the prepare
311-
statement.
312-
313-
:raises trino.exceptions.FailedToObtainAddedPrepareHeader: Error raised
314-
when unable to find the 'X-Trino-Added-Prepare' for the PREPARE
315-
statement request.
316-
317-
:return: string representing the value of the 'X-Trino-Added-Prepare'
318-
header.
308+
:param statement: sql to be executed.
309+
:param name: name that will be assigned to the prepared statement.
319310
"""
320-
sql = 'PREPARE {statement_name} FROM {operation}'.format(
321-
statement_name=statement_name,
322-
operation=operation
323-
)
324-
325-
# Send prepare statement. Copy the _request object to avoid polluting the
326-
# one that is going to be used to execute the actual operation.
327-
query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql,
311+
sql = f"PREPARE {name} FROM {statement}"
312+
query = trino.client.TrinoQuery(self.connection._create_request(), sql=sql,
328313
experimental_python_types=self._experimental_pyton_types)
329-
result = query.execute()
314+
query.execute()
330315

331-
# Iterate until the 'X-Trino-Added-Prepare' header is found or
332-
# until there are no more results
333-
for _ in result:
334-
response_headers = result.response_headers
335-
336-
if constants.HEADER_ADDED_PREPARE in response_headers:
337-
return response_headers[constants.HEADER_ADDED_PREPARE]
338-
339-
raise trino.exceptions.FailedToObtainAddedPrepareHeader
340-
341-
def _get_added_prepare_statement_trino_query(
316+
def _execute_prepared_statement(
342317
self,
343318
statement_name,
344319
params
345320
):
346321
sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params))
347-
348-
# No need to deepcopy _request here because this is the actual request
349-
# operation
350322
return trino.client.TrinoQuery(self._request, sql=sql, experimental_python_types=self._experimental_pyton_types)
351323

352324
def _format_prepared_param(self, param):
@@ -422,28 +394,11 @@ def _format_prepared_param(self, param):
422394

423395
raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param))
424396

425-
def _deallocate_prepare_statement(self, added_prepare_header, statement_name):
397+
def _deallocate_prepared_statement(self, statement_name: str) -> None:
426398
sql = 'DEALLOCATE PREPARE ' + statement_name
427-
428-
# Send deallocate statement. Copy the _request object to avoid poluting the
429-
# one that is going to be used to execute the actual operation.
430-
query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql,
399+
query = trino.client.TrinoQuery(self.connection._create_request(), sql=sql,
431400
experimental_python_types=self._experimental_pyton_types)
432-
result = query.execute(
433-
additional_http_headers={
434-
constants.HEADER_PREPARED_STATEMENT: added_prepare_header
435-
}
436-
)
437-
438-
# Iterate until the 'X-Trino-Deallocated-Prepare' header is found or
439-
# until there are no more results
440-
for _ in result:
441-
response_headers = result.response_headers
442-
443-
if constants.HEADER_DEALLOCATED_PREPARE in response_headers:
444-
return response_headers[constants.HEADER_DEALLOCATED_PREPARE]
445-
446-
raise trino.exceptions.FailedToObtainDeallocatedPrepareHeader
401+
query.execute()
447402

448403
def _generate_unique_statement_name(self):
449404
return 'st_' + uuid.uuid4().hex.replace('-', '')
@@ -456,27 +411,21 @@ def execute(self, operation, params=None):
456411
)
457412

458413
statement_name = self._generate_unique_statement_name()
459-
# Send prepare statement
460-
added_prepare_header = self._prepare_statement(
461-
operation, statement_name
462-
)
414+
self._prepare_statement(operation, statement_name)
463415

464416
try:
465417
# Send execute statement and assign the return value to `results`
466418
# as it will be returned by the function
467-
self._query = self._get_added_prepare_statement_trino_query(
419+
self._query = self._execute_prepared_statement(
468420
statement_name, params
469421
)
470-
result = self._query.execute(
471-
additional_http_headers={
472-
constants.HEADER_PREPARED_STATEMENT: added_prepare_header
473-
}
474-
)
422+
result = self._query.execute()
475423
finally:
476424
# Send deallocate statement
477425
# At this point the query can be deallocated since it has already
478426
# been executed
479-
self._deallocate_prepare_statement(added_prepare_header, statement_name)
427+
# TODO: Consider caching prepared statements if requested by caller
428+
self._deallocate_prepared_statement(statement_name)
480429

481430
else:
482431
self._query = trino.client.TrinoQuery(self._request, sql=operation,

trino/exceptions.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -134,22 +134,6 @@ class TrinoUserError(TrinoQueryError, ProgrammingError):
134134
pass
135135

136136

137-
class FailedToObtainAddedPrepareHeader(Error):
138-
"""
139-
Raise this exception when unable to find the 'X-Trino-Added-Prepare'
140-
header in the response of a PREPARE statement request.
141-
"""
142-
pass
143-
144-
145-
class FailedToObtainDeallocatedPrepareHeader(Error):
146-
"""
147-
Raise this exception when unable to find the 'X-Trino-Deallocated-Prepare'
148-
header in the response of a DEALLOCATED statement request.
149-
"""
150-
pass
151-
152-
153137
# client module errors
154138
class HttpError(Exception):
155139
pass

0 commit comments

Comments
 (0)