Skip to content

Commit 1003319

Browse files
remove model redundancies
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent b8a170e commit 1003319

File tree

5 files changed

+44
-50
lines changed

5 files changed

+44
-50
lines changed

examples/experimental/sea_connector_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_sea_query_execution():
4545
logger.info(f"backend type: {type(connection.session.backend)}")
4646

4747
# Create a cursor and execute a simple query
48-
cursor = connection.cursor(buffer_size_bytes=0)
48+
cursor = connection.cursor(arraysize=0, buffer_size_bytes=0)
4949

5050
logger.info("Executing query: SELECT 1 as test_value")
5151
cursor.execute("SELECT 1 as test_value")

src/databricks/sql/backend/models/requests.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class StatementParameter:
2121
class ExecuteStatementRequest:
2222
"""Request to execute a SQL statement."""
2323

24+
# TODO: result_compression key
2425
warehouse_id: str
2526
statement: str
2627
session_id: str
@@ -75,36 +76,33 @@ def to_dict(self) -> Dict[str, Any]:
7576
class GetStatementRequest:
7677
"""Request to get information about a statement."""
7778

78-
warehouse_id: str
7979
statement_id: str
8080

8181
def to_dict(self) -> Dict[str, Any]:
8282
"""Convert the request to a dictionary for JSON serialization."""
83-
return {"warehouse_id": self.warehouse_id, "statement_id": self.statement_id}
83+
return {"statement_id": self.statement_id}
8484

8585

8686
@dataclass
8787
class CancelStatementRequest:
8888
"""Request to cancel a statement."""
8989

90-
warehouse_id: str
9190
statement_id: str
9291

9392
def to_dict(self) -> Dict[str, Any]:
9493
"""Convert the request to a dictionary for JSON serialization."""
95-
return {"warehouse_id": self.warehouse_id, "statement_id": self.statement_id}
94+
return {"statement_id": self.statement_id}
9695

9796

9897
@dataclass
9998
class CloseStatementRequest:
10099
"""Request to close a statement."""
101100

102-
warehouse_id: str
103101
statement_id: str
104102

105103
def to_dict(self) -> Dict[str, Any]:
106104
"""Convert the request to a dictionary for JSON serialization."""
107-
return {"warehouse_id": self.warehouse_id, "statement_id": self.statement_id}
105+
return {"statement_id": self.statement_id}
108106

109107

110108
@dataclass
@@ -139,7 +137,7 @@ class DeleteSessionRequest:
139137
warehouse_id: str
140138
session_id: str
141139

142-
def to_query_params(self) -> Dict[str, str]:
140+
def to_dict(self) -> Dict[str, str]:
143141
"""
144142
Convert the request to query parameters.
145143
@@ -149,4 +147,4 @@ def to_query_params(self) -> Dict[str, str]:
149147
Returns:
150148
A dictionary containing the warehouse_id as a query parameter
151149
"""
152-
return {"warehouse_id": self.warehouse_id}
150+
return {"warehouse_id": self.warehouse_id, "session_id": self.session_id}

src/databricks/sql/backend/sea_backend.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def close_session(self, session_id: SessionId) -> None:
222222
self.http_client._make_request(
223223
method="DELETE",
224224
path=self.SESSION_PATH_WITH_ID.format(sea_session_id),
225-
data=request.to_query_params(), # Use to_query_params to get only the warehouse_id
225+
data=request.to_dict(),
226226
)
227227

228228
def execute_command(
@@ -284,7 +284,7 @@ def execute_command(
284284
statement=operation,
285285
disposition=disposition,
286286
format=format,
287-
wait_timeout="0s" if async_op else "30s",
287+
wait_timeout="0s" if async_op else "10s",
288288
on_wait_timeout="CONTINUE",
289289
row_limit=max_rows if max_rows > 0 else None,
290290
byte_limit=max_bytes if max_bytes > 0 else None,
@@ -324,9 +324,7 @@ def execute_command(
324324
time.sleep(0.5)
325325

326326
# Create the request model
327-
get_request = GetStatementRequest(
328-
warehouse_id=self.warehouse_id, statement_id=statement_id
329-
)
327+
get_request = GetStatementRequest(statement_id=statement_id)
330328

331329
# Get the statement status
332330
poll_response_data = self.http_client._make_request(
@@ -368,9 +366,7 @@ def cancel_command(self, command_id: CommandId) -> None:
368366
sea_statement_id = command_id.to_sea_statement_id()
369367

370368
# Create the request model
371-
request = CancelStatementRequest(
372-
warehouse_id=self.warehouse_id, statement_id=sea_statement_id
373-
)
369+
request = CancelStatementRequest(statement_id=sea_statement_id)
374370

375371
# Send the cancel request
376372
self.http_client._make_request(
@@ -395,9 +391,7 @@ def close_command(self, command_id: CommandId) -> None:
395391
sea_statement_id = command_id.to_sea_statement_id()
396392

397393
# Create the request model
398-
request = CloseStatementRequest(
399-
warehouse_id=self.warehouse_id, statement_id=sea_statement_id
400-
)
394+
request = CloseStatementRequest(statement_id=sea_statement_id)
401395

402396
# Send the close request - SEA uses DELETE for closing statements
403397
self.http_client._make_request(
@@ -425,9 +419,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
425419
sea_statement_id = command_id.to_sea_statement_id()
426420

427421
# Create the request model
428-
request = GetStatementRequest(
429-
warehouse_id=self.warehouse_id, statement_id=sea_statement_id
430-
)
422+
request = GetStatementRequest(statement_id=sea_statement_id)
431423

432424
# Get the statement status
433425
response_data = self.http_client._make_request(
@@ -466,9 +458,7 @@ def get_execution_result(
466458
sea_statement_id = command_id.to_sea_statement_id()
467459

468460
# Create the request model
469-
request = GetStatementRequest(
470-
warehouse_id=self.warehouse_id, statement_id=sea_statement_id
471-
)
461+
request = GetStatementRequest(statement_id=sea_statement_id)
472462

473463
# Get the statement result
474464
response_data = self.http_client._make_request(

src/databricks/sql/backend/sea_result_set.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
which handles the result data returned by the SEA API.
66
"""
77

8+
import json
89
import logging
910
from typing import Optional, List, Any, Dict, Tuple
1011

@@ -60,9 +61,8 @@ def __init__(
6061
# Parse the manifest
6162
manifest_data = sea_response.get("manifest")
6263
if manifest_data:
63-
schema_data = manifest_data.get("schema", [])
6464
columns = []
65-
for col_data in schema_data:
65+
for col_data in manifest_data.get("schema", {}).get("columns", []):
6666
columns.append(
6767
ColumnInfo(
6868
name=col_data.get("name", ""),

tests/unit/test_sea_backend.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
the Databricks SQL connector's SEA backend functionality.
66
"""
77

8+
import json
89
import pytest
910
from unittest.mock import patch, MagicMock, Mock
1011

1112
from databricks.sql.backend.sea_backend import SeaDatabricksClient
13+
from databricks.sql.backend.sea_result_set import SeaResultSet
1214
from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType
1315
from databricks.sql.types import SSLOptions
1416
from databricks.sql.auth.authenticators import AuthProvider
@@ -178,7 +180,7 @@ def test_close_session_valid_id(self, sea_client, mock_http_client):
178180
mock_http_client._make_request.assert_called_once_with(
179181
method="DELETE",
180182
path=sea_client.SESSION_PATH_WITH_ID.format("test-session-789"),
181-
data={"warehouse_id": "abc123"},
183+
data={"session_id": "test-session-789", "warehouse_id": "abc123"},
182184
)
183185

184186
def test_close_session_invalid_id_type(self, sea_client):
@@ -451,8 +453,6 @@ def test_cancel_command(self, sea_client, mock_http_client, sea_command_id):
451453
assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format(
452454
"test-statement-123"
453455
)
454-
assert "warehouse_id" in kwargs["data"]
455-
assert kwargs["data"]["warehouse_id"] == "abc123"
456456

457457
def test_close_command(self, sea_client, mock_http_client, sea_command_id):
458458
"""Test closing a command."""
@@ -469,8 +469,6 @@ def test_close_command(self, sea_client, mock_http_client, sea_command_id):
469469
assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format(
470470
"test-statement-123"
471471
)
472-
assert "warehouse_id" in kwargs["data"]
473-
assert kwargs["data"]["warehouse_id"] == "abc123"
474472

475473
def test_get_query_state(self, sea_client, mock_http_client, sea_command_id):
476474
"""Test getting the state of a query."""
@@ -493,40 +491,50 @@ def test_get_query_state(self, sea_client, mock_http_client, sea_command_id):
493491
assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format(
494492
"test-statement-123"
495493
)
496-
assert "warehouse_id" in kwargs["data"]
497-
assert kwargs["data"]["warehouse_id"] == "abc123"
498494

499495
def test_get_execution_result(
500496
self, sea_client, mock_http_client, mock_cursor, sea_command_id
501497
):
502498
"""Test getting the result of a command execution."""
503499
# Set up mock response
504-
mock_http_client._make_request.return_value = {
500+
sea_response = {
505501
"statement_id": "test-statement-123",
506502
"status": {"state": "SUCCEEDED"},
507503
"manifest": {
508-
"schema": [
509-
{
510-
"name": "col1",
511-
"type_name": "STRING",
512-
"type_text": "string",
513-
"nullable": True,
514-
}
515-
],
504+
"format": "JSON_ARRAY",
505+
"schema": {
506+
"column_count": 1,
507+
"columns": [
508+
{
509+
"name": "test_value",
510+
"type_text": "INT",
511+
"type_name": "INT",
512+
"position": 0,
513+
}
514+
],
515+
},
516+
"total_chunk_count": 1,
517+
"chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}],
516518
"total_row_count": 1,
517-
"total_byte_count": 100,
519+
"truncated": False,
520+
},
521+
"result": {
522+
"chunk_index": 0,
523+
"row_offset": 0,
524+
"row_count": 1,
525+
"data_array": [["1"]],
518526
},
519-
"result": {"data": [["value1"]]},
520527
}
528+
mock_http_client._make_request.return_value = sea_response
521529

522530
# Create a real result set to verify the implementation
523531
result = sea_client.get_execution_result(sea_command_id, mock_cursor)
532+
print(result)
524533

525534
# Verify basic properties of the result
526535
assert result.statement_id == "test-statement-123"
527536
assert result.status.state == CommandState.SUCCEEDED
528-
assert len(result.description) == 1
529-
assert result.description[0][0] == "col1" # column name
537+
assert result.manifest.schema[0].name == "test_value"
530538

531539
# Verify the HTTP request
532540
mock_http_client._make_request.assert_called_once()
@@ -535,8 +543,6 @@ def test_get_execution_result(
535543
assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format(
536544
"test-statement-123"
537545
)
538-
assert "warehouse_id" in kwargs["data"]
539-
assert kwargs["data"]["warehouse_id"] == "abc123"
540546

541547
# Tests for metadata operations
542548

0 commit comments

Comments
 (0)