Skip to content

Commit b6d7c0c

Browse files
review metadata ops
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent e1c7091 commit b6d7c0c

File tree

3 files changed

+214
-57
lines changed

3 files changed

+214
-57
lines changed

src/databricks/sql/backend/sea_backend.py

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ def _extract_warehouse_id(self, http_path: str) -> str:
101101
102102
The warehouse ID is expected to be the last segment of the path when the
103103
second-to-last segment is either 'warehouses' or 'endpoints'.
104-
This matches the JDBC implementation which supports both formats.
105104
106105
Args:
107106
http_path: The HTTP path from which to extract the warehouse ID
@@ -182,7 +181,6 @@ def open_session(
182181
schema=schema,
183182
)
184183

185-
# Send the request
186184
response = self.http_client._make_request(
187185
method="POST", path=self.SESSION_PATH, data=request.to_dict()
188186
)
@@ -219,16 +217,14 @@ def close_session(self, session_id: SessionId) -> None:
219217
raise ValueError("Not a valid SEA session ID")
220218
sea_session_id = session_id.to_sea_session_id()
221219

222-
# Create the request model
223220
request = DeleteSessionRequest(
224221
warehouse_id=self.warehouse_id, session_id=sea_session_id
225222
)
226223

227-
# Send the request
228224
self.http_client._make_request(
229225
method="DELETE",
230226
path=self.SESSION_PATH_WITH_ID.format(sea_session_id),
231-
data=request.to_dict(),
227+
params=request.to_dict(),
232228
)
233229

234230
def execute_command(
@@ -279,11 +275,8 @@ def execute_command(
279275
)
280276
)
281277

282-
# Determine format and disposition based on use_cloud_fetch
283278
format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY"
284279
disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE"
285-
286-
# Create the request model
287280
request = ExecuteStatementRequest(
288281
warehouse_id=self.warehouse_id,
289282
session_id=sea_session_id,
@@ -297,15 +290,10 @@ def execute_command(
297290
parameters=sea_parameters if sea_parameters else None,
298291
)
299292

300-
# Execute the statement
301293
response_data = self.http_client._make_request(
302294
method="POST", path=self.STATEMENT_PATH, data=request.to_dict()
303295
)
304-
305-
# Parse the response
306296
response = ExecuteStatementResponse.from_dict(response_data)
307-
308-
# Create a command ID from the statement ID
309297
statement_id = response.statement_id
310298
if not statement_id:
311299
raise ServerOperationError(
@@ -344,7 +332,6 @@ def execute_command(
344332
},
345333
)
346334

347-
# Get the final result
348335
return self.get_execution_result(command_id, cursor)
349336

350337
def cancel_command(self, command_id: CommandId) -> None:
@@ -362,10 +349,7 @@ def cancel_command(self, command_id: CommandId) -> None:
362349

363350
sea_statement_id = command_id.to_sea_statement_id()
364351

365-
# Create the request model
366352
request = CancelStatementRequest(statement_id=sea_statement_id)
367-
368-
# Send the cancel request
369353
self.http_client._make_request(
370354
method="POST",
371355
path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id),
@@ -387,10 +371,7 @@ def close_command(self, command_id: CommandId) -> None:
387371

388372
sea_statement_id = command_id.to_sea_statement_id()
389373

390-
# Create the request model
391374
request = CloseStatementRequest(statement_id=sea_statement_id)
392-
393-
# Send the close request - SEA uses DELETE for closing statements
394375
self.http_client._make_request(
395376
method="DELETE",
396377
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
@@ -415,10 +396,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
415396

416397
sea_statement_id = command_id.to_sea_statement_id()
417398

418-
# Create the request model
419399
request = GetStatementRequest(statement_id=sea_statement_id)
420-
421-
# Get the statement status
422400
response_data = self.http_client._make_request(
423401
method="GET",
424402
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
@@ -427,8 +405,6 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
427405

428406
# Parse the response
429407
response = GetStatementResponse.from_dict(response_data)
430-
431-
# Return the state directly since it's already a CommandState
432408
return response.status.state
433409

434410
def get_execution_result(
@@ -509,10 +485,12 @@ def get_schemas(
509485
catalog_name: Optional[str] = None,
510486
schema_name: Optional[str] = None,
511487
) -> "ResultSet":
512-
"""Get schemas by executing 'SHOW SCHEMAS [IN catalog]'."""
513-
operation = "SHOW SCHEMAS"
514-
if catalog_name:
515-
operation += f" IN `{catalog_name}`"
488+
"""Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'."""
489+
if not catalog_name:
490+
raise ValueError("Catalog name is required for get_schemas")
491+
492+
operation = f"SHOW SCHEMAS IN `{catalog_name}`"
493+
516494
if schema_name:
517495
operation += f" LIKE '{schema_name}'"
518496

@@ -542,13 +520,18 @@ def get_tables(
542520
table_name: Optional[str] = None,
543521
table_types: Optional[List[str]] = None,
544522
) -> "ResultSet":
545-
"""Get tables by executing 'SHOW TABLES [IN catalog.schema]'."""
546-
operation = "SHOW TABLES"
523+
"""Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'."""
524+
if not catalog_name:
525+
raise ValueError("Catalog name is required for get_tables")
526+
527+
operation = "SHOW TABLES IN " + (
528+
"ALL CATALOGS"
529+
if catalog_name in [None, "*", "%"]
530+
else f"CATALOG `{catalog_name}`"
531+
)
547532

548-
if catalog_name and schema_name:
549-
operation += f" IN `{catalog_name}`.`{schema_name}`"
550-
elif schema_name:
551-
operation += f" IN `{schema_name}`"
533+
if schema_name:
534+
operation += f" SCHEMA LIKE '{schema_name}'"
552535

553536
if table_name:
554537
operation += f" LIKE '{table_name}'"
@@ -579,11 +562,11 @@ def get_columns(
579562
table_name: Optional[str] = None,
580563
column_name: Optional[str] = None,
581564
) -> "ResultSet":
582-
"""Get columns by executing 'SHOW COLUMNS IN catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'."""
565+
"""Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'."""
583566
if not catalog_name:
584567
raise ValueError("Catalog name is required for get_columns")
585568

586-
operation = f"SHOW COLUMNS IN `{catalog_name}`"
569+
operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`"
587570

588571
if schema_name:
589572
operation += f" SCHEMA LIKE '{schema_name}'"

src/databricks/sql/backend/utils/http_client.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,11 @@ def _get_auth_headers(self) -> Dict[str, str]:
8787
return headers
8888

8989
def _make_request(
90-
self, method: str, path: str, data: Optional[Dict[str, Any]] = None
90+
self,
91+
method: str,
92+
path: str,
93+
data: Optional[Dict[str, Any]] = None,
94+
params: Optional[Dict[str, Any]] = None,
9195
) -> Dict[str, Any]:
9296
"""
9397
Make an HTTP request to the SEA endpoint.
@@ -109,12 +113,18 @@ def _make_request(
109113
logger.debug(f"making {method} request to {url}")
110114

111115
try:
116+
args = {
117+
"url": url,
118+
"headers": headers,
119+
"json": data,
120+
"params": params,
121+
}
112122
if method.upper() == "GET":
113-
response = self.session.get(url, headers=headers, params=data)
123+
response = self.session.get(**args)
114124
elif method.upper() == "POST":
115-
response = self.session.post(url, headers=headers, json=data)
125+
response = self.session.post(**args)
116126
elif method.upper() == "DELETE":
117-
response = self.session.delete(url, headers=headers, params=data)
127+
response = self.session.delete(**args)
118128
else:
119129
raise ValueError(f"Unsupported HTTP method: {method}")
120130

@@ -130,7 +140,11 @@ def _make_request(
130140

131141
# Log response content (but limit it for large responses)
132142
content_str = json.dumps(result, indent=4, sort_keys=True)
133-
content_str = content_str[:1000] + "..." if len(content_str) > 1000 else content_str
143+
content_str = (
144+
content_str[:1000] + "..."
145+
if len(content_str) > 1000
146+
else content_str
147+
)
134148
logger.debug(f"Response content: {content_str}")
135149

136150
return result

0 commit comments

Comments
 (0)