8
8
from databricks .sql .backend .databricks_client import DatabricksClient
9
9
from databricks .sql .backend .types import SessionId , CommandId , CommandState , BackendType
10
10
from databricks .sql .exc import Error , NotSupportedError
11
- from databricks .sql .sea . http_client import SEAHttpClient
11
+ from databricks .sql .backend . utils . http_client import CustomHttpClient
12
12
from databricks .sql .thrift_api .TCLIService import ttypes
13
13
from databricks .sql .types import SSLOptions
14
14
15
15
logger = logging .getLogger (__name__ )
16
16
17
17
18
- class SEADatabricksClient (DatabricksClient ):
18
+ class SeaDatabricksClient (DatabricksClient ):
19
19
"""
20
20
Statement Execution API (SEA) implementation of the DatabricksClient interface.
21
21
@@ -67,28 +67,10 @@ def __init__(
67
67
self ._max_download_threads = kwargs .get ("max_download_threads" , 10 )
68
68
69
69
# Extract warehouse ID from http_path
70
- # Format could be either:
71
- # - /sql/1.0/endpoints/{warehouse_id}
72
- # - /sql/1.0/warehouses/{warehouse_id}
73
- path_parts = http_path .strip ("/" ).split ("/" )
74
- self .warehouse_id = None
75
-
76
- if len (path_parts ) >= 3 :
77
- if path_parts [- 2 ] in ["endpoints" , "warehouses" ]:
78
- self .warehouse_id = path_parts [- 1 ]
79
- logger .debug (
80
- f"Extracted warehouse ID: { self .warehouse_id } from path: { http_path } "
81
- )
82
-
83
- if not self .warehouse_id :
84
- logger .warning (
85
- "Could not extract warehouse ID from http_path: %s. "
86
- "Session creation may fail if warehouse ID is required." ,
87
- http_path ,
88
- )
70
+ self .warehouse_id = self ._extract_warehouse_id (http_path )
89
71
90
72
# Initialize HTTP client
91
- self .http_client = SEAHttpClient (
73
+ self .http_client = CustomHttpClient (
92
74
server_hostname = server_hostname ,
93
75
port = port ,
94
76
http_path = http_path ,
@@ -98,6 +80,41 @@ def __init__(
98
80
** kwargs ,
99
81
)
100
82
83
+ def _extract_warehouse_id (self , http_path : str ) -> str :
84
+ """
85
+ Extract the warehouse ID from the HTTP path.
86
+
87
+ The warehouse ID is expected to be the last segment of the path when the
88
+ second-to-last segment is either 'warehouses' or 'endpoints'.
89
+ This matches the JDBC implementation which supports both formats.
90
+
91
+ Args:
92
+ http_path: The HTTP path from which to extract the warehouse ID
93
+
94
+ Returns:
95
+ The extracted warehouse ID
96
+
97
+ Raises:
98
+ Error: If the warehouse ID cannot be extracted from the path
99
+ """
100
+ path_parts = http_path .strip ("/" ).split ("/" )
101
+ warehouse_id = None
102
+
103
+ if len (path_parts ) >= 3 and path_parts [- 2 ] in ["warehouses" , "endpoints" ]:
104
+ warehouse_id = path_parts [- 1 ]
105
+ logger .debug (f"Extracted warehouse ID: { warehouse_id } from path: { http_path } " )
106
+
107
+ if not warehouse_id :
108
+ error_message = (
109
+ f"Could not extract warehouse ID from http_path: { http_path } . "
110
+ f"Expected format: /path/to/warehouses/{{warehouse_id}} or "
111
+ f"/path/to/endpoints/{{warehouse_id}}"
112
+ )
113
+ logger .error (error_message )
114
+ raise ValueError (error_message )
115
+
116
+ return warehouse_id
117
+
101
118
@property
102
119
def staging_allowed_local_path (self ) -> Union [None , str , List [str ]]:
103
120
"""Get the allowed local paths for staging operations."""
@@ -115,7 +132,7 @@ def max_download_threads(self) -> int:
115
132
116
133
def open_session (
117
134
self ,
118
- session_configuration : Optional [Dict [str , Any ]],
135
+ session_configuration : Optional [Dict [str , str ]],
119
136
catalog : Optional [str ],
120
137
schema : Optional [str ],
121
138
) -> SessionId :
@@ -141,36 +158,23 @@ def open_session(
141
158
schema ,
142
159
)
143
160
144
- # Prepare request payload
145
- request_data : Dict [str , Any ] = {}
146
-
147
- if self .warehouse_id :
148
- request_data ["warehouse_id" ] = self .warehouse_id
149
-
161
+ request_data : Dict [str , Any ] = {"warehouse_id" : self .warehouse_id }
150
162
if session_configuration :
151
- # The SEA API expects "session_confs" as the key for session configuration
152
163
request_data ["session_confs" ] = session_configuration
153
-
154
164
if catalog :
155
165
request_data ["catalog" ] = catalog
156
-
157
166
if schema :
158
167
request_data ["schema" ] = schema
159
168
160
- # Make API request
161
169
response = self .http_client ._make_request (
162
170
method = "POST" , path = self .SESSION_PATH , data = request_data
163
171
)
164
172
165
- # Extract session ID from response
166
173
session_id = response .get ("session_id" )
167
174
if not session_id :
168
175
raise Error ("Failed to create session: No session ID returned" )
169
176
170
- # Create and return SessionId object
171
- return SessionId .from_sea_session_id (
172
- session_id , {"warehouse_id" : self .warehouse_id }
173
- )
177
+ return SessionId .from_sea_session_id (session_id )
174
178
175
179
def close_session (self , session_id : SessionId ) -> None :
176
180
"""
@@ -185,16 +189,11 @@ def close_session(self, session_id: SessionId) -> None:
185
189
"""
186
190
logger .debug ("SEADatabricksClient.close_session(session_id=%s)" , session_id )
187
191
188
- # Validate session ID
189
192
if session_id .backend_type != BackendType .SEA :
190
193
raise ValueError ("Not a valid SEA session ID" )
191
-
192
194
sea_session_id = session_id .to_sea_session_id ()
193
195
194
- # Make API request with warehouse_id as a query parameter
195
- request_data = {}
196
- if self .warehouse_id :
197
- request_data ["warehouse_id" ] = self .warehouse_id
196
+ request_data = {"warehouse_id" : self .warehouse_id }
198
197
199
198
self .http_client ._make_request (
200
199
method = "DELETE" ,
0 commit comments