Skip to content

Commit 40676b7

Browse files
author
Jesse
authored
[PECO-1263] Extract asyncexecution from thrift backend (#312)
Teach AsyncExecution to cope with results from ThriftBackend This way, thrift_backend.py doesn't need to even know about AsyncExecution This de-complicates the dependency between these two modules. Signed-off-by: Jesse Whitehouse <[email protected]>
1 parent d28e1aa commit 40676b7

File tree

3 files changed

+57
-44
lines changed

3 files changed

+57
-44
lines changed

src/databricks/sql/ae.py

+45-10
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional, Union, TYPE_CHECKING
33
from databricks.sql.results import ResultSet
44

5+
from dataclasses import dataclass
56

67
if TYPE_CHECKING:
78
from databricks.sql.thrift_backend import ThriftBackend
@@ -14,6 +15,11 @@
1415
from databricks.sql.thrift_api.TCLIService import ttypes
1516

1617

18+
@dataclass
19+
class FakeCursor:
20+
active_op_handle: Optional[ttypes.TOperationHandle]
21+
22+
1723
class AsyncExecutionStatus(Enum):
1824
"""An enum that represents the status of an async execution"""
1925

@@ -22,6 +28,8 @@ class AsyncExecutionStatus(Enum):
2228
FINISHED = 2
2329
CANCELED = 3
2430
FETCHED = 4
31+
32+
# todo: when is this ever evaluated?
2533
ABORTED = 5
2634

2735

@@ -44,21 +52,29 @@ class AsyncExecution:
4452
"""
4553
A class that represents an async execution of a query. Exposes just two methods:
4654
get_result_or_status and cancel
55+
56+
AsyncExecutions are effectively connectionless. But because thrift_backend is entangled
57+
with client.py, the AsyncExecution needs access to both a Connection and a ThriftBackend
58+
59+
This will need to be refactored for cleanliness in the future.
4760
"""
4861

4962
_connection: "Connection"
63+
_thrift_backend: "ThriftBackend"
5064
_result_set: Optional["ResultSet"]
5165
_execute_statement_response: Optional[ttypes.TExecuteStatementResp]
5266

5367
def __init__(
5468
self,
69+
thrift_backend: "ThriftBackend",
5570
connection: "Connection",
5671
query_id: UUID,
5772
query_secret: UUID,
5873
status: AsyncExecutionStatus,
5974
execute_statement_response: Optional[ttypes.TExecuteStatementResp] = None,
6075
):
6176
self._connection = connection
77+
self._thrift_backend = thrift_backend
6278
self._execute_statement_response = execute_statement_response
6379
self.query_id = query_id
6480
self.query_secret = query_secret
@@ -87,13 +103,13 @@ def cancel(self) -> None:
87103
def _thrift_cancel_operation(self) -> None:
88104
"""Execute TCancelOperation"""
89105

90-
_output = self._connection.thrift_backend.async_cancel_command(self.t_operation_handle)
106+
_output = self._thrift_backend.async_cancel_command(self.t_operation_handle)
91107
self.status = AsyncExecutionStatus.CANCELED
92108

93109
def _thrift_get_operation_status(self) -> None:
94110
"""Execute GetOperationStatusReq and map thrift execution status to DbsqlAsyncExecutionStatus"""
95111

96-
_output = self._connection.thrift_backend._poll_for_status(self.t_operation_handle)
112+
_output = self._thrift_backend._poll_for_status(self.t_operation_handle)
97113
self.status = _toperationstate_to_ae_status(_output)
98114

99115
def _thrift_fetch_result(self) -> None:
@@ -104,10 +120,10 @@ def _thrift_fetch_result(self) -> None:
104120
# support JSON and Thrift binary result formats in addition to arrow.
105121

106122
# in the case of direct results this creates a second cursor...how can I avoid that?
107-
with self._connection.cursor() as cursor:
108-
er = self._connection.thrift_backend._handle_execute_response(
109-
self._execute_statement_response, cursor
110-
)
123+
124+
er = self._thrift_backend._handle_execute_response(
125+
self._execute_statement_response, FakeCursor(None)
126+
)
111127

112128
self._result_set = ResultSet(
113129
connection=self._connection,
@@ -123,18 +139,37 @@ def is_running(self) -> bool:
123139
AsyncExecutionStatus.RUNNING,
124140
AsyncExecutionStatus.PENDING,
125141
]
126-
142+
127143
@property
128144
def t_operation_handle(self) -> ttypes.TOperationHandle:
129-
"""Return the current AsyncExecution as a Thrift TOperationHandle
130-
"""
145+
"""Return the current AsyncExecution as a Thrift TOperationHandle"""
131146

132147
handle = ttypes.TOperationHandle(
133148
operationId=ttypes.THandleIdentifier(
134149
guid=self.query_id.bytes, secret=self.query_secret.bytes
135150
),
136151
operationType=ttypes.TOperationType.EXECUTE_STATEMENT,
137-
hasResultSet=True
152+
hasResultSet=True,
138153
)
139154

140155
return handle
156+
157+
@classmethod
158+
def from_thrift_response(
159+
cls,
160+
connection: "Connection",
161+
thrift_backend: "ThriftBackend",
162+
resp: ttypes.TExecuteStatementResp,
163+
) -> "AsyncExecution":
164+
"""This method is meant to be consumed by `client.py`"""
165+
166+
return cls(
167+
connection=connection,
168+
thrift_backend=thrift_backend,
169+
query_id=UUID(bytes=resp.operationHandle.operationId.guid),
170+
query_secret=UUID(bytes=resp.operationHandle.operationId.secret),
171+
status=_toperationstate_to_ae_status(
172+
resp.directResults.operationStatus.operationState
173+
),
174+
execute_statement_response=resp,
175+
)

src/databricks/sql/client.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525

2626
from databricks.sql.results import ResultSet
2727

28-
if TYPE_CHECKING:
29-
from databricks.sql.ae import AsyncExecution, AsyncExecutionStatus
28+
29+
from databricks.sql.ae import AsyncExecution, AsyncExecutionStatus
3030

3131
logger = logging.getLogger(__name__)
3232

@@ -386,7 +386,7 @@ def execute_async(
386386
)
387387

388388
with self.cursor() as cursor:
389-
ae: "AsyncExecution" = self.thrift_backend.async_execute_statement(
389+
execute_statement_resp = self.thrift_backend.async_execute_statement(
390390
statement=prepared_operation,
391391
session_handle=self._session_handle,
392392
max_rows=cursor.arraysize,
@@ -397,8 +397,11 @@ def execute_async(
397397
parameters=prepared_params,
398398
)
399399

400-
# should we log this?
401-
return ae
400+
return AsyncExecution.from_thrift_response(
401+
connection=self,
402+
thrift_backend=self.thrift_backend,
403+
resp=execute_statement_resp,
404+
)
402405

403406

404407
class Cursor:

src/databricks/sql/thrift_backend.py

+4-29
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,10 @@
99
from typing import List, Union, Callable, TYPE_CHECKING, Optional
1010

1111

12-
from databricks.sql.ae import _toperationstate_to_ae_status, AsyncExecution
13-
1412
if TYPE_CHECKING:
1513
from databricks.sql.client import Cursor
1614
from databricks.sql.parameters.native import TSparkParameter
17-
from databricks.sql.ae import AsyncExecution
15+
1816

1917
import pyarrow
2018
import thrift.transport.THttpClient
@@ -430,7 +428,7 @@ def async_execute_statement(
430428
cursor: "Cursor",
431429
use_cloud_fetch: bool = True,
432430
parameters: Optional[List["TSparkParameter"]] = [],
433-
) -> "AsyncExecution":
431+
) -> ttypes.TExecuteStatementResp:
434432
"""Send an ExecuteStatement command to the server, and return an AsyncExecution object.
435433
436434
Args:
@@ -479,28 +477,7 @@ def async_execute_statement(
479477
self._client.ExecuteStatement, req
480478
)
481479

482-
query_id = guid = uuid.UUID(
483-
hex=self.guid_to_hex_id(resp.operationHandle.operationId.guid)
484-
)
485-
486-
secret = uuid.UUID(
487-
hex=self.guid_to_hex_id(resp.operationHandle.operationId.secret)
488-
)
489-
490-
# operationStatus -> TOperationstate
491-
status = _toperationstate_to_ae_status(
492-
resp.directResults.operationStatus.operationState
493-
)
494-
495-
ae = AsyncExecution(
496-
connection=cursor.connection,
497-
query_id=query_id,
498-
query_secret=secret,
499-
status=status,
500-
execute_statement_response=resp,
501-
)
502-
503-
return ae
480+
return resp
504481

505482
# FUTURE: Consider moving to https://github.com/litl/backoff or
506483
# https://github.com/jd/tenacity for retry logic.
@@ -1187,9 +1164,7 @@ def cancel_command(self, active_op_handle):
11871164
req = ttypes.TCancelOperationReq(active_op_handle)
11881165
self.make_request(self._client.CancelOperation, req)
11891166

1190-
def async_cancel_command(
1191-
self, op_handle: ttypes.TOperationHandle
1192-
) -> None:
1167+
def async_cancel_command(self, op_handle: ttypes.TOperationHandle) -> None:
11931168
"""Cancel a query using the thrift operation handle
11941169
11951170
Args:

0 commit comments

Comments
 (0)