Skip to content

Commit 7b0cbed

Browse files
1 parent dac08f2 commit 7b0cbed

File tree

4 files changed

+38
-12
lines changed

4 files changed

+38
-12
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Dict, Tuple, List, Optional, Any, Union
33

44
from databricks.sql.thrift_api.TCLIService import ttypes
5-
from databricks.sql.backend.types import SessionId, CommandId
5+
from databricks.sql.backend.types import SessionId, CommandId, CommandState
66
from databricks.sql.utils import ExecuteResponse
77
from databricks.sql.types import SSLOptions
88

@@ -54,7 +54,7 @@ def close_command(self, command_id: CommandId) -> ttypes.TStatus:
5454
pass
5555

5656
@abstractmethod
57-
def get_query_state(self, command_id: CommandId) -> ttypes.TOperationState:
57+
def get_query_state(self, command_id: CommandId) -> CommandState:
5858
pass
5959

6060
@abstractmethod

src/databricks/sql/backend/thrift_backend.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState
1111
from databricks.sql.backend.types import (
12+
CommandState,
1213
SessionId,
1314
CommandId,
1415
BackendType,
@@ -903,15 +904,15 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
903904
self._check_command_not_in_error_or_closed_state(op_handle, poll_resp)
904905
return operation_state
905906

906-
def get_query_state(self, command_id: CommandId) -> "TOperationState":
907+
def get_query_state(self, command_id: CommandId) -> CommandState:
907908
thrift_handle = command_id.to_thrift_handle()
908909
if not thrift_handle:
909910
raise ValueError("Not a valid Thrift command ID")
910911

911912
poll_resp = self._poll_for_status(thrift_handle)
912913
operation_state = poll_resp.operationState
913914
self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp)
914-
return operation_state
915+
return CommandState.from_thrift_state(operation_state)
915916

916917
@staticmethod
917918
def _check_direct_results_for_error(t_spark_direct_results):

src/databricks/sql/backend/types.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,38 @@
33
import uuid
44
import logging
55

6+
from databricks.sql.thrift_api.TCLIService import ttypes
7+
68
logger = logging.getLogger(__name__)
79

810

11+
class CommandState(Enum):
12+
PENDING = "PENDING"
13+
RUNNING = "RUNNING"
14+
SUCCEEDED = "SUCCEEDED"
15+
FAILED = "FAILED"
16+
CLOSED = "CLOSED"
17+
CANCELLED = "CANCELLED"
18+
19+
@classmethod
20+
def from_thrift_state(cls, state: ttypes.TOperationState) -> "CommandState":
21+
match state:
22+
case ttypes.TOperationState.INITIALIZED_STATE | ttypes.TOperationState.PENDING_STATE:
23+
return cls.PENDING
24+
case ttypes.TOperationState.RUNNING_STATE:
25+
return cls.RUNNING
26+
case ttypes.TOperationState.FINISHED_STATE:
27+
return cls.SUCCEEDED
28+
case ttypes.TOperationState.ERROR_STATE | ttypes.TOperationState.TIMEDOUT_STATE | ttypes.TOperationState.UKNOWN_STATE:
29+
return cls.FAILED
30+
case ttypes.TOperationState.CLOSED_STATE:
31+
return cls.CLOSED
32+
case ttypes.TOperationState.CANCELLED_STATE:
33+
return cls.CANCELLED
34+
case _:
35+
raise ValueError(f"Unknown command state: {state}")
36+
37+
938
def guid_to_hex_id(guid: bytes) -> str:
1039
"""Return a hexadecimal string instead of bytes
1140

src/databricks/sql/client.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
4848
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
4949
from databricks.sql.session import Session
50-
from databricks.sql.backend.types import CommandId, BackendType, SessionId
50+
from databricks.sql.backend.types import CommandId, BackendType, CommandState, SessionId
5151

5252
from databricks.sql.thrift_api.TCLIService.ttypes import (
5353
TSparkParameter,
@@ -846,7 +846,7 @@ def execute_async(
846846

847847
return self
848848

849-
def get_query_state(self) -> "TOperationState":
849+
def get_query_state(self) -> CommandState:
850850
"""
851851
Get the state of the async executing query or basically poll the status of the query
852852
@@ -862,11 +862,7 @@ def is_query_pending(self):
862862
:return:
863863
"""
864864
operation_state = self.get_query_state()
865-
866-
return not operation_state or operation_state in [
867-
ttypes.TOperationState.RUNNING_STATE,
868-
ttypes.TOperationState.PENDING_STATE,
869-
]
865+
return operation_state in [CommandState.PENDING, CommandState.RUNNING]
870866

871867
def get_async_execution_result(self):
872868
"""
@@ -882,7 +878,7 @@ def get_async_execution_result(self):
882878
time.sleep(self.ASYNC_DEFAULT_POLLING_INTERVAL)
883879

884880
operation_state = self.get_query_state()
885-
if operation_state == ttypes.TOperationState.FINISHED_STATE:
881+
if operation_state == CommandState.SUCCEEDED:
886882
self.active_result_set = self.backend.get_execution_result(
887883
self.active_op_handle, self
888884
)

0 commit comments

Comments
 (0)