1
1
from enum import Enum
2
- from typing import Optional , Any , Union
2
+ from typing import Dict , Optional , Any , Union
3
3
import uuid
4
4
import logging
5
5
@@ -43,6 +43,7 @@ def __init__(
43
43
backend_type : BackendType ,
44
44
guid : Any ,
45
45
secret : Optional [Any ] = None ,
46
+ info : Optional [Dict [str , Any ]] = None ,
46
47
):
47
48
"""
48
49
Initialize a SessionId.
@@ -51,13 +52,15 @@ def __init__(
51
52
backend_type: The type of backend (THRIFT or SEA)
52
53
guid: The primary identifier for the session
53
54
secret: The secret part of the identifier (only used for Thrift)
55
+ info: Additional information about the session
54
56
"""
55
57
self .backend_type = backend_type
56
58
self .guid = guid
57
59
self .secret = secret
60
+ self .info = info or {}
58
61
59
62
@classmethod
60
- def from_thrift_handle (cls , session_handle ):
63
+ def from_thrift_handle (cls , session_handle , info : Optional [ Dict [ str , Any ]] = None ):
61
64
"""
62
65
Create a SessionId from a Thrift session handle.
63
66
@@ -67,16 +70,23 @@ def from_thrift_handle(cls, session_handle):
67
70
Returns:
68
71
A SessionId instance
69
72
"""
70
- if session_handle is None or session_handle . sessionId is None :
73
+ if session_handle is None :
71
74
return None
72
75
73
76
guid_bytes = session_handle .sessionId .guid
74
77
secret_bytes = session_handle .sessionId .secret
75
78
76
- return cls (BackendType .THRIFT , guid_bytes , secret_bytes )
79
+ if session_handle .serverProtocolVersion is not None :
80
+ if info is None :
81
+ info = {}
82
+ info ["serverProtocolVersion" ] = session_handle .serverProtocolVersion
83
+
84
+ return cls (BackendType .THRIFT , guid_bytes , secret_bytes , info )
77
85
78
86
@classmethod
79
- def from_sea_session_id (cls , session_id : str ):
87
+ def from_sea_session_id (
88
+ cls , session_id : str , info : Optional [Dict [str , Any ]] = None
89
+ ):
80
90
"""
81
91
Create a SessionId from a SEA session ID.
82
92
@@ -86,7 +96,7 @@ def from_sea_session_id(cls, session_id: str):
86
96
Returns:
87
97
A SessionId instance
88
98
"""
89
- return cls (BackendType .SEA , session_id )
99
+ return cls (BackendType .SEA , session_id , info = info )
90
100
91
101
def to_thrift_handle (self ):
92
102
"""
@@ -101,7 +111,10 @@ def to_thrift_handle(self):
101
111
from databricks .sql .thrift_api .TCLIService import ttypes
102
112
103
113
handle_identifier = ttypes .THandleIdentifier (guid = self .guid , secret = self .secret )
104
- return ttypes .TSessionHandle (sessionId = handle_identifier )
114
+ server_protocol_version = self .info .get ("serverProtocolVersion" )
115
+ return ttypes .TSessionHandle (
116
+ sessionId = handle_identifier , serverProtocolVersion = server_protocol_version
117
+ )
105
118
106
119
def to_sea_session_id (self ):
107
120
"""
@@ -129,19 +142,12 @@ def to_hex_id(self) -> str:
129
142
130
143
def get_protocol_version (self ):
131
144
"""
132
- Since the sessionHandle will sometimes have a serverProtocolVersion, it takes
133
- precedence over the serverProtocolVersion defined in the OpenSessionResponse.
145
+ Get the server protocol version for this session.
146
+
147
+ Returns:
148
+ The server protocol version or None if this is not a Thrift session ID
134
149
"""
135
- if self .backend_type != BackendType .THRIFT :
136
- return None
137
- session_handle = self .to_thrift_handle ()
138
- if (
139
- session_handle
140
- and hasattr (session_handle , "serverProtocolVersion" )
141
- and session_handle .serverProtocolVersion
142
- ):
143
- return session_handle .serverProtocolVersion
144
- return None
150
+ return self .info .get ("serverProtocolVersion" )
145
151
146
152
147
153
class CommandId :
@@ -190,7 +196,7 @@ def from_thrift_handle(cls, operation_handle):
190
196
Returns:
191
197
A CommandId instance
192
198
"""
193
- if operation_handle is None or operation_handle . operationId is None :
199
+ if operation_handle is None :
194
200
return None
195
201
196
202
guid_bytes = operation_handle .operationId .guid
0 commit comments