Skip to content

Commit 9d898a8

Browse files
committed
Acknowledge reception of data in TrinoResult
1 parent cffd2b2 commit 9d898a8

File tree

3 files changed

+14
-16
lines changed

3 files changed

+14
-16
lines changed

trino/client.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -594,28 +594,26 @@ class TrinoResult(object):
594594

595595
def __init__(self, query, rows=None):
596596
self._query = query
597-
self._rows = rows or []
597+
# Initial rows from the first POST request
598+
self._rows = rows
598599
self._rownumber = 0
599600

600601
@property
601602
def rownumber(self) -> int:
602603
return self._rownumber
603604

604605
def __iter__(self):
605-
# Initial fetch from the first POST request
606-
for row in self._rows:
607-
self._rownumber += 1
608-
yield row
609-
self._rows = None
610-
611-
# Subsequent fetches from GET requests until next_uri is empty.
612-
while not self._query.finished:
613-
rows = self._query.fetch()
614-
for row in rows:
606+
# A query only transitions to a FINISHED state when the results are fully consumed:
607+
# The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
608+
while not self._query.finished or self._rows is not None:
609+
next_rows = self._query.fetch() if not self._query.finished else None
610+
for row in self._rows:
615611
self._rownumber += 1
616612
logger.debug("row %s", row)
617613
yield row
618614

615+
self._rows = next_rows
616+
619617
@property
620618
def response_headers(self):
621619
return self._query.response_headers
@@ -641,7 +639,7 @@ def __init__(
641639
self._request = request
642640
self._update_type = None
643641
self._sql = sql
644-
self._result = TrinoResult(self)
642+
self._result: Optional[TrinoResult] = None
645643
self._response_headers = None
646644
self._experimental_python_types = experimental_python_types
647645
self._row_mapper: Optional[RowMapper] = None

trino/dbapi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def _prepare_statement(self, operation, statement_name):
322322
operation=operation
323323
)
324324

325-
# Send prepare statement. Copy the _request object to avoid poluting the
325+
# Send prepare statement. Copy the _request object to avoid polluting the
326326
# one that is going to be used to execute the actual operation.
327327
query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql,
328328
experimental_python_types=self._experimental_pyton_types)

trino/sqlalchemy/dialect.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def get_view_definition(self, connection: Connection, view_name: str, schema: st
231231
"""
232232
).strip()
233233
res = connection.execute(sql.text(query), schema=schema, view=view_name)
234-
return res.scalar()
234+
return res.scalar_one_or_none()
235235

236236
def get_indexes(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:
237237
if not self.has_table(connection, table_name, schema):
@@ -284,7 +284,7 @@ def get_table_comment(self, connection: Connection, table_name: str, schema: str
284284
sql.text(query),
285285
catalog_name=catalog_name, schema_name=schema_name, table_name=table_name
286286
)
287-
return dict(text=res.scalar())
287+
return dict(text=res.scalar_one_or_none())
288288
except error.TrinoQueryError as e:
289289
if e.error_name in (
290290
error.PERMISSION_DENIED,
@@ -326,7 +326,7 @@ def _get_server_version_info(self, connection: Connection) -> Any:
326326
query = "SELECT version()"
327327
try:
328328
res = connection.execute(sql.text(query))
329-
version = res.scalar()
329+
version = res.scalar_one()
330330
return tuple([version])
331331
except exc.ProgrammingError as e:
332332
logger.debug(f"Failed to get server version: {e.orig.message}")

0 commit comments

Comments
 (0)