Skip to content

Commit e74c693

Browse files
author
Jesse Whitehouse
committed
Revert "Attempt to completely remove query_secret and send an empty UUID."
This reverts commit 1c58313. Signed-off-by: Jesse Whitehouse <[email protected]>
1 parent 1c58313 commit e74c693

File tree

3 files changed

+39
-22
lines changed

3 files changed

+39
-22
lines changed

src/databricks/sql/ae.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(
8787
thrift_backend: "ThriftBackend",
8888
connection: "Connection",
8989
query_id: UUID,
90+
query_secret: UUID,
9091
status: Optional[AsyncExecutionStatus] = AsyncExecutionStatus.UNKNOWN,
9192
execute_statement_response: Optional[
9293
Union[FakeExecuteStatementResponse, ttypes.TExecuteStatementResp]
@@ -95,6 +96,7 @@ def __init__(
9596
self._connection = connection
9697
self._thrift_backend = thrift_backend
9798
self.query_id = query_id
99+
self.query_secret = query_secret
98100
self.status = status
99101

100102
if execute_statement_response:
@@ -153,7 +155,7 @@ def _thrift_cancel_operation(self) -> None:
153155
def _thrift_get_operation_status(self) -> ttypes.TGetOperationStatusResp:
154156
"""Execute TGetOperationStatusReq
155157
156-
Raises an AsyncExecutionError if the query_id is not found on the server.
158+
Raises an AsyncExecutionError if the query_id:query_secret pair is not found on the server.
157159
"""
158160
try:
159161
return self._thrift_backend._poll_for_status(self.t_operation_handle)
@@ -164,10 +166,10 @@ def _thrift_get_operation_status(self) -> ttypes.TGetOperationStatusResp:
164166
) from e
165167

166168
def serialize(self) -> str:
167-
"""Return a hex string representing the query_id of this AsyncExecution.
169+
"""Return a string representing the query_id and secret of this AsyncExecution.
168170
169-
Use this to preserve a reference to the query_id"""
170-
return f"{self.query_id}"
171+
Use this to preserve a reference to the query_id and query_secret."""
172+
return f"{self.query_id}:{self.query_secret}"
171173

172174
def sync_status(self) -> None:
173175
"""Synchronise the status of this AsyncExecution with the server query execution state."""
@@ -210,7 +212,7 @@ def t_operation_handle(self) -> ttypes.TOperationHandle:
210212

211213
handle = ttypes.TOperationHandle(
212214
operationId=ttypes.THandleIdentifier(
213-
guid=self.query_id.bytes, secret=UUID(int=0).bytes
215+
guid=self.query_id.bytes, secret=self.query_secret.bytes
214216
),
215217
operationType=ttypes.TOperationType.EXECUTE_STATEMENT,
216218
hasResultSet=True,
@@ -236,6 +238,7 @@ def from_thrift_response(
236238
connection=connection,
237239
thrift_backend=thrift_backend,
238240
query_id=UUID(bytes=resp.operationHandle.operationId.guid),
241+
query_secret=UUID(bytes=resp.operationHandle.operationId.secret),
239242
status=_toperationstate_to_ae_status(
240243
resp.directResults.operationStatus.operationState
241244
),
@@ -248,17 +251,19 @@ def from_query_id_and_secret(
248251
connection: "Connection",
249252
thrift_backend: "ThriftBackend",
250253
query_id: UUID,
254+
query_secret: UUID,
251255
) -> "AsyncExecution":
252-
"""Return a valid AsyncExecution object from a query_id.
256+
"""Return a valid AsyncExecution object from a query_id and query_secret.
253257
254-
Raises an AsyncExecutionException if the query_id pair is not found on the server.
258+
Raises an AsyncExecutionException if the query_id:query_secret pair is not found on the server.
255259
"""
256260

257261
# build a copy of this execution
258262
ae = cls(
259263
connection=connection,
260264
thrift_backend=thrift_backend,
261265
query_id=query_id,
266+
query_secret=query_secret,
262267
)
263268
# check to make sure this is a valid one
264269
ae.sync_status()

src/databricks/sql/client.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,12 +405,13 @@ def execute_async(
405405
)
406406

407407
def get_async_execution(
408-
self, query_id: Union[str, UUID]
408+
self, query_id: Union[str, UUID], query_secret: Union[str, UUID]
409409
) -> "AsyncExecution":
410410
"""Get an AsyncExecution object for an existing query.
411411
412412
Args:
413413
query_id: The query id of the query to retrieve
414+
query_secret: The query secret of the query to retrieve
414415
415416
Returns:
416417
An AsyncExecution object that can be used to poll for status and retrieve results.
@@ -421,10 +422,16 @@ def get_async_execution(
421422
else:
422423
_qid = UUID(hex=query_id)
423424

425+
if isinstance(query_secret, UUID):
426+
_qs = query_secret
427+
else:
428+
_qs = UUID(hex=query_secret)
429+
424430
return AsyncExecution.from_query_id_and_secret(
425431
connection=self,
426432
thrift_backend=self.thrift_backend,
427433
query_id=_qid,
434+
query_secret=_qs,
428435
)
429436

430437

tests/e2e/test_execute_async.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,24 +82,28 @@ def test_cant_get_results_after_cancel(self, long_running_ae: AsyncExecution):
8282
def test_get_async_execution_can_check_status(
8383
self, long_running_ae: AsyncExecution
8484
):
85-
query_id = long_running_ae.serialize()
85+
query_id, query_secret = str(long_running_ae.query_id), str(
86+
long_running_ae.query_secret
87+
)
8688

8789
with self.connection() as conn:
88-
ae = conn.get_async_execution(query_id)
90+
ae = conn.get_async_execution(query_id, query_secret)
8991
assert ae.is_running
9092

9193
def test_get_async_execution_can_cancel_across_threads(
9294
self, long_running_ae: AsyncExecution
9395
):
94-
query_id = long_running_ae.serialize()
96+
query_id, query_secret = str(long_running_ae.query_id), str(
97+
long_running_ae.query_secret
98+
)
9599

96-
def cancel_query_in_separate_thread(query_id):
100+
def cancel_query_in_separate_thread(query_id, query_secret):
97101
with self.connection() as conn:
98-
ae = conn.get_async_execution(query_id)
102+
ae = conn.get_async_execution(query_id, query_secret)
99103
ae.cancel()
100104

101105
threading.Thread(
102-
target=cancel_query_in_separate_thread, args=(query_id)
106+
target=cancel_query_in_separate_thread, args=(query_id, query_secret)
103107
).start()
104108

105109
time.sleep(5)
@@ -150,29 +154,30 @@ def test_get_async_execution_with_badly_formed_query_id(self):
150154
ae = conn.get_async_execution("foo", "bar")
151155

152156
def test_serialize(self, long_running_ae: AsyncExecution):
153-
query_id = long_running_ae.serialize()
157+
serialized = long_running_ae.serialize()
158+
query_id, query_secret = serialized.split(":")
154159

155160
with self.connection() as conn:
156-
ae = conn.get_async_execution(query_id)
161+
ae = conn.get_async_execution(query_id, query_secret)
157162
assert ae.is_running
158163

159164
def test_get_async_execution_no_results_when_direct_results_were_sent(self):
160165
"""It remains to be seen whether results can be fetched repeatedly from a "picked up" execution."""
161166

162167
with self.connection() as conn:
163168
ae = conn.execute_async(DIRECT_RESULTS_QUERY, {"param": 1})
164-
query_id = ae.serialize()
169+
query_id, query_secret = ae.serialize().split(":")
165170
ae.get_result()
166171

167172
with self.connection() as conn:
168173
with pytest.raises(AsyncExecutionException, match="Query not found"):
169-
ae_late = conn.get_async_execution(query_id)
174+
ae_late = conn.get_async_execution(query_id, query_secret)
170175

171176
def test_get_async_execution_and_fetch_results(self, long_ish_ae: AsyncExecution):
172-
query_id = long_ish_ae.serialize()
177+
query_id, query_secret = long_ish_ae.serialize().split(":")
173178

174179
with self.connection() as conn:
175-
ae = conn.get_async_execution(query_id)
180+
ae = conn.get_async_execution(query_id, query_secret)
176181

177182
while ae.is_running:
178183
time.sleep(1)
@@ -189,8 +194,8 @@ def test_get_async_execution_twice(self):
189194
with self.connection() as conn_1, self.connection() as conn_2:
190195
ae_1 = conn_1.execute_async(LONG_ISH_QUERY)
191196

192-
query_id = ae_1.serialize()
193-
ae_2 = conn_2.get_async_execution(query_id)
197+
query_id, query_secret = ae_1.serialize().split(":")
198+
ae_2 = conn_2.get_async_execution(query_id, query_secret)
194199

195200
while ae_1.is_running:
196201
time.sleep(1)

0 commit comments

Comments
 (0)