Skip to content

Commit 756ac17

Browse files
committed
Working prototype of execute_async, get_query_state and get_execution_result
1 parent 925b2a3 commit 756ac17

File tree

3 files changed

+37
-44
lines changed

3 files changed

+37
-44
lines changed

src/databricks/sql/client.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,7 @@ def execute(
733733
self,
734734
operation: str,
735735
parameters: Optional[TParameterCollection] = None,
736-
perform_async = False
736+
async_op=False,
737737
) -> "Cursor":
738738
"""
739739
Execute a query and wait for execution to complete.
@@ -797,14 +797,15 @@ def execute(
797797
cursor=self,
798798
use_cloud_fetch=self.connection.use_cloud_fetch,
799799
parameters=prepared_params,
800-
perform_async=perform_async,
800+
async_op=async_op,
801801
)
802802
self.active_result_set = ResultSet(
803803
self.connection,
804804
execute_response,
805805
self.thrift_backend,
806806
self.buffer_size_bytes,
807807
self.arraysize,
808+
async_op,
808809
)
809810

810811
if execute_response.is_staging_operation:
@@ -814,21 +815,25 @@ def execute(
814815

815816
return self
816817

817-
def execute_async(self,
818-
operation: str,
819-
parameters: Optional[TParameterCollection] = None,):
818+
def execute_async(
819+
self,
820+
operation: str,
821+
parameters: Optional[TParameterCollection] = None,
822+
):
820823
return self.execute(operation, parameters, True)
821824

822-
def get_query_status(self):
825+
def get_query_state(self):
823826
self._check_not_closed()
824-
return self.thrift_backend.get_query_status(self.active_op_handle)
827+
return self.thrift_backend.get_query_state(self.active_op_handle)
825828

826829
def get_execution_result(self):
827830
self._check_not_closed()
828831

829-
operation_state = self.get_query_status()
830-
if operation_state.statusCode == ttypes.TStatusCode.SUCCESS_STATUS or operation_state.statusCode == ttypes.TStatusCode.SUCCESS_WITH_INFO_STATUS:
831-
execute_response=self.thrift_backend.get_execution_result(self.active_op_handle)
832+
operation_state = self.get_query_state()
833+
if operation_state == ttypes.TOperationState.FINISHED_STATE:
834+
execute_response = self.thrift_backend.get_execution_result(
835+
self.active_op_handle, self
836+
)
832837
self.active_result_set = ResultSet(
833838
self.connection,
834839
execute_response,
@@ -844,7 +849,9 @@ def get_execution_result(self):
844849

845850
return self
846851
else:
847-
raise Error(f"get_execution_result failed with status code {operation_state.statusCode}")
852+
raise Error(
853+
f"get_execution_result failed with Operation status {operation_state}"
854+
)
848855

849856
def executemany(self, operation, seq_of_parameters):
850857
"""
@@ -1131,6 +1138,7 @@ def __init__(
11311138
thrift_backend: ThriftBackend,
11321139
result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES,
11331140
arraysize: int = 10000,
1141+
async_op=False,
11341142
):
11351143
"""
11361144
A ResultSet manages the results of a single command.
@@ -1153,7 +1161,7 @@ def __init__(
11531161
self._arrow_schema_bytes = execute_response.arrow_schema_bytes
11541162
self._next_row_index = 0
11551163

1156-
if execute_response.arrow_queue or True:
1164+
if execute_response.arrow_queue or async_op:
11571165
# In this case the server has taken the fast path and returned an initial batch of
11581166
# results
11591167
self.results = execute_response.arrow_queue

src/databricks/sql/constants.py

-12
This file was deleted.

src/databricks/sql/thrift_backend.py

+17-20
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
769769
arrow_schema_bytes=schema_bytes,
770770
)
771771

772-
def get_execution_result(self, op_handle):
772+
def get_execution_result(self, op_handle, cursor):
773773

774774
assert op_handle is not None
775775

@@ -780,15 +780,15 @@ def get_execution_result(self, op_handle):
780780
False,
781781
op_handle.modifiedRowCount,
782782
),
783-
maxRows=max_rows,
784-
maxBytes=max_bytes,
783+
maxRows=cursor.arraysize,
784+
maxBytes=cursor.buffer_size_bytes,
785785
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
786786
includeResultSetMetadata=True,
787787
)
788788

789789
resp = self.make_request(self._client.FetchResults, req)
790790

791-
t_result_set_metadata_resp = resp.resultSetMetaData
791+
t_result_set_metadata_resp = resp.resultSetMetadata
792792

793793
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
794794
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
@@ -797,15 +797,12 @@ def get_execution_result(self, op_handle):
797797
t_result_set_metadata_resp.schema
798798
)
799799

800-
if pyarrow:
801-
schema_bytes = (
802-
t_result_set_metadata_resp.arrowSchema
803-
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
804-
.serialize()
805-
.to_pybytes()
806-
)
807-
else:
808-
schema_bytes = None
800+
schema_bytes = (
801+
t_result_set_metadata_resp.arrowSchema
802+
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
803+
.serialize()
804+
.to_pybytes()
805+
)
809806

810807
queue = ResultSetQueueFactory.build_queue(
811808
row_set_type=resp.resultSetMetadata.resultFormat,
@@ -820,11 +817,11 @@ def get_execution_result(self, op_handle):
820817
return ExecuteResponse(
821818
arrow_queue=queue,
822819
status=resp.status,
823-
has_been_closed_server_side=has_been_closed_server_side,
820+
has_been_closed_server_side=False,
824821
has_more_rows=has_more_rows,
825822
lz4_compressed=lz4_compressed,
826823
is_staging_operation=is_staging_operation,
827-
command_handle=resp.operationHandle,
824+
command_handle=op_handle,
828825
description=description,
829826
arrow_schema_bytes=schema_bytes,
830827
)
@@ -847,9 +844,9 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
847844
self._check_command_not_in_error_or_closed_state(op_handle, poll_resp)
848845
return operation_state
849846

850-
def get_query_status(self, op_handle):
847+
def get_query_state(self, op_handle):
851848
poll_resp = self._poll_for_status(op_handle)
852-
operation_state = poll_resp.status
849+
operation_state = poll_resp.operationState
853850
self._check_command_not_in_error_or_closed_state(op_handle, poll_resp)
854851
return operation_state
855852

@@ -883,7 +880,7 @@ def execute_command(
883880
cursor,
884881
use_cloud_fetch=True,
885882
parameters=[],
886-
perform_async=False,
883+
async_op=False,
887884
):
888885
assert session_handle is not None
889886

@@ -914,7 +911,7 @@ def execute_command(
914911
)
915912
resp = self.make_request(self._client.ExecuteStatement, req)
916913

917-
if perform_async:
914+
if async_op:
918915
return self._handle_execute_response_async(resp, cursor)
919916
else:
920917
return self._handle_execute_response(resp, cursor)
@@ -1012,7 +1009,7 @@ def _handle_execute_response(self, resp, cursor):
10121009
final_operation_state = self._wait_until_command_done(
10131010
resp.operationHandle,
10141011
resp.directResults and resp.directResults.operationStatus,
1015-
)
1012+
)
10161013

10171014
return self._results_message_to_execute_response(resp, final_operation_state)
10181015

0 commit comments

Comments
 (0)