Skip to content

Commit 95b0781

Browse files
add unit tests for sea backend
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent 16ff4ec commit 95b0781

File tree

2 files changed

+171
-1
lines changed

2 files changed

+171
-1
lines changed

src/databricks/sql/backend/sea_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ def _extract_warehouse_id(self, http_path: str) -> str:
102102

103103
if len(path_parts) >= 3 and path_parts[-2] in ["warehouses", "endpoints"]:
104104
warehouse_id = path_parts[-1]
105-
logger.debug(f"Extracted warehouse ID: {warehouse_id} from path: {http_path}")
105+
logger.debug(
106+
f"Extracted warehouse ID: {warehouse_id} from path: {http_path}"
107+
)
106108

107109
if not warehouse_id:
108110
error_message = (

tests/unit/test_sea_backend.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import pytest
2+
from unittest.mock import patch, MagicMock, Mock
3+
4+
from databricks.sql.backend.sea_backend import SeaDatabricksClient
5+
from databricks.sql.backend.types import SessionId, BackendType
6+
from databricks.sql.types import SSLOptions
7+
from databricks.sql.auth.authenticators import AuthProvider
8+
from databricks.sql.exc import Error, NotSupportedError
9+
10+
11+
class TestSeaBackend:
12+
"""Test suite for the SeaDatabricksClient class."""
13+
14+
@pytest.fixture
15+
def mock_http_client(self):
16+
"""Create a mock HTTP client."""
17+
with patch(
18+
"databricks.sql.backend.sea_backend.CustomHttpClient"
19+
) as mock_client_class:
20+
mock_client = mock_client_class.return_value
21+
yield mock_client
22+
23+
@pytest.fixture
24+
def sea_client(self, mock_http_client):
25+
"""Create a SeaDatabricksClient instance with mocked dependencies."""
26+
server_hostname = "test-server.databricks.com"
27+
port = 443
28+
http_path = "/sql/warehouses/abc123"
29+
http_headers = [("header1", "value1"), ("header2", "value2")]
30+
auth_provider = AuthProvider()
31+
ssl_options = SSLOptions()
32+
33+
client = SeaDatabricksClient(
34+
server_hostname=server_hostname,
35+
port=port,
36+
http_path=http_path,
37+
http_headers=http_headers,
38+
auth_provider=auth_provider,
39+
ssl_options=ssl_options,
40+
)
41+
42+
return client
43+
44+
def test_init_extracts_warehouse_id(self, mock_http_client):
45+
"""Test that the constructor properly extracts the warehouse ID from the HTTP path."""
46+
# Test with warehouses format
47+
client1 = SeaDatabricksClient(
48+
server_hostname="test-server.databricks.com",
49+
port=443,
50+
http_path="/sql/warehouses/abc123",
51+
http_headers=[],
52+
auth_provider=AuthProvider(),
53+
ssl_options=SSLOptions(),
54+
)
55+
assert client1.warehouse_id == "abc123"
56+
57+
# Test with endpoints format
58+
client2 = SeaDatabricksClient(
59+
server_hostname="test-server.databricks.com",
60+
port=443,
61+
http_path="/sql/endpoints/def456",
62+
http_headers=[],
63+
auth_provider=AuthProvider(),
64+
ssl_options=SSLOptions(),
65+
)
66+
assert client2.warehouse_id == "def456"
67+
68+
def test_init_raises_error_for_invalid_http_path(self, mock_http_client):
69+
"""Test that the constructor raises an error for invalid HTTP paths."""
70+
with pytest.raises(ValueError) as excinfo:
71+
SeaDatabricksClient(
72+
server_hostname="test-server.databricks.com",
73+
port=443,
74+
http_path="/invalid/path",
75+
http_headers=[],
76+
auth_provider=AuthProvider(),
77+
ssl_options=SSLOptions(),
78+
)
79+
assert "Could not extract warehouse ID" in str(excinfo.value)
80+
81+
def test_open_session_basic(self, sea_client, mock_http_client):
82+
"""Test the open_session method with minimal parameters."""
83+
# Set up mock response
84+
mock_http_client._make_request.return_value = {"session_id": "test-session-123"}
85+
86+
# Call the method
87+
session_id = sea_client.open_session(None, None, None)
88+
89+
# Verify the result
90+
assert isinstance(session_id, SessionId)
91+
assert session_id.backend_type == BackendType.SEA
92+
assert session_id.guid == "test-session-123"
93+
94+
# Verify the HTTP request
95+
mock_http_client._make_request.assert_called_once_with(
96+
method="POST", path=sea_client.SESSION_PATH, data={"warehouse_id": "abc123"}
97+
)
98+
99+
def test_open_session_with_all_parameters(self, sea_client, mock_http_client):
100+
"""Test the open_session method with all parameters."""
101+
# Set up mock response
102+
mock_http_client._make_request.return_value = {"session_id": "test-session-456"}
103+
104+
# Call the method with all parameters
105+
session_config = {"spark.sql.shuffle.partitions": "10"}
106+
catalog = "test_catalog"
107+
schema = "test_schema"
108+
109+
session_id = sea_client.open_session(session_config, catalog, schema)
110+
111+
# Verify the result
112+
assert isinstance(session_id, SessionId)
113+
assert session_id.backend_type == BackendType.SEA
114+
assert session_id.guid == "test-session-456"
115+
116+
# Verify the HTTP request
117+
expected_data = {
118+
"warehouse_id": "abc123",
119+
"session_confs": session_config,
120+
"catalog": catalog,
121+
"schema": schema,
122+
}
123+
mock_http_client._make_request.assert_called_once_with(
124+
method="POST", path=sea_client.SESSION_PATH, data=expected_data
125+
)
126+
127+
def test_open_session_error_handling(self, sea_client, mock_http_client):
128+
"""Test error handling in the open_session method."""
129+
# Set up mock response without session_id
130+
mock_http_client._make_request.return_value = {}
131+
132+
# Call the method and expect an error
133+
with pytest.raises(Error) as excinfo:
134+
sea_client.open_session(None, None, None)
135+
136+
assert "Failed to create session" in str(excinfo.value)
137+
138+
def test_close_session_valid_id(self, sea_client, mock_http_client):
139+
"""Test closing a session with a valid session ID."""
140+
# Create a valid SEA session ID
141+
session_id = SessionId.from_sea_session_id("test-session-789")
142+
143+
# Set up mock response
144+
mock_http_client._make_request.return_value = {}
145+
146+
# Call the method
147+
sea_client.close_session(session_id)
148+
149+
# Verify the HTTP request
150+
mock_http_client._make_request.assert_called_once_with(
151+
method="DELETE",
152+
path=sea_client.SESSION_PATH_WITH_ID.format("test-session-789"),
153+
data={"warehouse_id": "abc123"},
154+
)
155+
156+
def test_close_session_invalid_id_type(self, sea_client):
157+
"""Test closing a session with an invalid session ID type."""
158+
# Create a Thrift session ID (not SEA)
159+
mock_thrift_handle = MagicMock()
160+
mock_thrift_handle.sessionId.guid = b"guid"
161+
mock_thrift_handle.sessionId.secret = b"secret"
162+
session_id = SessionId.from_thrift_handle(mock_thrift_handle)
163+
164+
# Call the method and expect an error
165+
with pytest.raises(ValueError) as excinfo:
166+
sea_client.close_session(session_id)
167+
168+
assert "Not a valid SEA session ID" in str(excinfo.value)

0 commit comments

Comments
 (0)