diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 4e0ab941..8ea81e12 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1,3 +1,4 @@ +import time from typing import Dict, Tuple, List, Optional, Any, Union, Sequence import pandas @@ -47,6 +48,7 @@ from databricks.sql.thrift_api.TCLIService.ttypes import ( TSparkParameter, + TOperationState, ) @@ -430,6 +432,8 @@ def __init__( self.escaper = ParamEscaper() self.lastrowid = None + self.ASYNC_DEFAULT_POLLING_INTERVAL = 2 + # The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently. def __enter__(self) -> "Cursor": return self @@ -796,6 +800,7 @@ def execute( cursor=self, use_cloud_fetch=self.connection.use_cloud_fetch, parameters=prepared_params, + async_op=False, ) self.active_result_set = ResultSet( self.connection, @@ -812,6 +817,106 @@ def execute( return self + def execute_async( + self, + operation: str, + parameters: Optional[TParameterCollection] = None, + ) -> "Cursor": + """ + + Execute a query and do not wait for it to complete and just move ahead + + :param operation: + :param parameters: + :return: + """ + param_approach = self._determine_parameter_approach(parameters) + if param_approach == ParameterApproach.NONE: + prepared_params = NO_NATIVE_PARAMS + prepared_operation = operation + + elif param_approach == ParameterApproach.INLINE: + prepared_operation, prepared_params = self._prepare_inline_parameters( + operation, parameters + ) + elif param_approach == ParameterApproach.NATIVE: + normalized_parameters = self._normalize_tparametercollection(parameters) + param_structure = self._determine_parameter_structure(normalized_parameters) + transformed_operation = transform_paramstyle( + operation, normalized_parameters, param_structure + ) + prepared_operation, prepared_params = self._prepare_native_parameters( + transformed_operation, normalized_parameters, param_structure + ) + + self._check_not_closed() + self._close_and_clear_active_result_set() + self.thrift_backend.execute_command( + operation=prepared_operation, + session_handle=self.connection._session_handle, + max_rows=self.arraysize, + max_bytes=self.buffer_size_bytes, + lz4_compression=self.connection.lz4_compression, + cursor=self, + use_cloud_fetch=self.connection.use_cloud_fetch, + parameters=prepared_params, + async_op=True, + ) + + return self + + def get_query_state(self) -> "TOperationState": + """ + Get the state of the async executing query or basically poll the status of the query + + :return: + """ + self._check_not_closed() + return self.thrift_backend.get_query_state(self.active_op_handle) + + def get_async_execution_result(self): + """ + + Checks for the status of the async executing query and fetches the result if the query is finished + Otherwise it will keep polling the status of the query till there is a Not pending state + :return: + """ + self._check_not_closed() + + def is_executing(operation_state) -> "bool": + return not operation_state or operation_state in [ + ttypes.TOperationState.RUNNING_STATE, + ttypes.TOperationState.PENDING_STATE, + ] + + while is_executing(self.get_query_state()): + # Poll after some default time + time.sleep(self.ASYNC_DEFAULT_POLLING_INTERVAL) + + operation_state = self.get_query_state() + if operation_state == ttypes.TOperationState.FINISHED_STATE: + execute_response = self.thrift_backend.get_execution_result( + self.active_op_handle, self + ) + self.active_result_set = ResultSet( + self.connection, + execute_response, + self.thrift_backend, + self.buffer_size_bytes, + self.arraysize, + ) + + if execute_response.is_staging_operation: + self._handle_staging_operation( + staging_allowed_local_path=self.thrift_backend.staging_allowed_local_path + ) + + return self + else: + raise Error( + f"get_execution_result failed with Operation status {operation_state}" + ) + def executemany(self, operation, seq_of_parameters): """ Execute the operation once for every set of passed in parameters. diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index cf5cd906..dbfd5936 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -7,6 +7,8 @@ import threading from typing import List, Union +from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState + try: import pyarrow except ImportError: @@ -769,6 +771,63 @@ def _results_message_to_execute_response(self, resp, operation_state): arrow_schema_bytes=schema_bytes, ) + def get_execution_result(self, op_handle, cursor): + + assert op_handle is not None + + req = ttypes.TFetchResultsReq( + operationHandle=ttypes.TOperationHandle( + op_handle.operationId, + op_handle.operationType, + False, + op_handle.modifiedRowCount, + ), + maxRows=cursor.arraysize, + maxBytes=cursor.buffer_size_bytes, + orientation=ttypes.TFetchOrientation.FETCH_NEXT, + includeResultSetMetadata=True, + ) + + resp = self.make_request(self._client.FetchResults, req) + + t_result_set_metadata_resp = resp.resultSetMetadata + + lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation + has_more_rows = resp.hasMoreRows + description = self._hive_schema_to_description( + t_result_set_metadata_resp.schema + ) + + schema_bytes = ( + t_result_set_metadata_resp.arrowSchema + or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + .serialize() + .to_pybytes() + ) + + queue = ResultSetQueueFactory.build_queue( + row_set_type=resp.resultSetMetadata.resultFormat, + t_row_set=resp.results, + arrow_schema_bytes=schema_bytes, + max_download_threads=self.max_download_threads, + lz4_compressed=lz4_compressed, + description=description, + ssl_options=self._ssl_options, + ) + + return ExecuteResponse( + arrow_queue=queue, + status=resp.status, + has_been_closed_server_side=False, + has_more_rows=has_more_rows, + lz4_compressed=lz4_compressed, + is_staging_operation=is_staging_operation, + command_handle=op_handle, + description=description, + arrow_schema_bytes=schema_bytes, + ) + def _wait_until_command_done(self, op_handle, initial_operation_status_resp): if initial_operation_status_resp: self._check_command_not_in_error_or_closed_state( @@ -787,6 +846,12 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp): self._check_command_not_in_error_or_closed_state(op_handle, poll_resp) return operation_state + def get_query_state(self, op_handle) -> "TOperationState": + poll_resp = self._poll_for_status(op_handle) + operation_state = poll_resp.operationState + self._check_command_not_in_error_or_closed_state(op_handle, poll_resp) + return operation_state + @staticmethod def _check_direct_results_for_error(t_spark_direct_results): if t_spark_direct_results: @@ -817,6 +882,7 @@ def execute_command( cursor, use_cloud_fetch=True, parameters=[], + async_op=False, ): assert session_handle is not None @@ -846,7 +912,11 @@ def execute_command( parameters=parameters, ) resp = self.make_request(self._client.ExecuteStatement, req) - return self._handle_execute_response(resp, cursor) + + if async_op: + self._handle_execute_response_async(resp, cursor) + else: + return self._handle_execute_response(resp, cursor) def get_catalogs(self, session_handle, max_rows, max_bytes, cursor): assert session_handle is not None @@ -945,6 +1015,10 @@ def _handle_execute_response(self, resp, cursor): return self._results_message_to_execute_response(resp, final_operation_state) + def _handle_execute_response_async(self, resp, cursor): + cursor.active_op_handle = resp.operationHandle + self._check_direct_results_for_error(resp.directResults) + def fetch_results( self, op_handle, diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index cfd1e969..2f0881cd 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -36,6 +36,7 @@ compare_dbr_versions, is_thrift_v5_plus, ) +from databricks.sql.thrift_api.TCLIService import ttypes from tests.e2e.common.core_tests import CoreTestMixin, SmokeTestMixin from tests.e2e.common.large_queries_mixin import LargeQueriesMixin from tests.e2e.common.timestamp_tests import TimestampTestsMixin @@ -78,6 +79,7 @@ class PySQLPytestTestCase: } arraysize = 1000 buffer_size_bytes = 104857600 + POLLING_INTERVAL = 2 @pytest.fixture(autouse=True) def get_details(self, connection_details): @@ -175,6 +177,27 @@ def test_cloud_fetch(self): for i in range(len(cf_result)): assert cf_result[i] == noop_result[i] + def test_execute_async(self): + def isExecuting(operation_state): + return not operation_state or operation_state in [ + ttypes.TOperationState.RUNNING_STATE, + ttypes.TOperationState.PENDING_STATE, + ] + + long_running_query = "SELECT COUNT(*) FROM RANGE(10000 * 16) x JOIN RANGE(10000) y ON FROM_UNIXTIME(x.id * y.id, 'yyyy-MM-dd') LIKE '%not%a%date%'" + with self.cursor() as cursor: + cursor.execute_async(long_running_query) + + ## Polling after every POLLING_INTERVAL seconds + while isExecuting(cursor.get_query_state()): + time.sleep(self.POLLING_INTERVAL) + log.info("Polling the status in test_execute_async") + + cursor.get_async_execution_result() + result = cursor.fetchall() + + assert result[0].asDict() == {"count(1)": 0} + # Exclude Retry tests because they require specific setups, and LargeQueries too slow for core # tests