Skip to content

Commit 328aeb5

Browse files
authored
[ PECO-2065 ] Create the async execution flow for the PySQL Connector (#463)
* Built the basic flow for the async pipeline - testing is remaining * Implemented the flow for the get_execution_result, but the problem of invalid operation handle still persists * Missed adding some files in previous commit * Working prototype of execute_async, get_query_state and get_execution_result * Added integration tests for execute_async * add docs for functions * Refractored the async code * Fixed java doc * Reformatted
1 parent 43fa964 commit 328aeb5

File tree

3 files changed

+203
-1
lines changed

3 files changed

+203
-1
lines changed

src/databricks/sql/client.py

+105
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence
23

34
import pandas
@@ -47,6 +48,7 @@
4748

4849
from databricks.sql.thrift_api.TCLIService.ttypes import (
4950
TSparkParameter,
51+
TOperationState,
5052
)
5153

5254

@@ -430,6 +432,8 @@ def __init__(
430432
self.escaper = ParamEscaper()
431433
self.lastrowid = None
432434

435+
self.ASYNC_DEFAULT_POLLING_INTERVAL = 2
436+
433437
# The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently.
434438
def __enter__(self) -> "Cursor":
435439
return self
@@ -796,6 +800,7 @@ def execute(
796800
cursor=self,
797801
use_cloud_fetch=self.connection.use_cloud_fetch,
798802
parameters=prepared_params,
803+
async_op=False,
799804
)
800805
self.active_result_set = ResultSet(
801806
self.connection,
@@ -812,6 +817,106 @@ def execute(
812817

813818
return self
814819

820+
def execute_async(
821+
self,
822+
operation: str,
823+
parameters: Optional[TParameterCollection] = None,
824+
) -> "Cursor":
825+
"""
826+
827+
Execute a query and do not wait for it to complete and just move ahead
828+
829+
:param operation:
830+
:param parameters:
831+
:return:
832+
"""
833+
param_approach = self._determine_parameter_approach(parameters)
834+
if param_approach == ParameterApproach.NONE:
835+
prepared_params = NO_NATIVE_PARAMS
836+
prepared_operation = operation
837+
838+
elif param_approach == ParameterApproach.INLINE:
839+
prepared_operation, prepared_params = self._prepare_inline_parameters(
840+
operation, parameters
841+
)
842+
elif param_approach == ParameterApproach.NATIVE:
843+
normalized_parameters = self._normalize_tparametercollection(parameters)
844+
param_structure = self._determine_parameter_structure(normalized_parameters)
845+
transformed_operation = transform_paramstyle(
846+
operation, normalized_parameters, param_structure
847+
)
848+
prepared_operation, prepared_params = self._prepare_native_parameters(
849+
transformed_operation, normalized_parameters, param_structure
850+
)
851+
852+
self._check_not_closed()
853+
self._close_and_clear_active_result_set()
854+
self.thrift_backend.execute_command(
855+
operation=prepared_operation,
856+
session_handle=self.connection._session_handle,
857+
max_rows=self.arraysize,
858+
max_bytes=self.buffer_size_bytes,
859+
lz4_compression=self.connection.lz4_compression,
860+
cursor=self,
861+
use_cloud_fetch=self.connection.use_cloud_fetch,
862+
parameters=prepared_params,
863+
async_op=True,
864+
)
865+
866+
return self
867+
868+
def get_query_state(self) -> "TOperationState":
869+
"""
870+
Get the state of the async executing query or basically poll the status of the query
871+
872+
:return:
873+
"""
874+
self._check_not_closed()
875+
return self.thrift_backend.get_query_state(self.active_op_handle)
876+
877+
def get_async_execution_result(self):
878+
"""
879+
880+
Checks for the status of the async executing query and fetches the result if the query is finished
881+
Otherwise it will keep polling the status of the query till there is a Not pending state
882+
:return:
883+
"""
884+
self._check_not_closed()
885+
886+
def is_executing(operation_state) -> "bool":
887+
return not operation_state or operation_state in [
888+
ttypes.TOperationState.RUNNING_STATE,
889+
ttypes.TOperationState.PENDING_STATE,
890+
]
891+
892+
while is_executing(self.get_query_state()):
893+
# Poll after some default time
894+
time.sleep(self.ASYNC_DEFAULT_POLLING_INTERVAL)
895+
896+
operation_state = self.get_query_state()
897+
if operation_state == ttypes.TOperationState.FINISHED_STATE:
898+
execute_response = self.thrift_backend.get_execution_result(
899+
self.active_op_handle, self
900+
)
901+
self.active_result_set = ResultSet(
902+
self.connection,
903+
execute_response,
904+
self.thrift_backend,
905+
self.buffer_size_bytes,
906+
self.arraysize,
907+
)
908+
909+
if execute_response.is_staging_operation:
910+
self._handle_staging_operation(
911+
staging_allowed_local_path=self.thrift_backend.staging_allowed_local_path
912+
)
913+
914+
return self
915+
else:
916+
raise Error(
917+
f"get_execution_result failed with Operation status {operation_state}"
918+
)
919+
815920
def executemany(self, operation, seq_of_parameters):
816921
"""
817922
Execute the operation once for every set of passed in parameters.

src/databricks/sql/thrift_backend.py

+75-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import threading
88
from typing import List, Union
99

10+
from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState
11+
1012
try:
1113
import pyarrow
1214
except ImportError:
@@ -769,6 +771,63 @@ def _results_message_to_execute_response(self, resp, operation_state):
769771
arrow_schema_bytes=schema_bytes,
770772
)
771773

774+
def get_execution_result(self, op_handle, cursor):
775+
776+
assert op_handle is not None
777+
778+
req = ttypes.TFetchResultsReq(
779+
operationHandle=ttypes.TOperationHandle(
780+
op_handle.operationId,
781+
op_handle.operationType,
782+
False,
783+
op_handle.modifiedRowCount,
784+
),
785+
maxRows=cursor.arraysize,
786+
maxBytes=cursor.buffer_size_bytes,
787+
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
788+
includeResultSetMetadata=True,
789+
)
790+
791+
resp = self.make_request(self._client.FetchResults, req)
792+
793+
t_result_set_metadata_resp = resp.resultSetMetadata
794+
795+
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
796+
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
797+
has_more_rows = resp.hasMoreRows
798+
description = self._hive_schema_to_description(
799+
t_result_set_metadata_resp.schema
800+
)
801+
802+
schema_bytes = (
803+
t_result_set_metadata_resp.arrowSchema
804+
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
805+
.serialize()
806+
.to_pybytes()
807+
)
808+
809+
queue = ResultSetQueueFactory.build_queue(
810+
row_set_type=resp.resultSetMetadata.resultFormat,
811+
t_row_set=resp.results,
812+
arrow_schema_bytes=schema_bytes,
813+
max_download_threads=self.max_download_threads,
814+
lz4_compressed=lz4_compressed,
815+
description=description,
816+
ssl_options=self._ssl_options,
817+
)
818+
819+
return ExecuteResponse(
820+
arrow_queue=queue,
821+
status=resp.status,
822+
has_been_closed_server_side=False,
823+
has_more_rows=has_more_rows,
824+
lz4_compressed=lz4_compressed,
825+
is_staging_operation=is_staging_operation,
826+
command_handle=op_handle,
827+
description=description,
828+
arrow_schema_bytes=schema_bytes,
829+
)
830+
772831
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
773832
if initial_operation_status_resp:
774833
self._check_command_not_in_error_or_closed_state(
@@ -787,6 +846,12 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
787846
self._check_command_not_in_error_or_closed_state(op_handle, poll_resp)
788847
return operation_state
789848

849+
def get_query_state(self, op_handle) -> "TOperationState":
850+
poll_resp = self._poll_for_status(op_handle)
851+
operation_state = poll_resp.operationState
852+
self._check_command_not_in_error_or_closed_state(op_handle, poll_resp)
853+
return operation_state
854+
790855
@staticmethod
791856
def _check_direct_results_for_error(t_spark_direct_results):
792857
if t_spark_direct_results:
@@ -817,6 +882,7 @@ def execute_command(
817882
cursor,
818883
use_cloud_fetch=True,
819884
parameters=[],
885+
async_op=False,
820886
):
821887
assert session_handle is not None
822888

@@ -846,7 +912,11 @@ def execute_command(
846912
parameters=parameters,
847913
)
848914
resp = self.make_request(self._client.ExecuteStatement, req)
849-
return self._handle_execute_response(resp, cursor)
915+
916+
if async_op:
917+
self._handle_execute_response_async(resp, cursor)
918+
else:
919+
return self._handle_execute_response(resp, cursor)
850920

851921
def get_catalogs(self, session_handle, max_rows, max_bytes, cursor):
852922
assert session_handle is not None
@@ -945,6 +1015,10 @@ def _handle_execute_response(self, resp, cursor):
9451015

9461016
return self._results_message_to_execute_response(resp, final_operation_state)
9471017

1018+
def _handle_execute_response_async(self, resp, cursor):
1019+
cursor.active_op_handle = resp.operationHandle
1020+
self._check_direct_results_for_error(resp.directResults)
1021+
9481022
def fetch_results(
9491023
self,
9501024
op_handle,

tests/e2e/test_driver.py

+23
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
compare_dbr_versions,
3737
is_thrift_v5_plus,
3838
)
39+
from databricks.sql.thrift_api.TCLIService import ttypes
3940
from tests.e2e.common.core_tests import CoreTestMixin, SmokeTestMixin
4041
from tests.e2e.common.large_queries_mixin import LargeQueriesMixin
4142
from tests.e2e.common.timestamp_tests import TimestampTestsMixin
@@ -78,6 +79,7 @@ class PySQLPytestTestCase:
7879
}
7980
arraysize = 1000
8081
buffer_size_bytes = 104857600
82+
POLLING_INTERVAL = 2
8183

8284
@pytest.fixture(autouse=True)
8385
def get_details(self, connection_details):
@@ -175,6 +177,27 @@ def test_cloud_fetch(self):
175177
for i in range(len(cf_result)):
176178
assert cf_result[i] == noop_result[i]
177179

180+
def test_execute_async(self):
181+
def isExecuting(operation_state):
182+
return not operation_state or operation_state in [
183+
ttypes.TOperationState.RUNNING_STATE,
184+
ttypes.TOperationState.PENDING_STATE,
185+
]
186+
187+
long_running_query = "SELECT COUNT(*) FROM RANGE(10000 * 16) x JOIN RANGE(10000) y ON FROM_UNIXTIME(x.id * y.id, 'yyyy-MM-dd') LIKE '%not%a%date%'"
188+
with self.cursor() as cursor:
189+
cursor.execute_async(long_running_query)
190+
191+
## Polling after every POLLING_INTERVAL seconds
192+
while isExecuting(cursor.get_query_state()):
193+
time.sleep(self.POLLING_INTERVAL)
194+
log.info("Polling the status in test_execute_async")
195+
196+
cursor.get_async_execution_result()
197+
result = cursor.fetchall()
198+
199+
assert result[0].asDict() == {"count(1)": 0}
200+
178201

179202
# Exclude Retry tests because they require specific setups, and LargeQueries too slow for core
180203
# tests

0 commit comments

Comments
 (0)