Skip to content

Commit 0c4acba

Browse files
author
Jesse
authored
[PECO-1263] Separate get_status and poll_for_status methods (#313)
Signed-off-by: Jesse Whitehouse <[email protected]>
1 parent 40676b7 commit 0c4acba

File tree

3 files changed

+51
-35
lines changed

3 files changed

+51
-35
lines changed

src/databricks/sql/ae.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
from databricks.sql.thrift_api.TCLIService import ttypes
1616

1717

18+
class AsyncExecutionException(Exception):
19+
pass
20+
21+
1822
@dataclass
1923
class FakeCursor:
2024
active_op_handle: Optional[ttypes.TOperationHandle]
@@ -50,8 +54,7 @@ def _toperationstate_to_ae_status(
5054

5155
class AsyncExecution:
5256
"""
53-
A class that represents an async execution of a query. Exposes just two methods:
54-
get_result_or_status and cancel
57+
A class that represents an async execution of a query.
5558
5659
AsyncExecutions are effectively connectionless. But because thrift_backend is entangled
5760
with client.py, the AsyncExecution needs access to both a Connection and a ThriftBackend
@@ -83,18 +86,20 @@ def __init__(
8386
status: AsyncExecutionStatus
8487
query_id: UUID
8588

86-
def get_result_or_status(self) -> Union["ResultSet", AsyncExecutionStatus]:
87-
"""Get the result of the async execution. If execution has not completed, return False."""
89+
def get_result(self) -> "ResultSet":
90+
"""Get a result set for this async execution
91+
92+
Raises an exception if the query is still running or has been canceled.
93+
"""
8894

8995
if self.status == AsyncExecutionStatus.CANCELED:
90-
return self.status
96+
raise AsyncExecutionException("Query was canceled: %s" % self.query_id)
97+
if self.is_running:
98+
raise AsyncExecutionException("Query is still running: %s" % self.query_id)
9199
if self.status == AsyncExecutionStatus.FINISHED:
92100
self._thrift_fetch_result()
93101
if self.status == AsyncExecutionStatus.FETCHED:
94102
return self._result_set
95-
else:
96-
self._thrift_get_operation_status()
97-
return self.status
98103

99104
def cancel(self) -> None:
100105
"""Cancel the query"""
@@ -106,11 +111,13 @@ def _thrift_cancel_operation(self) -> None:
106111
_output = self._thrift_backend.async_cancel_command(self.t_operation_handle)
107112
self.status = AsyncExecutionStatus.CANCELED
108113

109-
def _thrift_get_operation_status(self) -> None:
110-
"""Execute GetOperationStatusReq and map thrift execution status to DbsqlAsyncExecutionStatus"""
114+
def poll_for_status(self) -> None:
115+
"""Check the thrift server for the status of this operation and set self.status
116+
117+
This will result in an error if the operation has been canceled or aborted at the server"""
111118

112119
_output = self._thrift_backend._poll_for_status(self.t_operation_handle)
113-
self.status = _toperationstate_to_ae_status(_output)
120+
self.status = _toperationstate_to_ae_status(_output.operationState)
114121

115122
def _thrift_fetch_result(self) -> None:
116123
"""Execute TFetchResultReq and store the result"""

src/databricks/sql/thrift_backend.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -775,10 +775,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
775775
num_rows,
776776
) = convert_column_based_set_to_arrow_table(t_row_set.columns, description)
777777
elif t_row_set.arrowBatches is not None:
778-
(
779-
arrow_table,
780-
num_rows,
781-
) = convert_arrow_based_set_to_arrow_table(
778+
(arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table(
782779
t_row_set.arrowBatches, lz4_compressed, schema_bytes
783780
)
784781
else:

tests/e2e/test_execute_async.py

+32-20
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,62 @@
11
from tests.e2e.test_driver import PySQLPytestTestCase
22

3-
from databricks.sql.ae import AsyncExecutionStatus as AsyncExecutionStatus
4-
3+
from databricks.sql.ae import (
4+
AsyncExecutionStatus,
5+
AsyncExecutionException,
6+
AsyncExecution,
7+
)
8+
import pytest
59
import time
610

7-
LONG_RUNNING_QUERY = """
11+
LONG_RUNNING_QUERY = """
812
SELECT SUM(A.id - B.id)
913
FROM range(1000000000) A CROSS JOIN range(100000000) B
1014
GROUP BY (A.id - B.id)
1115
"""
1216

17+
1318
class TestExecuteAsync(PySQLPytestTestCase):
19+
@pytest.fixture
20+
def long_running_ae(self, scope="function") -> AsyncExecution:
21+
"""Start a long-running query so we can make assertions about it."""
22+
with self.connection() as conn:
23+
ae = conn.execute_async(LONG_RUNNING_QUERY)
24+
yield ae
25+
26+
# cancellation is idempotent
27+
ae.cancel()
1428

1529
def test_basic_api(self):
16-
"""This is a WIP test of the basic API defined in PECO-1263
17-
"""
30+
"""This is a WIP test of the basic API defined in PECO-1263"""
1831
# This is taken directly from the design doc
1932

2033
with self.connection() as conn:
2134
ae = conn.execute_async("select :param `col`", {"param": 1})
2235
while ae.is_running:
23-
ae.get_result_or_status()
36+
ae.poll_for_status()
2437
time.sleep(1)
25-
26-
result = ae.get_result_or_status().fetchone()
2738

28-
assert result.col == 1
39+
result = ae.get_result().fetchone()
2940

30-
def test_cancel_running_query(self):
31-
"""Start a long-running query and cancel it
32-
"""
41+
assert result.col == 1
3342

34-
with self.connection() as conn:
35-
ae = conn.execute_async(LONG_RUNNING_QUERY)
36-
time.sleep(2)
37-
ae.cancel()
43+
def test_cancel_running_query(self, long_running_ae: AsyncExecution):
44+
long_running_ae.cancel()
45+
assert long_running_ae.status == AsyncExecutionStatus.CANCELED
3846

39-
status = ae.get_result_or_status()
47+
def test_cant_get_results_while_running(self, long_running_ae: AsyncExecution):
48+
with pytest.raises(AsyncExecutionException, match="Query is still running"):
49+
long_running_ae.get_result()
4050

41-
assert ae.status == AsyncExecutionStatus.CANCELED
51+
def test_cant_get_results_after_cancel(self, long_running_ae: AsyncExecution):
52+
long_running_ae.cancel()
53+
with pytest.raises(AsyncExecutionException, match="Query was canceled"):
54+
long_running_ae.get_result()
4255

43-
4456

4557
def test_staging_operation(self):
4658
"""We need to test what happens with a staging operation since this query won't have a result set
4759
that user needs. It could be sufficient to notify users that they shouldn't use this API for staging/volumes
4860
queries...
4961
"""
50-
assert False
62+
assert False

0 commit comments

Comments
 (0)