Skip to content

Commit 16ff4ec

Browse files
cleanup: removed excess comments, validated decisions
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent d59880c commit 16ff4ec

File tree

4 files changed

+52
-67
lines changed

4 files changed

+52
-67
lines changed

src/databricks/sql/sea/sea_backend.py renamed to src/databricks/sql/backend/sea_backend.py

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
from databricks.sql.backend.databricks_client import DatabricksClient
99
from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType
1010
from databricks.sql.exc import Error, NotSupportedError
11-
from databricks.sql.sea.http_client import SEAHttpClient
11+
from databricks.sql.backend.utils.http_client import CustomHttpClient
1212
from databricks.sql.thrift_api.TCLIService import ttypes
1313
from databricks.sql.types import SSLOptions
1414

1515
logger = logging.getLogger(__name__)
1616

1717

18-
class SEADatabricksClient(DatabricksClient):
18+
class SeaDatabricksClient(DatabricksClient):
1919
"""
2020
Statement Execution API (SEA) implementation of the DatabricksClient interface.
2121
@@ -67,28 +67,10 @@ def __init__(
6767
self._max_download_threads = kwargs.get("max_download_threads", 10)
6868

6969
# Extract warehouse ID from http_path
70-
# Format could be either:
71-
# - /sql/1.0/endpoints/{warehouse_id}
72-
# - /sql/1.0/warehouses/{warehouse_id}
73-
path_parts = http_path.strip("/").split("/")
74-
self.warehouse_id = None
75-
76-
if len(path_parts) >= 3:
77-
if path_parts[-2] in ["endpoints", "warehouses"]:
78-
self.warehouse_id = path_parts[-1]
79-
logger.debug(
80-
f"Extracted warehouse ID: {self.warehouse_id} from path: {http_path}"
81-
)
82-
83-
if not self.warehouse_id:
84-
logger.warning(
85-
"Could not extract warehouse ID from http_path: %s. "
86-
"Session creation may fail if warehouse ID is required.",
87-
http_path,
88-
)
70+
self.warehouse_id = self._extract_warehouse_id(http_path)
8971

9072
# Initialize HTTP client
91-
self.http_client = SEAHttpClient(
73+
self.http_client = CustomHttpClient(
9274
server_hostname=server_hostname,
9375
port=port,
9476
http_path=http_path,
@@ -98,6 +80,41 @@ def __init__(
9880
**kwargs,
9981
)
10082

83+
def _extract_warehouse_id(self, http_path: str) -> str:
84+
"""
85+
Extract the warehouse ID from the HTTP path.
86+
87+
The warehouse ID is expected to be the last segment of the path when the
88+
second-to-last segment is either 'warehouses' or 'endpoints'.
89+
This matches the JDBC implementation which supports both formats.
90+
91+
Args:
92+
http_path: The HTTP path from which to extract the warehouse ID
93+
94+
Returns:
95+
The extracted warehouse ID
96+
97+
Raises:
98+
Error: If the warehouse ID cannot be extracted from the path
99+
"""
100+
path_parts = http_path.strip("/").split("/")
101+
warehouse_id = None
102+
103+
if len(path_parts) >= 3 and path_parts[-2] in ["warehouses", "endpoints"]:
104+
warehouse_id = path_parts[-1]
105+
logger.debug(f"Extracted warehouse ID: {warehouse_id} from path: {http_path}")
106+
107+
if not warehouse_id:
108+
error_message = (
109+
f"Could not extract warehouse ID from http_path: {http_path}. "
110+
f"Expected format: /path/to/warehouses/{{warehouse_id}} or "
111+
f"/path/to/endpoints/{{warehouse_id}}"
112+
)
113+
logger.error(error_message)
114+
raise ValueError(error_message)
115+
116+
return warehouse_id
117+
101118
@property
102119
def staging_allowed_local_path(self) -> Union[None, str, List[str]]:
103120
"""Get the allowed local paths for staging operations."""
@@ -115,7 +132,7 @@ def max_download_threads(self) -> int:
115132

116133
def open_session(
117134
self,
118-
session_configuration: Optional[Dict[str, Any]],
135+
session_configuration: Optional[Dict[str, str]],
119136
catalog: Optional[str],
120137
schema: Optional[str],
121138
) -> SessionId:
@@ -141,36 +158,23 @@ def open_session(
141158
schema,
142159
)
143160

144-
# Prepare request payload
145-
request_data: Dict[str, Any] = {}
146-
147-
if self.warehouse_id:
148-
request_data["warehouse_id"] = self.warehouse_id
149-
161+
request_data: Dict[str, Any] = {"warehouse_id": self.warehouse_id}
150162
if session_configuration:
151-
# The SEA API expects "session_confs" as the key for session configuration
152163
request_data["session_confs"] = session_configuration
153-
154164
if catalog:
155165
request_data["catalog"] = catalog
156-
157166
if schema:
158167
request_data["schema"] = schema
159168

160-
# Make API request
161169
response = self.http_client._make_request(
162170
method="POST", path=self.SESSION_PATH, data=request_data
163171
)
164172

165-
# Extract session ID from response
166173
session_id = response.get("session_id")
167174
if not session_id:
168175
raise Error("Failed to create session: No session ID returned")
169176

170-
# Create and return SessionId object
171-
return SessionId.from_sea_session_id(
172-
session_id, {"warehouse_id": self.warehouse_id}
173-
)
177+
return SessionId.from_sea_session_id(session_id)
174178

175179
def close_session(self, session_id: SessionId) -> None:
176180
"""
@@ -185,16 +189,11 @@ def close_session(self, session_id: SessionId) -> None:
185189
"""
186190
logger.debug("SEADatabricksClient.close_session(session_id=%s)", session_id)
187191

188-
# Validate session ID
189192
if session_id.backend_type != BackendType.SEA:
190193
raise ValueError("Not a valid SEA session ID")
191-
192194
sea_session_id = session_id.to_sea_session_id()
193195

194-
# Make API request with warehouse_id as a query parameter
195-
request_data = {}
196-
if self.warehouse_id:
197-
request_data["warehouse_id"] = self.warehouse_id
196+
request_data = {"warehouse_id": self.warehouse_id}
198197

199198
self.http_client._make_request(
200199
method="DELETE",

src/databricks/sql/sea/http_client.py renamed to src/databricks/sql/backend/utils/http_client.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
logger = logging.getLogger(__name__)
1111

1212

13-
class SEAHttpClient:
13+
class CustomHttpClient:
1414
"""
1515
HTTP client for Statement Execution API (SEA).
1616
@@ -46,14 +46,11 @@ def __init__(
4646
self.auth_provider = auth_provider
4747
self.ssl_options = ssl_options
4848

49-
# Base URL for API requests
5049
self.base_url = f"https://{server_hostname}:{port}"
5150

52-
# Convert headers list to dictionary
5351
self.headers = dict(http_headers)
5452
self.headers.update({"Content-Type": "application/json"})
5553

56-
# Session retry configuration
5754
self.max_retries = kwargs.get("_retry_stop_after_attempts_count", 30)
5855

5956
# Create a session for connection pooling
@@ -109,15 +106,7 @@ def _make_request(
109106
url = urljoin(self.base_url, path)
110107
headers = {**self.headers, **self._get_auth_headers()}
111108

112-
# Log request details (without sensitive information)
113-
logger.debug(f"Making {method} request to {url}")
114-
logger.debug(f"Headers: {[k for k in headers.keys()]}")
115-
if data:
116-
# Don't log sensitive data like access tokens
117-
safe_data = {
118-
k: v for k, v in data.items() if k not in ["access_token", "token"]
119-
}
120-
logger.debug(f"Request data: {safe_data}")
109+
logger.debug(f"making {method} request to {url}")
121110

122111
try:
123112
if method.upper() == "GET":
@@ -135,7 +124,6 @@ def _make_request(
135124

136125
# Log response details
137126
logger.debug(f"Response status: {response.status_code}")
138-
logger.debug(f"Response headers: {dict(response.headers)}")
139127

140128
# Parse JSON response
141129
if response.content:
@@ -168,7 +156,11 @@ def _make_request(
168156
)
169157
except (ValueError, KeyError):
170158
# If we can't parse the JSON, just log the raw content
171-
content_str = e.response.content.decode('utf-8', errors='replace') if isinstance(e.response.content, bytes) else str(e.response.content)
159+
content_str = (
160+
e.response.content.decode("utf-8", errors="replace")
161+
if isinstance(e.response.content, bytes)
162+
else str(e.response.content)
163+
)
172164
logger.error(
173165
f"Response status: {e.response.status_code}, Raw content: {content_str}"
174166
)

src/databricks/sql/sea/__init__.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

src/databricks/sql/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from databricks.sql import __version__
99
from databricks.sql import USER_AGENT_NAME
1010
from databricks.sql.backend.thrift_backend import ThriftDatabricksClient
11-
from databricks.sql.sea.sea_backend import SEADatabricksClient
11+
from databricks.sql.backend.sea_backend import SeaDatabricksClient
1212
from databricks.sql.backend.databricks_client import DatabricksClient
1313
from databricks.sql.backend.types import SessionId, BackendType
1414

@@ -78,7 +78,7 @@ def __init__(
7878
use_sea = kwargs.get("use_sea", False)
7979

8080
if use_sea:
81-
self.backend: DatabricksClient = SEADatabricksClient(
81+
self.backend: DatabricksClient = SeaDatabricksClient(
8282
self.host,
8383
self.port,
8484
http_path,

0 commit comments

Comments
 (0)