Skip to content

Commit b8a170e

Browse files
typing, change DESCRIBE TABLE to SHOW COLUMNS
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent d552695 commit b8a170e

File tree

7 files changed

+76
-59
lines changed

7 files changed

+76
-59
lines changed

src/databricks/sql/backend/models/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import Dict, List, Any, Optional, Union
88
from dataclasses import dataclass, field
99

10+
from databricks.sql.backend.types import CommandState
11+
1012

1113
@dataclass
1214
class ServiceError:
@@ -20,7 +22,7 @@ class ServiceError:
2022
class StatementStatus:
2123
"""Status information for a statement execution."""
2224

23-
state: str
25+
state: CommandState
2426
error: Optional[ServiceError] = None
2527
sql_state: Optional[str] = None
2628

src/databricks/sql/backend/models/requests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class ExecuteStatementRequest:
3636

3737
def to_dict(self) -> Dict[str, Any]:
3838
"""Convert the request to a dictionary for JSON serialization."""
39-
result = {
39+
result: Dict[str, Any] = {
4040
"warehouse_id": self.warehouse_id,
4141
"session_id": self.session_id,
4242
"statement": self.statement,
@@ -118,7 +118,7 @@ class CreateSessionRequest:
118118

119119
def to_dict(self) -> Dict[str, Any]:
120120
"""Convert the request to a dictionary for JSON serialization."""
121-
result = {"warehouse_id": self.warehouse_id}
121+
result: Dict[str, Any] = {"warehouse_id": self.warehouse_id}
122122

123123
if self.session_confs:
124124
result["session_confs"] = self.session_confs

src/databricks/sql/backend/models/responses.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
from typing import Dict, List, Any, Optional, Union
88
from dataclasses import dataclass, field
99

10+
from databricks.sql.backend.types import CommandState
1011
from databricks.sql.backend.models.base import (
1112
StatementStatus,
1213
ResultManifest,
1314
ResultData,
15+
ServiceError,
1416
)
1517

1618

@@ -27,14 +29,17 @@ class ExecuteStatementResponse:
2729
def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse":
2830
"""Create an ExecuteStatementResponse from a dictionary."""
2931
status_data = data.get("status", {})
32+
error = None
33+
if "error" in status_data:
34+
error_data = status_data["error"]
35+
error = ServiceError(
36+
message=error_data.get("message", ""),
37+
error_code=error_data.get("error_code"),
38+
)
39+
3040
status = StatementStatus(
31-
state=status_data.get("state", ""),
32-
error=None
33-
if "error" not in status_data
34-
else {
35-
"message": status_data["error"].get("message", ""),
36-
"error_code": status_data["error"].get("error_code"),
37-
},
41+
state=CommandState.from_sea_state(status_data.get("state", "")),
42+
error=error,
3843
sql_state=status_data.get("sql_state"),
3944
)
4045

@@ -59,14 +64,17 @@ class GetStatementResponse:
5964
def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse":
6065
"""Create a GetStatementResponse from a dictionary."""
6166
status_data = data.get("status", {})
67+
error = None
68+
if "error" in status_data:
69+
error_data = status_data["error"]
70+
error = ServiceError(
71+
message=error_data.get("message", ""),
72+
error_code=error_data.get("error_code"),
73+
)
74+
6275
status = StatementStatus(
63-
state=status_data.get("state", ""),
64-
error=None
65-
if "error" not in status_data
66-
else {
67-
"message": status_data["error"].get("message", ""),
68-
"error_code": status_data["error"].get("error_code"),
69-
},
76+
state=CommandState.from_sea_state(status_data.get("state", "")),
77+
error=error,
7078
sql_state=status_data.get("sql_state"),
7179
)
7280

src/databricks/sql/backend/sea_backend.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -273,13 +273,17 @@ def execute_command(
273273
)
274274
)
275275

276+
# Determine format and disposition based on use_cloud_fetch
277+
format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY"
278+
disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE"
279+
276280
# Create the request model
277281
request = ExecuteStatementRequest(
278282
warehouse_id=self.warehouse_id,
279283
session_id=sea_session_id,
280284
statement=operation,
281-
disposition="EXTERNAL_LINKS" if use_cloud_fetch else "INLINE",
282-
format="ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY",
285+
disposition=disposition,
286+
format=format,
283287
wait_timeout="0s" if async_op else "30s",
284288
on_wait_timeout="CONTINUE",
285289
row_limit=max_rows if max_rows > 0 else None,
@@ -315,7 +319,7 @@ def execute_command(
315319
state = status.state
316320

317321
# Keep polling until we reach a terminal state
318-
while state in ["PENDING", "RUNNING"]:
322+
while state in [CommandState.PENDING, CommandState.RUNNING]:
319323
# Add a small delay to avoid excessive API calls
320324
time.sleep(0.5)
321325

@@ -337,12 +341,12 @@ def execute_command(
337341
state = status.state
338342

339343
# Check for errors
340-
if state == "FAILED" and status.error:
341-
error_message = status.error["message"]
344+
if state == CommandState.FAILED and status.error:
345+
error_message = status.error.message
342346
raise Error(f"Statement execution failed: {error_message}")
343347

344348
# Check for cancellation
345-
if state == "CANCELED":
349+
if state == CommandState.CANCELLED:
346350
raise Error("Statement execution was canceled")
347351

348352
# Get the final result
@@ -435,11 +439,8 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
435439
# Parse the response
436440
response = GetStatementResponse.from_dict(response_data)
437441

438-
# Extract the status
439-
state = response.status.state
440-
441-
# Map SEA state to CommandState
442-
return CommandState.from_sea_state(state)
442+
# Return the state directly since it's already a CommandState
443+
return response.status.state
443444

444445
def get_execution_result(
445446
self,
@@ -591,20 +592,20 @@ def get_columns(
591592
table_name: Optional[str] = None,
592593
column_name: Optional[str] = None,
593594
) -> "ResultSet":
594-
"""Get columns by executing 'DESCRIBE TABLE [catalog.schema.]table'."""
595-
if not table_name:
596-
raise ValueError("Table name is required for get_columns")
595+
"""Get columns by executing 'SHOW COLUMNS IN catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'."""
596+
if not catalog_name:
597+
raise ValueError("Catalog name is required for get_columns")
597598

598-
operation = "DESCRIBE TABLE "
599+
operation = f"SHOW COLUMNS IN `{catalog_name}`"
599600

600-
if catalog_name and schema_name:
601-
operation += f"`{catalog_name}`.`{schema_name}`."
602-
elif schema_name:
603-
operation += f"`{schema_name}`."
601+
if schema_name:
602+
operation += f" SCHEMA LIKE '{schema_name}'"
604603

605-
operation += f"`{table_name}`"
604+
if table_name:
605+
operation += f" TABLE LIKE '{table_name}'"
606606

607-
# Column name filtering will be done client-side
607+
if column_name:
608+
operation += f" LIKE '{column_name}'"
608609

609610
result = self.execute_command(
610611
operation=operation,

src/databricks/sql/backend/sea_result_set.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010

1111
from databricks.sql.result_set import ResultSet
1212
from databricks.sql.types import Row
13-
from databricks.sql.backend.types import CommandId
13+
from databricks.sql.backend.types import CommandId, CommandState
1414
from databricks.sql.exc import Error
1515

1616
from databricks.sql.backend.models import (
1717
StatementStatus,
1818
ResultManifest,
1919
ResultData,
2020
ColumnInfo,
21+
ServiceError,
2122
)
2223

2324
logger = logging.getLogger(__name__)
@@ -42,14 +43,17 @@ def __init__(
4243

4344
# Parse the status
4445
status_data = sea_response.get("status", {})
46+
error = None
47+
if "error" in status_data:
48+
error_data = status_data["error"]
49+
error = ServiceError(
50+
message=error_data.get("message", ""),
51+
error_code=error_data.get("error_code"),
52+
)
53+
4554
self.status = StatementStatus(
46-
state=status_data.get("state", ""),
47-
error=None
48-
if "error" not in status_data
49-
else {
50-
"message": status_data["error"].get("message", ""),
51-
"error_code": status_data["error"].get("error_code"),
52-
},
55+
state=CommandState.from_sea_state(status_data.get("state", "")),
56+
error=error,
5357
sql_state=status_data.get("sql_state"),
5458
)
5559

@@ -71,7 +75,7 @@ def __init__(
7175
)
7276
)
7377

74-
self.manifest = ResultManifest(
78+
self.manifest: Optional[ResultManifest] = ResultManifest(
7579
schema=columns,
7680
total_row_count=manifest_data.get("total_row_count", 0),
7781
total_byte_count=manifest_data.get("total_byte_count", 0),
@@ -84,7 +88,7 @@ def __init__(
8488
# Parse the result data
8589
result_data = sea_response.get("result")
8690
if result_data:
87-
self.result = ResultData(
91+
self.result: Optional[ResultData] = ResultData(
8892
data=result_data.get("data"),
8993
external_links=result_data.get("external_links"),
9094
)
@@ -112,7 +116,9 @@ def is_staging_operation(self) -> bool:
112116

113117
def _extract_description_from_manifest(
114118
self,
115-
) -> List[Tuple[str, str, None, None, None, None, bool]]:
119+
) -> Optional[
120+
List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]]
121+
]:
116122
"""
117123
Extract column descriptions from the SEA manifest.
118124
@@ -121,7 +127,7 @@ def _extract_description_from_manifest(
121127
(name, type_code, display_size, internal_size, precision, scale, null_ok)
122128
"""
123129
if not self.manifest or not self.manifest.schema:
124-
return []
130+
return None
125131

126132
description = []
127133
for col in self.manifest.schema:

src/databricks/sql/result_set.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import List, Optional, Any, Union
2+
from typing import List, Optional, Any, Union, Tuple
33

44
import logging
55
import time
@@ -32,7 +32,7 @@ def __init__(self, connection, backend, arraysize: int, buffer_size_bytes: int):
3232
self.arraysize = arraysize
3333
self.buffer_size_bytes = buffer_size_bytes
3434
self._next_row_index = 0
35-
self.description = None
35+
self.description: Optional[Any] = None
3636

3737
def __iter__(self):
3838
while True:

tests/unit/test_sea_backend.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ def test_get_execution_result(
524524

525525
# Verify basic properties of the result
526526
assert result.statement_id == "test-statement-123"
527-
assert result.status.state == "SUCCEEDED"
527+
assert result.status.state == CommandState.SUCCEEDED
528528
assert len(result.description) == 1
529529
assert result.description[0][0] == "col1" # column name
530530

@@ -657,25 +657,25 @@ def test_get_columns(self, sea_client, mock_cursor, sea_session_id):
657657
args, kwargs = mock_execute.call_args
658658
assert (
659659
kwargs["operation"]
660-
== "DESCRIBE TABLE `test_catalog`.`test_schema`.`test_table`"
660+
== "SHOW COLUMNS IN `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table'"
661661
)
662662
assert kwargs["session_id"] == sea_session_id
663663
assert kwargs["max_rows"] == 100
664664
assert kwargs["max_bytes"] == 1000
665665
assert kwargs["cursor"] == mock_cursor
666666
assert kwargs["async_op"] is False
667667

668-
def test_get_columns_no_table_name(self, sea_client, mock_cursor, sea_session_id):
669-
"""Test getting columns without a table name raises an error."""
668+
def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id):
669+
"""Test getting columns without a catalog name raises an error."""
670670
with pytest.raises(ValueError) as excinfo:
671671
sea_client.get_columns(
672672
session_id=sea_session_id,
673673
max_rows=100,
674674
max_bytes=1000,
675675
cursor=mock_cursor,
676-
catalog_name="test_catalog",
676+
catalog_name=None, # No catalog name
677677
schema_name="test_schema",
678-
table_name=None, # No table name
678+
table_name="test_table",
679679
)
680680

681-
assert "Table name is required" in str(excinfo.value)
681+
assert "Catalog name is required" in str(excinfo.value)

0 commit comments

Comments
 (0)