Skip to content

Commit bf046ff

Browse files
author
Jesse
authored
[PECO-1263] Add get_async_execution method (#314)
Signed-off-by: Jesse Whitehouse <[email protected]>
1 parent 0c4acba commit bf046ff

File tree

3 files changed

+319
-50
lines changed

3 files changed

+319
-50
lines changed

src/databricks/sql/ae.py

+124-35
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from enum import Enum
22
from typing import Optional, Union, TYPE_CHECKING
3+
from databricks.sql.exc import RequestError
34
from databricks.sql.results import ResultSet
45

6+
from datetime import datetime
7+
58
from dataclasses import dataclass
69

710
if TYPE_CHECKING:
@@ -19,11 +22,23 @@ class AsyncExecutionException(Exception):
1922
pass
2023

2124

25+
class AsyncExecutionUnrecoverableResultException(AsyncExecutionException):
26+
"""Raised when a result can never be retrieved for this query id."""
27+
28+
pass
29+
30+
2231
@dataclass
2332
class FakeCursor:
2433
active_op_handle: Optional[ttypes.TOperationHandle]
2534

2635

36+
@dataclass
37+
class FakeExecuteStatementResponse:
38+
directResults: bool
39+
operationHandle: ttypes.TOperationHandle
40+
41+
2742
class AsyncExecutionStatus(Enum):
2843
"""An enum that represents the status of an async execution"""
2944

@@ -35,6 +50,7 @@ class AsyncExecutionStatus(Enum):
3550

3651
# todo: when is this ever evaluated?
3752
ABORTED = 5
53+
UNKNOWN = 6
3854

3955

4056
def _toperationstate_to_ae_status(
@@ -54,52 +70,77 @@ def _toperationstate_to_ae_status(
5470

5571
class AsyncExecution:
5672
"""
57-
A class that represents an async execution of a query.
58-
59-
AsyncExecutions are effectively connectionless. But because thrift_backend is entangled
60-
with client.py, the AsyncExecution needs access to both a Connection and a ThriftBackend
61-
62-
This will need to be refactored for cleanliness in the future.
73+
A handle for a query execution on Databricks.
6374
"""
6475

6576
_connection: "Connection"
6677
_thrift_backend: "ThriftBackend"
6778
_result_set: Optional["ResultSet"]
68-
_execute_statement_response: Optional[ttypes.TExecuteStatementResp]
79+
_execute_statement_response: Optional[
80+
Union[FakeExecuteStatementResponse, ttypes.TExecuteStatementResp]
81+
]
82+
_last_sync_timestamp: Optional[datetime] = None
83+
_result_set: Optional["ResultSet"] = None
6984

7085
def __init__(
7186
self,
7287
thrift_backend: "ThriftBackend",
7388
connection: "Connection",
7489
query_id: UUID,
7590
query_secret: UUID,
76-
status: AsyncExecutionStatus,
77-
execute_statement_response: Optional[ttypes.TExecuteStatementResp] = None,
91+
status: Optional[AsyncExecutionStatus] = AsyncExecutionStatus.UNKNOWN,
92+
execute_statement_response: Optional[
93+
Union[FakeExecuteStatementResponse, ttypes.TExecuteStatementResp]
94+
] = None,
7895
):
7996
self._connection = connection
8097
self._thrift_backend = thrift_backend
81-
self._execute_statement_response = execute_statement_response
8298
self.query_id = query_id
8399
self.query_secret = query_secret
84100
self.status = status
85101

102+
if execute_statement_response:
103+
self._execute_statement_response = execute_statement_response
104+
else:
105+
self._execute_statement_response = FakeExecuteStatementResponse(
106+
directResults=False, operationHandle=self.t_operation_handle
107+
)
108+
86109
status: AsyncExecutionStatus
87110
query_id: UUID
88111

89-
def get_result(self) -> "ResultSet":
90-
"""Get a result set for this async execution
112+
def get_result(
113+
self,
114+
) -> "ResultSet":
115+
"""Attempt to get the result of this query and set self.status to FETCHED.
116+
117+
IMPORTANT: Generally, you'll call this method only after checking that the query is finished.
118+
But you can call it at any time. If you call this method while the query is still running,
119+
your code will block indefinitely until the query completes! This will be changed in a
120+
subsequent release (PECO-1291)
121+
122+
If you have already called get_result successfully, this method will return the same ResultSet
123+
as before without making an additional roundtrip to the server.
91124
92-
Raises an exception if the query is still running or has been canceled.
125+
Raises an AsyncExecutionUnrecoverableResultException if the query was canceled or aborted
126+
at the server, so a result will never be available.
93127
"""
94128

95-
if self.status == AsyncExecutionStatus.CANCELED:
96-
raise AsyncExecutionException("Query was canceled: %s" % self.query_id)
97-
if self.is_running:
98-
raise AsyncExecutionException("Query is still running: %s" % self.query_id)
99-
if self.status == AsyncExecutionStatus.FINISHED:
100-
self._thrift_fetch_result()
101-
if self.status == AsyncExecutionStatus.FETCHED:
102-
return self._result_set
129+
# this isn't recoverable
130+
if self.status in [AsyncExecutionStatus.ABORTED, AsyncExecutionStatus.CANCELED]:
131+
raise AsyncExecutionUnrecoverableResultException(
132+
"Result for %s is not recoverable. Query status is %s"
133+
% (self.query_id, self.status),
134+
)
135+
136+
return self._get_result_set()
137+
138+
def _get_result_set(self) -> "ResultSet":
139+
if self._result_set is None:
140+
self._result_set = self._thrift_fetch_result()
141+
self.status = AsyncExecutionStatus.FETCHED
142+
143+
return self._result_set
103144

104145
def cancel(self) -> None:
105146
"""Cancel the query"""
@@ -111,42 +152,60 @@ def _thrift_cancel_operation(self) -> None:
111152
_output = self._thrift_backend.async_cancel_command(self.t_operation_handle)
112153
self.status = AsyncExecutionStatus.CANCELED
113154

114-
def poll_for_status(self) -> None:
115-
"""Check the thrift server for the status of this operation and set self.status
155+
def _thrift_get_operation_status(self) -> ttypes.TGetOperationStatusResp:
156+
"""Execute TGetOperationStatusReq
157+
158+
Raises an AsyncExecutionError if the query_id:query_secret pair is not found on the server.
159+
"""
160+
try:
161+
return self._thrift_backend._poll_for_status(self.t_operation_handle)
162+
except RequestError as e:
163+
if "RESOURCE_DOES_NOT_EXIST" in e.message:
164+
raise AsyncExecutionException(
165+
"Query not found: %s" % self.query_id
166+
) from e
116167

117-
This will result in an error if the operation has been canceled or aborted at the server"""
168+
def serialize(self) -> str:
169+
"""Return a string representing the query_id and secret of this AsyncExecution.
118170
119-
_output = self._thrift_backend._poll_for_status(self.t_operation_handle)
120-
self.status = _toperationstate_to_ae_status(_output.operationState)
171+
Use this to preserve a reference to the query_id and query_secret."""
172+
return f"{self.query_id}:{self.query_secret}"
121173

122-
def _thrift_fetch_result(self) -> None:
123-
"""Execute TFetchResultReq and store the result"""
174+
def sync_status(self) -> None:
175+
"""Synchronise the status of this AsyncExecution with the server query execution state."""
124176

125-
# A cursor is required here to hook into the thrift_backend result fetching API
126-
# TODO: need to rewrite this to use a generic result fetching API so we can
127-
# support JSON and Thrift binary result formats in addition to arrow.
177+
resp = self._thrift_get_operation_status()
178+
self.status = _toperationstate_to_ae_status(resp.operationState)
179+
self._last_sync_timestamp = datetime.now()
128180

129-
# in the case of direct results this creates a second cursor...how can I avoid that?
181+
def _thrift_fetch_result(self) -> "ResultSet":
182+
"""Execute TFetchResultReq"""
130183

131184
er = self._thrift_backend._handle_execute_response(
132185
self._execute_statement_response, FakeCursor(None)
133186
)
134187

135-
self._result_set = ResultSet(
188+
return ResultSet(
136189
connection=self._connection,
137190
execute_response=er,
138191
thrift_backend=self._connection.thrift_backend,
139192
)
140193

141-
self.status = AsyncExecutionStatus.FETCHED
142-
143194
@property
144195
def is_running(self) -> bool:
145196
return self.status in [
146197
AsyncExecutionStatus.RUNNING,
147198
AsyncExecutionStatus.PENDING,
148199
]
149200

201+
@property
202+
def is_canceled(self) -> bool:
203+
return self.status == AsyncExecutionStatus.CANCELED
204+
205+
@property
206+
def is_finished(self) -> bool:
207+
return self.status == AsyncExecutionStatus.FINISHED
208+
150209
@property
151210
def t_operation_handle(self) -> ttypes.TOperationHandle:
152211
"""Return the current AsyncExecution as a Thrift TOperationHandle"""
@@ -161,6 +220,11 @@ def t_operation_handle(self) -> ttypes.TOperationHandle:
161220

162221
return handle
163222

223+
@property
224+
def last_sync_timestamp(self) -> Optional[datetime]:
225+
"""The timestamp of the last time self.status was synced with the server"""
226+
return self._last_sync_timestamp
227+
164228
@classmethod
165229
def from_thrift_response(
166230
cls,
@@ -180,3 +244,28 @@ def from_thrift_response(
180244
),
181245
execute_statement_response=resp,
182246
)
247+
248+
@classmethod
249+
def from_query_id_and_secret(
250+
cls,
251+
connection: "Connection",
252+
thrift_backend: "ThriftBackend",
253+
query_id: UUID,
254+
query_secret: UUID,
255+
) -> "AsyncExecution":
256+
"""Return a valid AsyncExecution object from a query_id and query_secret.
257+
258+
Raises an AsyncExecutionException if the query_id:query_secret pair is not found on the server.
259+
"""
260+
261+
# build a copy of this execution
262+
ae = cls(
263+
connection=connection,
264+
thrift_backend=thrift_backend,
265+
query_id=query_id,
266+
query_secret=query_secret,
267+
)
268+
# check to make sure this is a valid one
269+
ae.sync_status()
270+
271+
return ae

src/databricks/sql/client.py

+31
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828

2929
from databricks.sql.ae import AsyncExecution, AsyncExecutionStatus
30+
from uuid import UUID
3031

3132
logger = logging.getLogger(__name__)
3233

@@ -403,6 +404,36 @@ def execute_async(
403404
resp=execute_statement_resp,
404405
)
405406

407+
def get_async_execution(
408+
self, query_id: Union[str, UUID], query_secret: Union[str, UUID]
409+
) -> "AsyncExecution":
410+
"""Get an AsyncExecution object for an existing query.
411+
412+
Args:
413+
query_id: The query id of the query to retrieve
414+
query_secret: The query secret of the query to retrieve
415+
416+
Returns:
417+
An AsyncExecution object that can be used to poll for status and retrieve results.
418+
"""
419+
420+
if isinstance(query_id, UUID):
421+
_qid = query_id
422+
else:
423+
_qid = UUID(hex=query_id)
424+
425+
if isinstance(query_secret, UUID):
426+
_qs = query_secret
427+
else:
428+
_qs = UUID(hex=query_secret)
429+
430+
return AsyncExecution.from_query_id_and_secret(
431+
connection=self,
432+
thrift_backend=self.thrift_backend,
433+
query_id=_qid,
434+
query_secret=_qs,
435+
)
436+
406437

407438
class Cursor:
408439
def __init__(

0 commit comments

Comments
 (0)