Skip to content

Commit b44b298

Browse files
committed
Refractored the async code
1 parent 8bf4442 commit b44b298

File tree

3 files changed

+55
-26
lines changed

3 files changed

+55
-26
lines changed

src/databricks/sql/client.py

+50-9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence
23

34
import pandas
@@ -430,6 +431,8 @@ def __init__(
430431
self.escaper = ParamEscaper()
431432
self.lastrowid = None
432433

434+
self.ASYNC_DEFAULT_POLLING_INTERVAL = 2
435+
433436
# The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently.
434437
def __enter__(self) -> "Cursor":
435438
return self
@@ -733,7 +736,6 @@ def execute(
733736
self,
734737
operation: str,
735738
parameters: Optional[TParameterCollection] = None,
736-
async_op=False,
737739
) -> "Cursor":
738740
"""
739741
Execute a query and wait for execution to complete.
@@ -802,15 +804,14 @@ def execute(
802804
cursor=self,
803805
use_cloud_fetch=self.connection.use_cloud_fetch,
804806
parameters=prepared_params,
805-
async_op=async_op,
807+
async_op=False,
806808
)
807809
self.active_result_set = ResultSet(
808810
self.connection,
809811
execute_response,
810812
self.thrift_backend,
811813
self.buffer_size_bytes,
812814
self.arraysize,
813-
async_op,
814815
)
815816

816817
if execute_response.is_staging_operation:
@@ -829,12 +830,43 @@ def execute_async(
829830
830831
Execute a query and do not wait for it to complete and just move ahead
831832
832-
Internally it calls execute function with async_op=True
833833
:param operation:
834834
:param parameters:
835835
:return:
836836
"""
837-
self.execute(operation, parameters, True)
837+
param_approach = self._determine_parameter_approach(parameters)
838+
if param_approach == ParameterApproach.NONE:
839+
prepared_params = NO_NATIVE_PARAMS
840+
prepared_operation = operation
841+
842+
elif param_approach == ParameterApproach.INLINE:
843+
prepared_operation, prepared_params = self._prepare_inline_parameters(
844+
operation, parameters
845+
)
846+
elif param_approach == ParameterApproach.NATIVE:
847+
normalized_parameters = self._normalize_tparametercollection(parameters)
848+
param_structure = self._determine_parameter_structure(normalized_parameters)
849+
transformed_operation = transform_paramstyle(
850+
operation, normalized_parameters, param_structure
851+
)
852+
prepared_operation, prepared_params = self._prepare_native_parameters(
853+
transformed_operation, normalized_parameters, param_structure
854+
)
855+
856+
self._check_not_closed()
857+
self._close_and_clear_active_result_set()
858+
self.thrift_backend.execute_command(
859+
operation=prepared_operation,
860+
session_handle=self.connection._session_handle,
861+
max_rows=self.arraysize,
862+
max_bytes=self.buffer_size_bytes,
863+
lz4_compression=self.connection.lz4_compression,
864+
cursor=self,
865+
use_cloud_fetch=self.connection.use_cloud_fetch,
866+
parameters=prepared_params,
867+
async_op=True,
868+
)
869+
838870
return self
839871

840872
def get_query_state(self) -> "TOperationState":
@@ -846,15 +878,25 @@ def get_query_state(self) -> "TOperationState":
846878
self._check_not_closed()
847879
return self.thrift_backend.get_query_state(self.active_op_handle)
848880

849-
def get_execution_result(self):
881+
def get_async_execution_result(self):
850882
"""
851883
852884
Checks for the status of the async executing query and fetches the result if the query is finished
853-
If executed sets the active_result_set to the obtained result
885+
Otherwise it will keep polling the status of the query till there is a Not pending state
854886
:return:
855887
"""
856888
self._check_not_closed()
857889

890+
def is_executing(operation_state) -> "bool":
891+
return not operation_state or operation_state in [
892+
ttypes.TOperationState.RUNNING_STATE,
893+
ttypes.TOperationState.PENDING_STATE,
894+
]
895+
896+
while(is_executing(self.get_query_state())):
897+
# Poll after some default time
898+
time.sleep(self.ASYNC_DEFAULT_POLLING_INTERVAL)
899+
858900
operation_state = self.get_query_state()
859901
if operation_state == ttypes.TOperationState.FINISHED_STATE:
860902
execute_response = self.thrift_backend.get_execution_result(
@@ -1164,7 +1206,6 @@ def __init__(
11641206
thrift_backend: ThriftBackend,
11651207
result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES,
11661208
arraysize: int = 10000,
1167-
async_op=False,
11681209
):
11691210
"""
11701211
A ResultSet manages the results of a single command.
@@ -1187,7 +1228,7 @@ def __init__(
11871228
self._arrow_schema_bytes = execute_response.arrow_schema_bytes
11881229
self._next_row_index = 0
11891230

1190-
if execute_response.arrow_queue or async_op:
1231+
if execute_response.arrow_queue:
11911232
# In this case the server has taken the fast path and returned an initial batch of
11921233
# results
11931234
self.results = execute_response.arrow_queue

src/databricks/sql/thrift_backend.py

+1-14
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ def execute_command(
914914
resp = self.make_request(self._client.ExecuteStatement, req)
915915

916916
if async_op:
917-
return self._handle_execute_response_async(resp, cursor)
917+
self._handle_execute_response_async(resp, cursor)
918918
else:
919919
return self._handle_execute_response(resp, cursor)
920920

@@ -1018,19 +1018,6 @@ def _handle_execute_response(self, resp, cursor):
10181018
def _handle_execute_response_async(self, resp, cursor):
10191019
cursor.active_op_handle = resp.operationHandle
10201020
self._check_direct_results_for_error(resp.directResults)
1021-
operation_status = resp.status.statusCode
1022-
1023-
return ExecuteResponse(
1024-
arrow_queue=None,
1025-
status=operation_status,
1026-
has_been_closed_server_side=None,
1027-
has_more_rows=None,
1028-
lz4_compressed=None,
1029-
is_staging_operation=None,
1030-
command_handle=resp.operationHandle,
1031-
description=None,
1032-
arrow_schema_bytes=None,
1033-
)
10341021

10351022
def fetch_results(
10361023
self,

tests/e2e/test_driver.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class PySQLPytestTestCase:
7979
}
8080
arraysize = 1000
8181
buffer_size_bytes = 104857600
82+
POLLING_INTERVAL = 2
8283

8384
@pytest.fixture(autouse=True)
8485
def get_details(self, connection_details):
@@ -187,12 +188,12 @@ def isExecuting(operation_state):
187188
with self.cursor() as cursor:
188189
cursor.execute_async(long_running_query)
189190

190-
## Polling after every 10 seconds
191+
## Polling after every POLLING_INTERVAL seconds
191192
while isExecuting(cursor.get_query_state()):
192-
time.sleep(10)
193+
time.sleep(self.POLLING_INTERVAL)
193194
log.info("Polling the status in test_execute_async")
194195

195-
cursor.get_execution_result()
196+
cursor.get_async_execution_result()
196197
result = cursor.fetchall()
197198

198199
assert result[0].asDict() == {"count(1)": 0}

0 commit comments

Comments
 (0)