Skip to content

Commit 4447b2e

Browse files
committed
Acknowledge reception of data in TrinoResult
1 parent ac4c458 commit 4447b2e

File tree

3 files changed

+14
-16
lines changed

3 files changed

+14
-16
lines changed

trino/client.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,8 @@ class TrinoResult(object):
594594

595595
def __init__(self, query, rows=None, experimental_python_types: bool = False):
596596
self._query = query
597-
self._rows = rows or []
597+
# Initial rows from the first POST request
598+
self._rows = rows
598599
self._rownumber = 0
599600
self._experimental_python_types = experimental_python_types
600601

@@ -603,20 +604,17 @@ def rownumber(self) -> int:
603604
return self._rownumber
604605

605606
def __iter__(self):
606-
# Initial fetch from the first POST request
607-
for row in self._rows:
608-
self._rownumber += 1
609-
yield self._map_row(self._experimental_python_types, row, self._query.columns)
610-
self._rows = None
611-
612-
# Subsequent fetches from GET requests until next_uri is empty.
613-
while not self._query.finished:
614-
rows = self._query.fetch()
615-
for row in rows:
607+
# A query only transitions to a FINISHED state when the results are fully consumed:
608+
# The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
609+
while not self._query.finished or self._rows is not None:
610+
next_rows = self._query.fetch() if not self._query.finished else None
611+
for row in self._rows:
616612
self._rownumber += 1
617613
logger.debug("row %s", row)
618614
yield self._map_row(self._experimental_python_types, row, self._query.columns)
619615

616+
self._rows = next_rows
617+
620618
@property
621619
def response_headers(self):
622620
return self._query.response_headers
@@ -723,7 +721,7 @@ def __init__(
723721
self._request = request
724722
self._update_type = None
725723
self._sql = sql
726-
self._result = TrinoResult(self, experimental_python_types=experimental_python_types)
724+
self._result: Optional[TrinoResult] = None
727725
self._response_headers = None
728726
self._experimental_python_types = experimental_python_types
729727

trino/dbapi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def _prepare_statement(self, operation, statement_name):
322322
operation=operation
323323
)
324324

325-
# Send prepare statement. Copy the _request object to avoid poluting the
325+
# Send prepare statement. Copy the _request object to avoid polluting the
326326
# one that is going to be used to execute the actual operation.
327327
query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql,
328328
experimental_python_types=self._experimental_pyton_types)

trino/sqlalchemy/dialect.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def get_view_definition(self, connection: Connection, view_name: str, schema: st
231231
"""
232232
).strip()
233233
res = connection.execute(sql.text(query), schema=schema, view=view_name)
234-
return res.scalar()
234+
return res.scalar_one_or_none()
235235

236236
def get_indexes(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:
237237
if not self.has_table(connection, table_name, schema):
@@ -282,7 +282,7 @@ def get_table_comment(self, connection: Connection, table_name: str, schema: str
282282
sql.text(query),
283283
catalog_name=catalog_name, schema_name=schema_name, table_name=table_name
284284
)
285-
return dict(text=res.scalar())
285+
return dict(text=res.scalar_one_or_none())
286286
except error.TrinoQueryError as e:
287287
if e.error_name in (
288288
error.PERMISSION_DENIED,
@@ -324,7 +324,7 @@ def _get_server_version_info(self, connection: Connection) -> Any:
324324
query = "SELECT version()"
325325
try:
326326
res = connection.execute(sql.text(query))
327-
version = res.scalar()
327+
version = res.scalar_one()
328328
return tuple([version])
329329
except exc.ProgrammingError as e:
330330
logger.debug(f"Failed to get server version: {e.orig.message}")

0 commit comments

Comments
 (0)