Skip to content

Commit cf6af76

Browse files
committed
Fix prepared statement handling
The prepared statement handling code assumed that for each query we'll always receive some non-empty response even after the initial response which is not a valid assumption. This assumption worked because earlier Trino used to send empty fake results even for queries which don't return results (like PREPARE and DEALLOCATE) but is now invalid with trinodb/trino@bc794cd. The other problem with the code was that it leaked HTTP protocol details into dbapi.py and worked around it by keeping a deep copy of the request object from the PREPARE execution and re-using it for the actual query execution. The new code fixes both issues by processing the prepared statement headers as they are received and storing the resulting set of active prepared statements on the ClientSession object. The ClientSession's set of prepared statements is then rendered into the prepared statement request header in TrinoRequest. Since the ClientSession is created and reused for the entire Connection this also means that we can now actually implement re-use of prepared statements within a single Connection.
1 parent efb6680 commit cf6af76

File tree

4 files changed

+64
-108
lines changed

4 files changed

+64
-108
lines changed

tests/unit/test_client.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -881,21 +881,6 @@ def __call__(self, *args, **kwargs):
881881
return http_response
882882

883883

884-
def test_trino_result_response_headers():
885-
"""
886-
Validates that the `TrinoResult.response_headers` property returns the
887-
headers associated to the TrinoQuery instance provided to the `TrinoResult`
888-
class.
889-
"""
890-
mock_trino_query = mock.Mock(respone_headers={
891-
'X-Trino-Fake-1': 'one',
892-
'X-Trino-Fake-2': 'two',
893-
})
894-
895-
result = TrinoResult(query=mock_trino_query, rows=[])
896-
assert result.response_headers == mock_trino_query.response_headers
897-
898-
899884
def test_trino_query_response_headers(sample_get_response_data):
900885
"""
901886
Validates that the `TrinoQuery.execute` function can take addtional headers

trino/client.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(
125125
self._extra_credential = extra_credential
126126
self._client_tags = client_tags
127127
self._role = role
128+
self._prepared_statements = {}
128129
self._object_lock = threading.Lock()
129130

130131
@property
@@ -206,6 +207,15 @@ def role(self, role):
206207
with self._object_lock:
207208
self._role = role
208209

210+
@property
211+
def prepared_statements(self):
212+
return self._prepared_statements
213+
214+
@prepared_statements.setter
215+
def prepared_statements(self, prepared_statements):
216+
with self._object_lock:
217+
self._prepared_statements = prepared_statements
218+
209219

210220
def get_header_values(headers, header):
211221
return [val.strip() for val in headers[header].split(",")]
@@ -218,6 +228,12 @@ def get_session_property_values(headers, header):
218228
for k, v in (kv.split("=", 1) for kv in kvs)
219229
]
220230

231+
def get_prepared_statement_values(headers, header):
232+
kvs = get_header_values(headers, header)
233+
return [
234+
(k.strip(), urllib.parse.unquote_plus(v.strip()))
235+
for k, v in (kv.split("=", 1) for kv in kvs)
236+
]
221237

222238
class TrinoStatus(object):
223239
def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, columns=None):
@@ -392,6 +408,13 @@ def http_headers(self) -> Dict[str, str]:
392408
for name, value in self._client_session.properties.items()
393409
)
394410

411+
if len(self._client_session.prepared_statements) != 0:
412+
# ``name`` must not contain ``=``
413+
headers[constants.HEADER_PREPARED_STATEMENT] = ",".join(
414+
"{}={}".format(name, urllib.parse.quote_plus(statement))
415+
for name, statement in self._client_session.prepared_statements.items()
416+
)
417+
395418
# merge custom http headers
396419
for key in self._client_session.headers:
397420
if key in headers.keys():
@@ -556,6 +579,18 @@ def process(self, http_response) -> TrinoStatus:
556579
if constants.HEADER_SET_ROLE in http_response.headers:
557580
self._client_session.role = http_response.headers[constants.HEADER_SET_ROLE]
558581

582+
if constants.HEADER_ADDED_PREPARE in http_response.headers:
583+
for name, statement in get_prepared_statement_values(
584+
http_response.headers, constants.HEADER_ADDED_PREPARE
585+
):
586+
self._client_session.prepared_statements[name] = statement
587+
588+
if constants.HEADER_DEALLOCATED_PREPARE in http_response.headers:
589+
for name in get_header_values(
590+
http_response.headers, constants.HEADER_DEALLOCATED_PREPARE
591+
):
592+
self._client_session.prepared_statements.pop(name)
593+
559594
self._next_uri = response.get("nextUri")
560595

561596
return TrinoStatus(
@@ -622,10 +657,6 @@ def __iter__(self):
622657

623658
self._rows = next_rows
624659

625-
@property
626-
def response_headers(self):
627-
return self._query.response_headers
628-
629660

630661
class TrinoQuery(object):
631662
"""Represent the execution of a SQL statement by Trino."""
@@ -648,7 +679,6 @@ def __init__(
648679
self._update_type = None
649680
self._sql = sql
650681
self._result: Optional[TrinoResult] = None
651-
self._response_headers = None
652682
self._experimental_python_types = experimental_python_types
653683
self._row_mapper: Optional[RowMapper] = None
654684

@@ -705,7 +735,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
705735
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
706736
self._result = TrinoResult(self, rows)
707737

708-
# Execute should block until at least one row is received
738+
# Execute should block until at least one row is received or query is finished or cancelled
709739
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
710740
self._result.rows += self.fetch()
711741
return self._result
@@ -725,7 +755,6 @@ def fetch(self) -> List[List[Any]]:
725755
status = self._request.process(response)
726756
self._update_state(status)
727757
logger.debug(status)
728-
self._response_headers = response.headers
729758
if status.next_uri is None:
730759
self._finished = True
731760

@@ -763,10 +792,6 @@ def finished(self) -> bool:
763792
def cancelled(self) -> bool:
764793
return self._cancelled
765794

766-
@property
767-
def response_headers(self):
768-
return self._response_headers
769-
770795

771796
def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts):
772797
def wrapper(func):

trino/dbapi.py

Lines changed: 28 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -295,58 +295,42 @@ def warnings(self):
295295
return self._query.warnings
296296
return None
297297

298+
def _new_request_with_session_from(self, request):
299+
"""
300+
Returns a new request with the `ClientSession` set to the one from the
301+
given request.
302+
"""
303+
request = self.connection._create_request()
304+
request._client_session = request._client_session
305+
return request
306+
298307
def setinputsizes(self, sizes):
299308
raise trino.exceptions.NotSupportedError
300309

301310
def setoutputsize(self, size, column):
302311
raise trino.exceptions.NotSupportedError
303312

304-
def _prepare_statement(self, operation, statement_name):
313+
def _prepare_statement(self, statement, name):
305314
"""
306-
Prepends the given `operation` with "PREPARE <statement_name> FROM" and
307-
executes as a prepare statement.
308-
309-
:param operation: sql to be executed.
310-
:param statement_name: name that will be assigned to the prepare
311-
statement.
312-
313-
:raises trino.exceptions.FailedToObtainAddedPrepareHeader: Error raised
314-
when unable to find the 'X-Trino-Added-Prepare' for the PREPARE
315-
statement request.
315+
Registers a prepared statement for the provided `operation` with the
316+
`name` assigned to it.
316317
317-
:return: string representing the value of the 'X-Trino-Added-Prepare'
318-
header.
318+
:param statement: sql to be executed.
319+
:param name: name that will be assigned to the prepared statement.
319320
"""
320-
sql = 'PREPARE {statement_name} FROM {operation}'.format(
321-
statement_name=statement_name,
322-
operation=operation
323-
)
324-
325-
# Send prepare statement. Copy the _request object to avoid polluting the
326-
# one that is going to be used to execute the actual operation.
327-
query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql,
321+
sql = f"PREPARE {name} FROM {statement}"
322+
# TODO: Evaluate whether we can avoid the piggybacking on current request
323+
query = trino.client.TrinoQuery(self._new_request_with_session_from(self._request), sql=sql,
328324
experimental_python_types=self._experimental_pyton_types)
329-
result = query.execute()
330-
331-
# Iterate until the 'X-Trino-Added-Prepare' header is found or
332-
# until there are no more results
333-
for _ in result:
334-
response_headers = result.response_headers
335-
336-
if constants.HEADER_ADDED_PREPARE in response_headers:
337-
return response_headers[constants.HEADER_ADDED_PREPARE]
325+
query.execute()
338326

339-
raise trino.exceptions.FailedToObtainAddedPrepareHeader
340327

341-
def _get_added_prepare_statement_trino_query(
328+
def _execute_prepared_statement(
342329
self,
343330
statement_name,
344331
params
345332
):
346333
sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params))
347-
348-
# No need to deepcopy _request here because this is the actual request
349-
# operation
350334
return trino.client.TrinoQuery(self._request, sql=sql, experimental_python_types=self._experimental_pyton_types)
351335

352336
def _format_prepared_param(self, param):
@@ -422,28 +406,12 @@ def _format_prepared_param(self, param):
422406

423407
raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param))
424408

425-
def _deallocate_prepare_statement(self, added_prepare_header, statement_name):
409+
def _deallocate_prepared_statement(self, statement_name):
426410
sql = 'DEALLOCATE PREPARE ' + statement_name
427-
428-
# Send deallocate statement. Copy the _request object to avoid poluting the
429-
# one that is going to be used to execute the actual operation.
430-
query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql,
411+
# TODO: Evaluate whether we can avoid the piggybacking on current request
412+
query = trino.client.TrinoQuery(self._new_request_with_session_from(self._request), sql=sql,
431413
experimental_python_types=self._experimental_pyton_types)
432-
result = query.execute(
433-
additional_http_headers={
434-
constants.HEADER_PREPARED_STATEMENT: added_prepare_header
435-
}
436-
)
437-
438-
# Iterate until the 'X-Trino-Deallocated-Prepare' header is found or
439-
# until there are no more results
440-
for _ in result:
441-
response_headers = result.response_headers
442-
443-
if constants.HEADER_DEALLOCATED_PREPARE in response_headers:
444-
return response_headers[constants.HEADER_DEALLOCATED_PREPARE]
445-
446-
raise trino.exceptions.FailedToObtainDeallocatedPrepareHeader
414+
query.execute()
447415

448416
def _generate_unique_statement_name(self):
449417
return 'st_' + uuid.uuid4().hex.replace('-', '')
@@ -456,27 +424,21 @@ def execute(self, operation, params=None):
456424
)
457425

458426
statement_name = self._generate_unique_statement_name()
459-
# Send prepare statement
460-
added_prepare_header = self._prepare_statement(
461-
operation, statement_name
462-
)
427+
self._prepare_statement(operation, statement_name)
463428

464429
try:
465430
# Send execute statement and assign the return value to `results`
466431
# as it will be returned by the function
467-
self._query = self._get_added_prepare_statement_trino_query(
432+
self._query = self._execute_prepared_statement(
468433
statement_name, params
469434
)
470-
result = self._query.execute(
471-
additional_http_headers={
472-
constants.HEADER_PREPARED_STATEMENT: added_prepare_header
473-
}
474-
)
435+
result = self._query.execute()
475436
finally:
476437
# Send deallocate statement
477438
# At this point the query can be deallocated since it has already
478439
# been executed
479-
self._deallocate_prepare_statement(added_prepare_header, statement_name)
440+
# TODO: Consider caching prepared statements if requested by caller
441+
self._deallocate_prepared_statement(statement_name)
480442

481443
else:
482444
self._query = trino.client.TrinoQuery(self._request, sql=operation,

trino/exceptions.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -134,22 +134,6 @@ class TrinoUserError(TrinoQueryError, ProgrammingError):
134134
pass
135135

136136

137-
class FailedToObtainAddedPrepareHeader(Error):
138-
"""
139-
Raise this exception when unable to find the 'X-Trino-Added-Prepare'
140-
header in the response of a PREPARE statement request.
141-
"""
142-
pass
143-
144-
145-
class FailedToObtainDeallocatedPrepareHeader(Error):
146-
"""
147-
Raise this exception when unable to find the 'X-Trino-Deallocated-Prepare'
148-
header in the response of a DEALLOCATED statement request.
149-
"""
150-
pass
151-
152-
153137
# client module errors
154138
class HttpError(Exception):
155139
pass

0 commit comments

Comments
 (0)