19
19
OperationalError ,
20
20
SessionAlreadyClosedError ,
21
21
CursorAlreadyClosedError ,
22
+ Error ,
23
+ NotSupportedError ,
22
24
)
23
25
from databricks .sql .thrift_api .TCLIService import ttypes
24
26
from databricks .sql .thrift_backend import ThriftBackend
45
47
from databricks .sql .types import Row , SSLOptions
46
48
from databricks .sql .auth .auth import get_python_sql_connector_auth_provider
47
49
from databricks .sql .experimental .oauth_persistence import OAuthPersistence
50
+ from databricks .sql .session import Session
48
51
49
52
from databricks .sql .thrift_api .TCLIService .ttypes import (
50
53
TSparkParameter ,
@@ -218,66 +221,24 @@ def read(self) -> Optional[OAuthToken]:
218
221
access_token_kv = {"access_token" : access_token }
219
222
kwargs = {** kwargs , ** access_token_kv }
220
223
221
- self .open = False
222
- self .host = server_hostname
223
- self .port = kwargs .get ("_port" , 443 )
224
224
self .disable_pandas = kwargs .get ("_disable_pandas" , False )
225
225
self .lz4_compression = kwargs .get ("enable_query_result_lz4_compression" , True )
226
+ self .use_cloud_fetch = kwargs .get ("use_cloud_fetch" , True )
227
+ self ._cursors = [] # type: List[Cursor]
226
228
227
- auth_provider = get_python_sql_connector_auth_provider (
228
- server_hostname , ** kwargs
229
- )
230
-
231
- user_agent_entry = kwargs .get ("user_agent_entry" )
232
- if user_agent_entry is None :
233
- user_agent_entry = kwargs .get ("_user_agent_entry" )
234
- if user_agent_entry is not None :
235
- logger .warning (
236
- "[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. "
237
- "This parameter will be removed in the upcoming releases."
238
- )
239
-
240
- if user_agent_entry :
241
- useragent_header = "{}/{} ({})" .format (
242
- USER_AGENT_NAME , __version__ , user_agent_entry
243
- )
244
- else :
245
- useragent_header = "{}/{}" .format (USER_AGENT_NAME , __version__ )
246
-
247
- base_headers = [("User-Agent" , useragent_header )]
248
-
249
- self ._ssl_options = SSLOptions (
250
- # Double negation is generally a bad thing, but we have to keep backward compatibility
251
- tls_verify = not kwargs .get (
252
- "_tls_no_verify" , False
253
- ), # by default - verify cert and host
254
- tls_verify_hostname = kwargs .get ("_tls_verify_hostname" , True ),
255
- tls_trusted_ca_file = kwargs .get ("_tls_trusted_ca_file" ),
256
- tls_client_cert_file = kwargs .get ("_tls_client_cert_file" ),
257
- tls_client_cert_key_file = kwargs .get ("_tls_client_cert_key_file" ),
258
- tls_client_cert_key_password = kwargs .get ("_tls_client_cert_key_password" ),
259
- )
260
-
261
- self .thrift_backend = ThriftBackend (
262
- self .host ,
263
- self .port ,
229
+ # Create the session
230
+ self .session = Session (
231
+ server_hostname ,
264
232
http_path ,
265
- (http_headers or []) + base_headers ,
266
- auth_provider ,
267
- ssl_options = self ._ssl_options ,
268
- _use_arrow_native_complex_types = _use_arrow_native_complex_types ,
269
- ** kwargs ,
270
- )
271
-
272
- self ._open_session_resp = self .thrift_backend .open_session (
273
- session_configuration , catalog , schema
233
+ http_headers ,
234
+ session_configuration ,
235
+ catalog ,
236
+ schema ,
237
+ _use_arrow_native_complex_types ,
238
+ ** kwargs
274
239
)
275
- self ._session_handle = self ._open_session_resp .sessionHandle
276
- self .protocol_version = self .get_protocol_version (self ._open_session_resp )
277
- self .use_cloud_fetch = kwargs .get ("use_cloud_fetch" , True )
278
- self .open = True
279
- logger .info ("Successfully opened session " + str (self .get_session_id_hex ()))
280
- self ._cursors = [] # type: List[Cursor]
240
+
241
+ logger .info ("Successfully opened connection with session " + str (self .get_session_id_hex ()))
281
242
282
243
self .use_inline_params = self ._set_use_inline_params_with_warning (
283
244
kwargs .get ("use_inline_params" , False )
@@ -318,7 +279,7 @@ def __exit__(self, exc_type, exc_value, traceback):
318
279
self .close ()
319
280
320
281
def __del__ (self ):
321
- if self .open :
282
+ if self .session . open :
322
283
logger .debug (
323
284
"Closing unclosed connection for session "
324
285
"{}" .format (self .get_session_id_hex ())
@@ -330,34 +291,27 @@ def __del__(self):
330
291
logger .debug ("Couldn't close unclosed connection: {}" .format (e .message ))
331
292
332
293
def get_session_id (self ):
333
- return self .thrift_backend .handle_to_id (self ._session_handle )
294
+ """Get the session ID from the Session object"""
295
+ return self .session .get_session_id ()
334
296
335
- @staticmethod
336
- def get_protocol_version (openSessionResp ):
337
- """
338
- Since the sessionHandle will sometimes have a serverProtocolVersion, it takes
339
- precedence over the serverProtocolVersion defined in the OpenSessionResponse.
340
- """
341
- if (
342
- openSessionResp .sessionHandle
343
- and hasattr (openSessionResp .sessionHandle , "serverProtocolVersion" )
344
- and openSessionResp .sessionHandle .serverProtocolVersion
345
- ):
346
- return openSessionResp .sessionHandle .serverProtocolVersion
347
- return openSessionResp .serverProtocolVersion
297
+ def get_session_id_hex (self ):
298
+ """Get the session ID in hex format from the Session object"""
299
+ return self .session .get_session_id_hex ()
348
300
349
301
@staticmethod
350
302
def server_parameterized_queries_enabled (protocolVersion ):
351
- if (
352
- protocolVersion
353
- and protocolVersion >= ttypes .TProtocolVersion .SPARK_CLI_SERVICE_PROTOCOL_V8
354
- ):
355
- return True
356
- else :
357
- return False
303
+ """Delegate to Session class static method"""
304
+ return Session .server_parameterized_queries_enabled (protocolVersion )
358
305
359
- def get_session_id_hex (self ):
360
- return self .thrift_backend .handle_to_hex_id (self ._session_handle )
306
+ @property
307
+ def protocol_version (self ):
308
+ """Get the protocol version from the Session object"""
309
+ return self .session .protocol_version
310
+
311
+ @staticmethod
312
+ def get_protocol_version (openSessionResp ):
313
+ """Delegate to Session class static method"""
314
+ return Session .get_protocol_version (openSessionResp )
361
315
362
316
def cursor (
363
317
self ,
@@ -369,12 +323,12 @@ def cursor(
369
323
370
324
Will throw an Error if the connection has been closed.
371
325
"""
372
- if not self .open :
326
+ if not self .session . open :
373
327
raise Error ("Cannot create cursor from closed connection" )
374
328
375
329
cursor = Cursor (
376
330
self ,
377
- self .thrift_backend ,
331
+ self .session . thrift_backend ,
378
332
arraysize = arraysize ,
379
333
result_buffer_size_bytes = buffer_size_bytes ,
380
334
)
@@ -390,28 +344,10 @@ def _close(self, close_cursors=True) -> None:
390
344
for cursor in self ._cursors :
391
345
cursor .close ()
392
346
393
- logger .info (f"Closing session { self .get_session_id_hex ()} " )
394
- if not self .open :
395
- logger .debug ("Session appears to have been closed already" )
396
-
397
347
try :
398
- self .thrift_backend .close_session (self ._session_handle )
399
- except RequestError as e :
400
- if isinstance (e .args [1 ], SessionAlreadyClosedError ):
401
- logger .info ("Session was closed by a prior request" )
402
- except DatabaseError as e :
403
- if "Invalid SessionHandle" in str (e ):
404
- logger .warning (
405
- f"Attempted to close session that was already closed: { e } "
406
- )
407
- else :
408
- logger .warning (
409
- f"Attempt to close session raised an exception at the server: { e } "
410
- )
348
+ self .session .close ()
411
349
except Exception as e :
412
- logger .error (f"Attempt to close session raised a local exception: { e } " )
413
-
414
- self .open = False
350
+ logger .error (f"Attempt to close session raised an exception: { e } " )
415
351
416
352
def commit (self ):
417
353
"""No-op because Databricks does not support transactions"""
@@ -811,7 +747,7 @@ def execute(
811
747
self ._close_and_clear_active_result_set ()
812
748
execute_response = self .thrift_backend .execute_command (
813
749
operation = prepared_operation ,
814
- session_handle = self .connection ._session_handle ,
750
+ session_handle = self .connection .session . _session_handle ,
815
751
max_rows = self .arraysize ,
816
752
max_bytes = self .buffer_size_bytes ,
817
753
lz4_compression = self .connection .lz4_compression ,
@@ -874,7 +810,7 @@ def execute_async(
874
810
self ._close_and_clear_active_result_set ()
875
811
self .thrift_backend .execute_command (
876
812
operation = prepared_operation ,
877
- session_handle = self .connection ._session_handle ,
813
+ session_handle = self .connection .session . _session_handle ,
878
814
max_rows = self .arraysize ,
879
815
max_bytes = self .buffer_size_bytes ,
880
816
lz4_compression = self .connection .lz4_compression ,
@@ -970,7 +906,7 @@ def catalogs(self) -> "Cursor":
970
906
self ._check_not_closed ()
971
907
self ._close_and_clear_active_result_set ()
972
908
execute_response = self .thrift_backend .get_catalogs (
973
- session_handle = self .connection ._session_handle ,
909
+ session_handle = self .connection .session . _session_handle ,
974
910
max_rows = self .arraysize ,
975
911
max_bytes = self .buffer_size_bytes ,
976
912
cursor = self ,
@@ -996,7 +932,7 @@ def schemas(
996
932
self ._check_not_closed ()
997
933
self ._close_and_clear_active_result_set ()
998
934
execute_response = self .thrift_backend .get_schemas (
999
- session_handle = self .connection ._session_handle ,
935
+ session_handle = self .connection .session . _session_handle ,
1000
936
max_rows = self .arraysize ,
1001
937
max_bytes = self .buffer_size_bytes ,
1002
938
cursor = self ,
@@ -1029,7 +965,7 @@ def tables(
1029
965
self ._close_and_clear_active_result_set ()
1030
966
1031
967
execute_response = self .thrift_backend .get_tables (
1032
- session_handle = self .connection ._session_handle ,
968
+ session_handle = self .connection .session . _session_handle ,
1033
969
max_rows = self .arraysize ,
1034
970
max_bytes = self .buffer_size_bytes ,
1035
971
cursor = self ,
@@ -1064,7 +1000,7 @@ def columns(
1064
1000
self ._close_and_clear_active_result_set ()
1065
1001
1066
1002
execute_response = self .thrift_backend .get_columns (
1067
- session_handle = self .connection ._session_handle ,
1003
+ session_handle = self .connection .session . _session_handle ,
1068
1004
max_rows = self .arraysize ,
1069
1005
max_bytes = self .buffer_size_bytes ,
1070
1006
cursor = self ,
@@ -1493,7 +1429,7 @@ def close(self) -> None:
1493
1429
if (
1494
1430
self .op_state != self .thrift_backend .CLOSED_OP_STATE
1495
1431
and not self .has_been_closed_server_side
1496
- and self .connection .open
1432
+ and self .connection .session . open
1497
1433
):
1498
1434
self .thrift_backend .close_command (self .command_id )
1499
1435
except RequestError as e :
0 commit comments