diff --git a/src/databricks/sql/ae.py b/src/databricks/sql/ae.py index a0308f1b..0751e1bb 100644 --- a/src/databricks/sql/ae.py +++ b/src/databricks/sql/ae.py @@ -1,7 +1,10 @@ from enum import Enum from typing import Optional, Union, TYPE_CHECKING +from databricks.sql.exc import RequestError from databricks.sql.results import ResultSet +from datetime import datetime + from dataclasses import dataclass if TYPE_CHECKING: @@ -19,11 +22,23 @@ class AsyncExecutionException(Exception): pass +class AsyncExecutionUnrecoverableResultException(AsyncExecutionException): + """Raised when a result can never be retrieved for this query id.""" + + pass + + @dataclass class FakeCursor: active_op_handle: Optional[ttypes.TOperationHandle] +@dataclass +class FakeExecuteStatementResponse: + directResults: bool + operationHandle: ttypes.TOperationHandle + + class AsyncExecutionStatus(Enum): """An enum that represents the status of an async execution""" @@ -35,6 +50,7 @@ class AsyncExecutionStatus(Enum): # todo: when is this ever evaluated? ABORTED = 5 + UNKNOWN = 6 def _toperationstate_to_ae_status( @@ -54,18 +70,17 @@ def _toperationstate_to_ae_status( class AsyncExecution: """ - A class that represents an async execution of a query. - - AsyncExecutions are effectively connectionless. But because thrift_backend is entangled - with client.py, the AsyncExecution needs access to both a Connection and a ThriftBackend - - This will need to be refactored for cleanliness in the future. + A handle for a query execution on Databricks. """ _connection: "Connection" _thrift_backend: "ThriftBackend" _result_set: Optional["ResultSet"] - _execute_statement_response: Optional[ttypes.TExecuteStatementResp] + _execute_statement_response: Optional[ + Union[FakeExecuteStatementResponse, ttypes.TExecuteStatementResp] + ] + _last_sync_timestamp: Optional[datetime] = None + _result_set: Optional["ResultSet"] = None def __init__( self, @@ -73,33 +88,59 @@ def __init__( connection: "Connection", query_id: UUID, query_secret: UUID, - status: AsyncExecutionStatus, - execute_statement_response: Optional[ttypes.TExecuteStatementResp] = None, + status: Optional[AsyncExecutionStatus] = AsyncExecutionStatus.UNKNOWN, + execute_statement_response: Optional[ + Union[FakeExecuteStatementResponse, ttypes.TExecuteStatementResp] + ] = None, ): self._connection = connection self._thrift_backend = thrift_backend - self._execute_statement_response = execute_statement_response self.query_id = query_id self.query_secret = query_secret self.status = status + if execute_statement_response: + self._execute_statement_response = execute_statement_response + else: + self._execute_statement_response = FakeExecuteStatementResponse( + directResults=False, operationHandle=self.t_operation_handle + ) + status: AsyncExecutionStatus query_id: UUID - def get_result(self) -> "ResultSet": - """Get a result set for this async execution + def get_result( + self, + ) -> "ResultSet": + """Attempt to get the result of this query and set self.status to FETCHED. + + IMPORTANT: Generally, you'll call this method only after checking that the query is finished. + But you can call it at any time. If you call this method while the query is still running, + your code will block indefinitely until the query completes! This will be changed in a + subsequent release (PECO-1291) + + If you have already called get_result successfully, this method will return the same ResultSet + as before without making an additional roundtrip to the server. - Raises an exception if the query is still running or has been canceled. + Raises an AsyncExecutionUnrecoverableResultException if the query was canceled or aborted + at the server, so a result will never be available. """ - if self.status == AsyncExecutionStatus.CANCELED: - raise AsyncExecutionException("Query was canceled: %s" % self.query_id) - if self.is_running: - raise AsyncExecutionException("Query is still running: %s" % self.query_id) - if self.status == AsyncExecutionStatus.FINISHED: - self._thrift_fetch_result() - if self.status == AsyncExecutionStatus.FETCHED: - return self._result_set + # this isn't recoverable + if self.status in [AsyncExecutionStatus.ABORTED, AsyncExecutionStatus.CANCELED]: + raise AsyncExecutionUnrecoverableResultException( + "Result for %s is not recoverable. Query status is %s" + % (self.query_id, self.status), + ) + + return self._get_result_set() + + def _get_result_set(self) -> "ResultSet": + if self._result_set is None: + self._result_set = self._thrift_fetch_result() + self.status = AsyncExecutionStatus.FETCHED + + return self._result_set def cancel(self) -> None: """Cancel the query""" @@ -111,35 +152,45 @@ def _thrift_cancel_operation(self) -> None: _output = self._thrift_backend.async_cancel_command(self.t_operation_handle) self.status = AsyncExecutionStatus.CANCELED - def poll_for_status(self) -> None: - """Check the thrift server for the status of this operation and set self.status + def _thrift_get_operation_status(self) -> ttypes.TGetOperationStatusResp: + """Execute TGetOperationStatusReq + + Raises an AsyncExecutionError if the query_id:query_secret pair is not found on the server. + """ + try: + return self._thrift_backend._poll_for_status(self.t_operation_handle) + except RequestError as e: + if "RESOURCE_DOES_NOT_EXIST" in e.message: + raise AsyncExecutionException( + "Query not found: %s" % self.query_id + ) from e - This will result in an error if the operation has been canceled or aborted at the server""" + def serialize(self) -> str: + """Return a string representing the query_id and secret of this AsyncExecution. - _output = self._thrift_backend._poll_for_status(self.t_operation_handle) - self.status = _toperationstate_to_ae_status(_output.operationState) + Use this to preserve a reference to the query_id and query_secret.""" + return f"{self.query_id}:{self.query_secret}" - def _thrift_fetch_result(self) -> None: - """Execute TFetchResultReq and store the result""" + def sync_status(self) -> None: + """Synchronise the status of this AsyncExecution with the server query execution state.""" - # A cursor is required here to hook into the thrift_backend result fetching API - # TODO: need to rewrite this to use a generic result fetching API so we can - # support JSON and Thrift binary result formats in addition to arrow. + resp = self._thrift_get_operation_status() + self.status = _toperationstate_to_ae_status(resp.operationState) + self._last_sync_timestamp = datetime.now() - # in the case of direct results this creates a second cursor...how can I avoid that? + def _thrift_fetch_result(self) -> "ResultSet": + """Execute TFetchResultReq""" er = self._thrift_backend._handle_execute_response( self._execute_statement_response, FakeCursor(None) ) - self._result_set = ResultSet( + return ResultSet( connection=self._connection, execute_response=er, thrift_backend=self._connection.thrift_backend, ) - self.status = AsyncExecutionStatus.FETCHED - @property def is_running(self) -> bool: return self.status in [ @@ -147,6 +198,14 @@ def is_running(self) -> bool: AsyncExecutionStatus.PENDING, ] + @property + def is_canceled(self) -> bool: + return self.status == AsyncExecutionStatus.CANCELED + + @property + def is_finished(self) -> bool: + return self.status == AsyncExecutionStatus.FINISHED + @property def t_operation_handle(self) -> ttypes.TOperationHandle: """Return the current AsyncExecution as a Thrift TOperationHandle""" @@ -161,6 +220,11 @@ def t_operation_handle(self) -> ttypes.TOperationHandle: return handle + @property + def last_sync_timestamp(self) -> Optional[datetime]: + """The timestamp of the last time self.status was synced with the server""" + return self._last_sync_timestamp + @classmethod def from_thrift_response( cls, @@ -180,3 +244,28 @@ def from_thrift_response( ), execute_statement_response=resp, ) + + @classmethod + def from_query_id_and_secret( + cls, + connection: "Connection", + thrift_backend: "ThriftBackend", + query_id: UUID, + query_secret: UUID, + ) -> "AsyncExecution": + """Return a valid AsyncExecution object from a query_id and query_secret. + + Raises an AsyncExecutionException if the query_id:query_secret pair is not found on the server. + """ + + # build a copy of this execution + ae = cls( + connection=connection, + thrift_backend=thrift_backend, + query_id=query_id, + query_secret=query_secret, + ) + # check to make sure this is a valid one + ae.sync_status() + + return ae diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index fc7d1bc1..97b86ba0 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -27,6 +27,7 @@ from databricks.sql.ae import AsyncExecution, AsyncExecutionStatus +from uuid import UUID logger = logging.getLogger(__name__) @@ -403,6 +404,36 @@ def execute_async( resp=execute_statement_resp, ) + def get_async_execution( + self, query_id: Union[str, UUID], query_secret: Union[str, UUID] + ) -> "AsyncExecution": + """Get an AsyncExecution object for an existing query. + + Args: + query_id: The query id of the query to retrieve + query_secret: The query secret of the query to retrieve + + Returns: + An AsyncExecution object that can be used to poll for status and retrieve results. + """ + + if isinstance(query_id, UUID): + _qid = query_id + else: + _qid = UUID(hex=query_id) + + if isinstance(query_secret, UUID): + _qs = query_secret + else: + _qs = UUID(hex=query_secret) + + return AsyncExecution.from_query_id_and_secret( + connection=self, + thrift_backend=self.thrift_backend, + query_id=_qid, + query_secret=_qs, + ) + class Cursor: def __init__( diff --git a/tests/e2e/test_execute_async.py b/tests/e2e/test_execute_async.py index b73bd36e..6fde60ea 100644 --- a/tests/e2e/test_execute_async.py +++ b/tests/e2e/test_execute_async.py @@ -8,11 +8,25 @@ import pytest import time -LONG_RUNNING_QUERY = """ +import threading + +BASE_LONG_QUERY = """ SELECT SUM(A.id - B.id) -FROM range(1000000000) A CROSS JOIN range(100000000) B +FROM range({val}) A CROSS JOIN range({val}) B GROUP BY (A.id - B.id) """ +GT_ONE_MINUTE_VALUE = 100000000 + +# Arrived at this value through some manual testing on a serverless SQL warehouse +# The goal here is to have a query that takes longer than five seconds (therefore bypassing directResults) +# but not so long that I can't attempt to fetch its results in a reasonable amount of time +GT_FIVE_SECONDS_VALUE = 90000 + +LONG_RUNNING_QUERY = BASE_LONG_QUERY.format(val=GT_ONE_MINUTE_VALUE) +LONG_ISH_QUERY = BASE_LONG_QUERY.format(val=GT_FIVE_SECONDS_VALUE) + +# This query should always return in < 5 seconds and therefore should be a direct result +DIRECT_RESULTS_QUERY = "select :param `col`" class TestExecuteAsync(PySQLPytestTestCase): @@ -26,37 +40,172 @@ def long_running_ae(self, scope="function") -> AsyncExecution: # cancellation is idempotent ae.cancel() - def test_basic_api(self): + @pytest.fixture + def long_ish_ae(self, scope="function") -> AsyncExecution: + """Start a long-running query so we can make assertions about it.""" + with self.connection() as conn: + ae = conn.execute_async(LONG_ISH_QUERY) + yield ae + + def test_execute_async(self): """This is a WIP test of the basic API defined in PECO-1263""" # This is taken directly from the design doc with self.connection() as conn: - ae = conn.execute_async("select :param `col`", {"param": 1}) + ae = conn.execute_async(DIRECT_RESULTS_QUERY, {"param": 1}) while ae.is_running: - ae.poll_for_status() + ae.sync_status() time.sleep(1) result = ae.get_result().fetchone() assert result.col == 1 + def test_direct_results_query_canary(self): + """This test verifies that on the current endpoint, the DIRECT_RESULTS_QUERY returned a thrift operation state + other than FINISHED_STATE. If this test fails, it means the SQL warehouse got slower at executing this query + """ + + with self.connection() as conn: + ae = conn.execute_async(DIRECT_RESULTS_QUERY, {"param": 1}) + assert not ae.is_running + def test_cancel_running_query(self, long_running_ae: AsyncExecution): long_running_ae.cancel() assert long_running_ae.status == AsyncExecutionStatus.CANCELED - def test_cant_get_results_while_running(self, long_running_ae: AsyncExecution): - with pytest.raises(AsyncExecutionException, match="Query is still running"): - long_running_ae.get_result() - def test_cant_get_results_after_cancel(self, long_running_ae: AsyncExecution): long_running_ae.cancel() - with pytest.raises(AsyncExecutionException, match="Query was canceled"): + with pytest.raises(AsyncExecutionException, match="CANCELED"): long_running_ae.get_result() + def test_get_async_execution_can_check_status( + self, long_running_ae: AsyncExecution + ): + query_id, query_secret = str(long_running_ae.query_id), str( + long_running_ae.query_secret + ) - def test_staging_operation(self): - """We need to test what happens with a staging operation since this query won't have a result set - that user needs. It could be sufficient to notify users that they shouldn't use this API for staging/volumes - queries... + with self.connection() as conn: + ae = conn.get_async_execution(query_id, query_secret) + assert ae.is_running + + def test_get_async_execution_can_cancel_across_threads( + self, long_running_ae: AsyncExecution + ): + query_id, query_secret = str(long_running_ae.query_id), str( + long_running_ae.query_secret + ) + + def cancel_query_in_separate_thread(query_id, query_secret): + with self.connection() as conn: + ae = conn.get_async_execution(query_id, query_secret) + ae.cancel() + + threading.Thread( + target=cancel_query_in_separate_thread, args=(query_id, query_secret) + ).start() + + time.sleep(5) + + long_running_ae.sync_status() + assert long_running_ae.status == AsyncExecutionStatus.CANCELED + + def test_long_ish_query_canary(self, long_ish_ae: AsyncExecution): + """This test verifies that on the current endpoint, the LONG_ISH_QUERY requires + at least one sync_status call before it is finished. If this test fails, it means + the SQL warehouse got faster at executing this query and we should increment the value + of GT_FIVE_SECONDS_VALUE + + It would be easier to do this if Databricks SQL had a SLEEP() function :/ """ - assert False + + poll_count = 0 + while long_ish_ae.is_running: + time.sleep(1) + long_ish_ae.sync_status() + poll_count += 1 + + assert poll_count > 0 + + def test_get_async_execution_and_get_results_without_direct_results( + self, long_ish_ae: AsyncExecution + ): + while long_ish_ae.is_running: + time.sleep(1) + long_ish_ae.sync_status() + + result = long_ish_ae.get_result().fetchone() + assert len(result) == 1 + + def test_get_async_execution_with_bogus_query_id(self): + with self.connection() as conn: + with pytest.raises(AsyncExecutionException, match="Query not found"): + ae = conn.get_async_execution( + "bedc786d-64da-45d4-99da-5d3603525803", + "ba469f82-cf3f-454e-b575-f4dcd58dd692", + ) + + def test_get_async_execution_with_badly_formed_query_id(self): + with self.connection() as conn: + with pytest.raises( + ValueError, match="badly formed hexadecimal UUID string" + ): + ae = conn.get_async_execution("foo", "bar") + + def test_serialize(self, long_running_ae: AsyncExecution): + serialized = long_running_ae.serialize() + query_id, query_secret = serialized.split(":") + + with self.connection() as conn: + ae = conn.get_async_execution(query_id, query_secret) + assert ae.is_running + + def test_get_async_execution_no_results_when_direct_results_were_sent(self): + """It remains to be seen whether results can be fetched repeatedly from a "picked up" execution.""" + + with self.connection() as conn: + ae = conn.execute_async(DIRECT_RESULTS_QUERY, {"param": 1}) + query_id, query_secret = ae.serialize().split(":") + ae.get_result() + + with self.connection() as conn: + with pytest.raises(AsyncExecutionException, match="Query not found"): + ae_late = conn.get_async_execution(query_id, query_secret) + + def test_get_async_execution_and_fetch_results(self, long_ish_ae: AsyncExecution): + query_id, query_secret = long_ish_ae.serialize().split(":") + + with self.connection() as conn: + ae = conn.get_async_execution(query_id, query_secret) + + while ae.is_running: + time.sleep(1) + ae.sync_status() + + result = ae.get_result().fetchone() + + assert len(result) == 1 + + def test_get_async_execution_twice(self): + """This test demonstrates that the original AsyncExecution object can fetch a result + and a separate AsyncExecution object can also fetch a result. + """ + with self.connection() as conn_1, self.connection() as conn_2: + ae_1 = conn_1.execute_async(LONG_ISH_QUERY) + + query_id, query_secret = ae_1.serialize().split(":") + ae_2 = conn_2.get_async_execution(query_id, query_secret) + + while ae_1.is_running: + time.sleep(1) + ae_1.sync_status() + + result_1 = ae_1.get_result().fetchone() + assert len(result_1) == 1 + + ae_2.sync_status() + assert ae_2.status == AsyncExecutionStatus.FINISHED + + result_2 = ae_2.get_result().fetchone() + assert len(result_2) == 1