Skip to content

Commit 7f6073d

Browse files
Merge remote-tracking branch 'origin/sea-migration' into fetch-interface
2 parents c91bc37 + 400a8bd commit 7f6073d

File tree

9 files changed

+102
-99
lines changed

9 files changed

+102
-99
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33
44
Implementations of this class are responsible for:
55
- Managing connections to Databricks SQL services
6-
- Handling authentication
76
- Executing SQL queries and commands
87
- Retrieving query results
98
- Fetching metadata about catalogs, schemas, tables, and columns
10-
- Managing error handling and retries
119
"""
1210

1311
from abc import ABC, abstractmethod
@@ -110,11 +108,44 @@ def cancel_command(self, command_id: CommandId) -> None:
110108
pass
111109

112110
@abstractmethod
113-
def close_command(self, command_id: CommandId) -> None:
111+
def close_command(self, command_id: CommandId) -> ttypes.TStatus:
112+
"""
113+
Closes a command and releases associated resources.
114+
115+
This method informs the server that the client is done with the command
116+
and any resources associated with it can be released.
117+
118+
Args:
119+
command_id: The command identifier to close
120+
121+
Returns:
122+
ttypes.TStatus: The status of the close operation
123+
124+
Raises:
125+
ValueError: If the command ID is invalid
126+
OperationalError: If there's an error closing the command
127+
"""
114128
pass
115129

116130
@abstractmethod
117-
def get_query_state(self, command_id: CommandId) -> CommandState:
131+
def get_query_state(self, command_id: CommandId) -> ttypes.TOperationState:
132+
"""
133+
Gets the current state of a query or command.
134+
135+
This method retrieves the current execution state of a command from the server.
136+
137+
Args:
138+
command_id: The command identifier to check
139+
140+
Returns:
141+
ttypes.TOperationState: The current state of the command
142+
143+
Raises:
144+
ValueError: If the command ID is invalid
145+
OperationalError: If there's an error retrieving the state
146+
ServerOperationError: If the command is in an error state
147+
DatabaseError: If the command has been closed unexpectedly
148+
"""
118149
pass
119150

120151
@abstractmethod
@@ -173,30 +204,29 @@ def get_columns(
173204
schema_name: Optional[str] = None,
174205
table_name: Optional[str] = None,
175206
column_name: Optional[str] = None,
176-
) -> "ResultSet":
177-
pass
178-
179-
# == Properties ==
180-
@property
181-
@abstractmethod
182-
def staging_allowed_local_path(self) -> Union[None, str, List[str]]:
207+
) -> ExecuteResponse:
183208
"""
184-
Gets the allowed local paths for staging operations.
209+
Retrieves a list of columns, optionally filtered by catalog, schema, table, and column name patterns.
185210
186-
Returns:
187-
Union[None, str, List[str]]: The allowed local paths for staging operations,
188-
or None if staging is not allowed
189-
"""
190-
pass
211+
This method fetches metadata about columns available in the specified table,
212+
or all tables if not specified.
191213
192-
@property
193-
@abstractmethod
194-
def ssl_options(self) -> SSLOptions:
195-
"""
196-
Gets the SSL options for this client.
214+
Args:
215+
session_id: The session identifier
216+
max_rows: Maximum number of rows to fetch in a single batch
217+
max_bytes: Maximum number of bytes to fetch in a single batch
218+
cursor: The cursor object that will handle the results
219+
catalog_name: Optional catalog name pattern to filter by
220+
schema_name: Optional schema name pattern to filter by
221+
table_name: Optional table name pattern to filter by
222+
column_name: Optional column name pattern to filter by
197223
198224
Returns:
199-
SSLOptions: The SSL configuration options
225+
ExecuteResponse: An object containing the column metadata
226+
227+
Raises:
228+
ValueError: If the session ID is invalid
229+
OperationalError: If there's an error retrieving the columns
200230
"""
201231
pass
202232

src/databricks/sql/backend/thrift_backend.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,18 @@
55
import time
66
import uuid
77
import threading
8-
from typing import List, Union, Any, TYPE_CHECKING
8+
from typing import List, Optional, Union, Any, TYPE_CHECKING
99

1010
if TYPE_CHECKING:
1111
from databricks.sql.client import Cursor
1212

1313
from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState
1414
from databricks.sql.backend.types import (
15-
CommandState,
1615
SessionId,
1716
CommandId,
1817
BackendType,
19-
guid_to_hex_id,
2018
)
19+
from databricks.sql.backend.utils import guid_to_hex_id
2120

2221
try:
2322
import pyarrow
@@ -103,7 +102,6 @@ def __init__(
103102
http_headers,
104103
auth_provider: AuthProvider,
105104
ssl_options: SSLOptions,
106-
staging_allowed_local_path: Union[None, str, List[str]] = None,
107105
**kwargs,
108106
):
109107
# Internal arguments in **kwargs:
@@ -162,7 +160,6 @@ def __init__(
162160
else:
163161
raise ValueError("No valid connection settings.")
164162

165-
self._staging_allowed_local_path = staging_allowed_local_path
166163
self._initialize_retry_args(kwargs)
167164
self._use_arrow_native_complex_types = kwargs.get(
168165
"_use_arrow_native_complex_types", True
@@ -236,14 +233,6 @@ def __init__(
236233

237234
self._request_lock = threading.RLock()
238235

239-
@property
240-
def staging_allowed_local_path(self) -> Union[None, str, List[str]]:
241-
return self._staging_allowed_local_path
242-
243-
@property
244-
def ssl_options(self) -> SSLOptions:
245-
return self._ssl_options
246-
247236
@property
248237
def max_download_threads(self) -> int:
249238
return self._max_download_threads

src/databricks/sql/backend/types.py

Lines changed: 14 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,16 @@
11
from enum import Enum
22
from typing import Dict, Optional, Any, Union
3-
import uuid
43
import logging
54

6-
from databricks.sql.thrift_api.TCLIService import ttypes
5+
from databricks.sql.backend.utils import guid_to_hex_id
76

87
logger = logging.getLogger(__name__)
98

109

11-
class CommandState(Enum):
12-
PENDING = "PENDING"
13-
RUNNING = "RUNNING"
14-
SUCCEEDED = "SUCCEEDED"
15-
FAILED = "FAILED"
16-
CLOSED = "CLOSED"
17-
CANCELLED = "CANCELLED"
18-
19-
@classmethod
20-
def from_thrift_state(cls, state: ttypes.TOperationState) -> "CommandState":
21-
if state in (
22-
ttypes.TOperationState.INITIALIZED_STATE,
23-
ttypes.TOperationState.PENDING_STATE,
24-
):
25-
return cls.PENDING
26-
elif state == ttypes.TOperationState.RUNNING_STATE:
27-
return cls.RUNNING
28-
elif state == ttypes.TOperationState.FINISHED_STATE:
29-
return cls.SUCCEEDED
30-
elif state in (
31-
ttypes.TOperationState.ERROR_STATE,
32-
ttypes.TOperationState.TIMEDOUT_STATE,
33-
ttypes.TOperationState.UKNOWN_STATE,
34-
):
35-
return cls.FAILED
36-
elif state == ttypes.TOperationState.CLOSED_STATE:
37-
return cls.CLOSED
38-
elif state == ttypes.TOperationState.CANCELED_STATE:
39-
return cls.CANCELLED
40-
else:
41-
raise ValueError(f"Unknown command state: {state}")
42-
43-
44-
def guid_to_hex_id(guid: bytes) -> str:
45-
"""Return a hexadecimal string instead of bytes
46-
47-
Example:
48-
IN b'\x01\xee\x1d)\xa4\x19\x1d\xb6\xa9\xc0\x8d\xf1\xfe\xbaB\xdd'
49-
OUT '01ee1d29-a419-1db6-a9c0-8df1feba42dd'
50-
51-
If conversion to hexadecimal fails, the original bytes are returned
52-
"""
53-
try:
54-
this_uuid = uuid.UUID(bytes=guid)
55-
except Exception as e:
56-
logger.debug(f"Unable to convert bytes to UUID: {guid!r} -- {str(e)}")
57-
return str(guid)
58-
return str(this_uuid)
59-
60-
6110
class BackendType(Enum):
62-
"""Enum representing the type of backend."""
11+
"""
12+
Enum representing the type of backend
13+
"""
6314

6415
THRIFT = "thrift"
6516
SEA = "sea"
@@ -87,7 +38,7 @@ def __init__(
8738
backend_type: The type of backend (THRIFT or SEA)
8839
guid: The primary identifier for the session
8940
secret: The secret part of the identifier (only used for Thrift)
90-
info: Additional information about the session
41+
properties: Additional information about the session
9142
"""
9243
self.backend_type = backend_type
9344
self.guid = guid
@@ -107,7 +58,12 @@ def __str__(self) -> str:
10758
if self.backend_type == BackendType.SEA:
10859
return str(self.guid)
10960
elif self.backend_type == BackendType.THRIFT:
110-
return f"{self.get_hex_id()}|{guid_to_hex_id(self.secret) if isinstance(self.secret, bytes) else str(self.secret)}"
61+
secret_hex = (
62+
guid_to_hex_id(self.secret)
63+
if isinstance(self.secret, bytes)
64+
else str(self.secret)
65+
)
66+
return f"{self.get_hex_guid()}|{secret_hex}"
11167
return str(self.guid)
11268

11369
@classmethod
@@ -181,13 +137,13 @@ def to_sea_session_id(self):
181137

182138
return self.guid
183139

184-
def get_id(self) -> Any:
140+
def get_guid(self) -> Any:
185141
"""
186142
Get the ID of the session.
187143
"""
188144
return self.guid
189145

190-
def get_hex_id(self) -> str:
146+
def get_hex_guid(self) -> str:
191147
"""
192148
Get a hexadecimal string representation of the session ID.
193149
@@ -316,7 +272,7 @@ def to_sea_statement_id(self):
316272

317273
return self.guid
318274

319-
def to_hex_id(self) -> str:
275+
def to_hex_guid(self) -> str:
320276
"""
321277
Get a hexadecimal string representation of the command ID.
322278
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .guid_utils import guid_to_hex_id
2+
3+
__all__ = ["guid_to_hex_id"]
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import uuid
2+
import logging
3+
4+
logger = logging.getLogger(__name__)
5+
6+
7+
def guid_to_hex_id(guid: bytes) -> str:
8+
"""Return a hexadecimal string instead of bytes
9+
10+
Example:
11+
IN b'\x01\xee\x1d)\xa4\x19\x1d\xb6\xa9\xc0\x8d\xf1\xfe\xbaB\xdd'
12+
OUT '01ee1d29-a419-1db6-a9c0-8df1feba42dd'
13+
14+
If conversion to hexadecimal fails, a string representation of the original
15+
bytes is returned
16+
"""
17+
try:
18+
this_uuid = uuid.UUID(bytes=guid)
19+
except Exception as e:
20+
logger.debug(f"Unable to convert bytes to UUID: {guid!r} -- {str(e)}")
21+
return str(guid)
22+
return str(this_uuid)

src/databricks/sql/client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def read(self) -> Optional[OAuthToken]:
247247
self.use_inline_params = self._set_use_inline_params_with_warning(
248248
kwargs.get("use_inline_params", False)
249249
)
250+
self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None)
250251

251252
def _set_use_inline_params_with_warning(self, value: Union[bool, str]):
252253
"""Valid values are True, False, and "silent"
@@ -785,7 +786,7 @@ def execute(
785786

786787
if self.active_result_set and self.active_result_set.is_staging_operation:
787788
self._handle_staging_operation(
788-
staging_allowed_local_path=self.backend.staging_allowed_local_path
789+
staging_allowed_local_path=self.connection.staging_allowed_local_path
789790
)
790791

791792
return self
@@ -881,7 +882,7 @@ def get_async_execution_result(self):
881882

882883
if self.active_result_set and self.active_result_set.is_staging_operation:
883884
self._handle_staging_operation(
884-
staging_allowed_local_path=self.backend.staging_allowed_local_path
885+
staging_allowed_local_path=self.connection.staging_allowed_local_path
885886
)
886887

887888
return self
@@ -1106,7 +1107,7 @@ def query_id(self) -> Optional[str]:
11061107
invoked via the execute method yet, or if cursor was closed.
11071108
"""
11081109
if self.active_command_id is not None:
1109-
return self.active_command_id.to_hex_id()
1110+
return self.active_command_id.to_hex_guid()
11101111
return None
11111112

11121113
@property

src/databricks/sql/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,11 @@ def get_session_id(self) -> SessionId:
116116

117117
def get_id(self):
118118
"""Get the raw session ID (backend-specific)"""
119-
return self._session_id.get_id()
119+
return self._session_id.get_guid()
120120

121121
def get_id_hex(self) -> str:
122122
"""Get the session ID in hex format"""
123-
return self._session_id.get_hex_id()
123+
return self._session_id.get_hex_guid()
124124

125125
def close(self) -> None:
126126
"""Close the underlying session."""

tests/unit/test_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self):
133133

134134
result_set = ThriftResultSet(
135135
connection=mock_connection,
136+
backend=mock_backend,
136137
execute_response=Mock(),
137138
thrift_client=mock_backend,
138139
)

tests/unit/test_thrift_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class ThriftBackendTestSuite(unittest.TestCase):
5353
open_session_resp = ttypes.TOpenSessionResp(
5454
status=okay_status,
5555
serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4,
56+
sessionHandle=session_handle,
5657
)
5758

5859
metadata_resp = ttypes.TGetResultSetMetadataResp(

0 commit comments

Comments
 (0)