Skip to content

Commit 312b3c8

Browse files
update open session with normalised SessionId
1 parent d688272 commit 312b3c8

File tree

3 files changed

+52
-25
lines changed

3 files changed

+52
-25
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
from typing import List, Union, Any
99

1010
from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState
11-
from databricks.sql.backend.types import SessionId, CommandId, BackendType, guid_to_hex_id
11+
from databricks.sql.backend.types import (
12+
SessionId,
13+
CommandId,
14+
BackendType,
15+
guid_to_hex_id,
16+
)
1217

1318
try:
1419
import pyarrow
@@ -580,7 +585,12 @@ def open_session(self, session_configuration, catalog, schema) -> SessionId:
580585
self._check_protocol_version(response)
581586
if response.sessionHandle is None:
582587
return None
583-
return SessionId.from_thrift_handle(response.sessionHandle)
588+
info = (
589+
{"serverProtocolVersion": response.serverProtocolVersion}
590+
if response.serverProtocolVersion
591+
else {}
592+
)
593+
return SessionId.from_thrift_handle(response.sessionHandle, info)
584594
except:
585595
self._transport.close()
586596
raise

src/databricks/sql/backend/types.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import Optional, Any, Union
2+
from typing import Dict, Optional, Any, Union
33
import uuid
44
import logging
55

@@ -43,6 +43,7 @@ def __init__(
4343
backend_type: BackendType,
4444
guid: Any,
4545
secret: Optional[Any] = None,
46+
info: Optional[Dict[str, Any]] = None,
4647
):
4748
"""
4849
Initialize a SessionId.
@@ -51,13 +52,15 @@ def __init__(
5152
backend_type: The type of backend (THRIFT or SEA)
5253
guid: The primary identifier for the session
5354
secret: The secret part of the identifier (only used for Thrift)
55+
info: Additional information about the session
5456
"""
5557
self.backend_type = backend_type
5658
self.guid = guid
5759
self.secret = secret
60+
self.info = info or {}
5861

5962
@classmethod
60-
def from_thrift_handle(cls, session_handle):
63+
def from_thrift_handle(cls, session_handle, info: Optional[Dict[str, Any]] = None):
6164
"""
6265
Create a SessionId from a Thrift session handle.
6366
@@ -67,16 +70,23 @@ def from_thrift_handle(cls, session_handle):
6770
Returns:
6871
A SessionId instance
6972
"""
70-
if session_handle is None or session_handle.sessionId is None:
73+
if session_handle is None:
7174
return None
7275

7376
guid_bytes = session_handle.sessionId.guid
7477
secret_bytes = session_handle.sessionId.secret
7578

76-
return cls(BackendType.THRIFT, guid_bytes, secret_bytes)
79+
if session_handle.serverProtocolVersion is not None:
80+
if info is None:
81+
info = {}
82+
info["serverProtocolVersion"] = session_handle.serverProtocolVersion
83+
84+
return cls(BackendType.THRIFT, guid_bytes, secret_bytes, info)
7785

7886
@classmethod
79-
def from_sea_session_id(cls, session_id: str):
87+
def from_sea_session_id(
88+
cls, session_id: str, info: Optional[Dict[str, Any]] = None
89+
):
8090
"""
8191
Create a SessionId from a SEA session ID.
8292
@@ -86,7 +96,7 @@ def from_sea_session_id(cls, session_id: str):
8696
Returns:
8797
A SessionId instance
8898
"""
89-
return cls(BackendType.SEA, session_id)
99+
return cls(BackendType.SEA, session_id, info=info)
90100

91101
def to_thrift_handle(self):
92102
"""
@@ -101,7 +111,10 @@ def to_thrift_handle(self):
101111
from databricks.sql.thrift_api.TCLIService import ttypes
102112

103113
handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret)
104-
return ttypes.TSessionHandle(sessionId=handle_identifier)
114+
server_protocol_version = self.info.get("serverProtocolVersion")
115+
return ttypes.TSessionHandle(
116+
sessionId=handle_identifier, serverProtocolVersion=server_protocol_version
117+
)
105118

106119
def to_sea_session_id(self):
107120
"""
@@ -129,19 +142,12 @@ def to_hex_id(self) -> str:
129142

130143
def get_protocol_version(self):
131144
"""
132-
Since the sessionHandle will sometimes have a serverProtocolVersion, it takes
133-
precedence over the serverProtocolVersion defined in the OpenSessionResponse.
145+
Get the server protocol version for this session.
146+
147+
Returns:
148+
The server protocol version or None if this is not a Thrift session ID
134149
"""
135-
if self.backend_type != BackendType.THRIFT:
136-
return None
137-
session_handle = self.to_thrift_handle()
138-
if (
139-
session_handle
140-
and hasattr(session_handle, "serverProtocolVersion")
141-
and session_handle.serverProtocolVersion
142-
):
143-
return session_handle.serverProtocolVersion
144-
return None
150+
return self.info.get("serverProtocolVersion")
145151

146152

147153
class CommandId:
@@ -190,7 +196,7 @@ def from_thrift_handle(cls, operation_handle):
190196
Returns:
191197
A CommandId instance
192198
"""
193-
if operation_handle is None or operation_handle.operationId is None:
199+
if operation_handle is None:
194200
return None
195201

196202
guid_bytes = operation_handle.operationId.guid

tests/unit/test_parameters.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
TinyIntParameter,
2323
VoidParameter,
2424
)
25+
from databricks.sql.backend.types import SessionId
2526
from databricks.sql.parameters.native import (
2627
TDbsqlParameter,
2728
TSparkParameterValue,
@@ -42,7 +43,10 @@ class TestSessionHandleChecks(object):
4243
(
4344
TOpenSessionResp(
4445
serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7,
45-
sessionHandle=TSessionHandle(1, None),
46+
sessionHandle=TSessionHandle(
47+
sessionId=ttypes.THandleIdentifier(guid=0x36, secret=0x37),
48+
serverProtocolVersion=None,
49+
),
4650
),
4751
ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7,
4852
),
@@ -51,15 +55,22 @@ class TestSessionHandleChecks(object):
5155
TOpenSessionResp(
5256
serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7,
5357
sessionHandle=TSessionHandle(
54-
1, ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8
58+
sessionId=ttypes.THandleIdentifier(guid=0x36, secret=0x37),
59+
serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8,
5560
),
5661
),
5762
ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8,
5863
),
5964
],
6065
)
6166
def test_get_protocol_version_fallback_behavior(self, test_input, expected):
62-
assert Connection.get_protocol_version(test_input) == expected
67+
info = (
68+
{"serverProtocolVersion": test_input.serverProtocolVersion}
69+
if test_input.serverProtocolVersion
70+
else {}
71+
)
72+
session_id = SessionId.from_thrift_handle(test_input.sessionHandle, info)
73+
assert Connection.get_protocol_version(session_id) == expected
6374

6475
@pytest.mark.parametrize(
6576
"test_input,expected",

0 commit comments

Comments
 (0)