Skip to content

Commit 7383290

Browse files
decouple session class from existing Connection
ensure maintenance of current APIs of Connection while delegating responsibility Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent 7b51c6e commit 7383290

File tree

2 files changed

+190
-108
lines changed

2 files changed

+190
-108
lines changed

src/databricks/sql/client.py

Lines changed: 44 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
OperationalError,
2020
SessionAlreadyClosedError,
2121
CursorAlreadyClosedError,
22+
Error,
23+
NotSupportedError,
2224
)
2325
from databricks.sql.thrift_api.TCLIService import ttypes
2426
from databricks.sql.thrift_backend import ThriftBackend
@@ -45,6 +47,7 @@
4547
from databricks.sql.types import Row, SSLOptions
4648
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
4749
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
50+
from databricks.sql.session import Session
4851

4952
from databricks.sql.thrift_api.TCLIService.ttypes import (
5053
TSparkParameter,
@@ -218,66 +221,24 @@ def read(self) -> Optional[OAuthToken]:
218221
access_token_kv = {"access_token": access_token}
219222
kwargs = {**kwargs, **access_token_kv}
220223

221-
self.open = False
222-
self.host = server_hostname
223-
self.port = kwargs.get("_port", 443)
224224
self.disable_pandas = kwargs.get("_disable_pandas", False)
225225
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
226+
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
227+
self._cursors = [] # type: List[Cursor]
226228

227-
auth_provider = get_python_sql_connector_auth_provider(
228-
server_hostname, **kwargs
229-
)
230-
231-
user_agent_entry = kwargs.get("user_agent_entry")
232-
if user_agent_entry is None:
233-
user_agent_entry = kwargs.get("_user_agent_entry")
234-
if user_agent_entry is not None:
235-
logger.warning(
236-
"[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. "
237-
"This parameter will be removed in the upcoming releases."
238-
)
239-
240-
if user_agent_entry:
241-
useragent_header = "{}/{} ({})".format(
242-
USER_AGENT_NAME, __version__, user_agent_entry
243-
)
244-
else:
245-
useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)
246-
247-
base_headers = [("User-Agent", useragent_header)]
248-
249-
self._ssl_options = SSLOptions(
250-
# Double negation is generally a bad thing, but we have to keep backward compatibility
251-
tls_verify=not kwargs.get(
252-
"_tls_no_verify", False
253-
), # by default - verify cert and host
254-
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
255-
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
256-
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
257-
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
258-
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
259-
)
260-
261-
self.thrift_backend = ThriftBackend(
262-
self.host,
263-
self.port,
229+
# Create the session
230+
self.session = Session(
231+
server_hostname,
264232
http_path,
265-
(http_headers or []) + base_headers,
266-
auth_provider,
267-
ssl_options=self._ssl_options,
268-
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
269-
**kwargs,
270-
)
271-
272-
self._open_session_resp = self.thrift_backend.open_session(
273-
session_configuration, catalog, schema
233+
http_headers,
234+
session_configuration,
235+
catalog,
236+
schema,
237+
_use_arrow_native_complex_types,
238+
**kwargs
274239
)
275-
self._session_handle = self._open_session_resp.sessionHandle
276-
self.protocol_version = self.get_protocol_version(self._open_session_resp)
277-
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
278-
self.open = True
279-
logger.info("Successfully opened session " + str(self.get_session_id_hex()))
280-
self._cursors = [] # type: List[Cursor]
240+
241+
logger.info("Successfully opened connection with session " + str(self.get_session_id_hex()))
281242

282243
self.use_inline_params = self._set_use_inline_params_with_warning(
283244
kwargs.get("use_inline_params", False)
@@ -318,7 +279,7 @@ def __exit__(self, exc_type, exc_value, traceback):
318279
self.close()
319280

320281
def __del__(self):
321-
if self.open:
282+
if self.session.open:
322283
logger.debug(
323284
"Closing unclosed connection for session "
324285
"{}".format(self.get_session_id_hex())
@@ -330,34 +291,27 @@ def __del__(self):
330291
logger.debug("Couldn't close unclosed connection: {}".format(e.message))
331292

332293
def get_session_id(self):
333-
return self.thrift_backend.handle_to_id(self._session_handle)
294+
"""Get the session ID from the Session object"""
295+
return self.session.get_session_id()
334296

335-
@staticmethod
336-
def get_protocol_version(openSessionResp):
337-
"""
338-
Since the sessionHandle will sometimes have a serverProtocolVersion, it takes
339-
precedence over the serverProtocolVersion defined in the OpenSessionResponse.
340-
"""
341-
if (
342-
openSessionResp.sessionHandle
343-
and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion")
344-
and openSessionResp.sessionHandle.serverProtocolVersion
345-
):
346-
return openSessionResp.sessionHandle.serverProtocolVersion
347-
return openSessionResp.serverProtocolVersion
297+
def get_session_id_hex(self):
298+
"""Get the session ID in hex format from the Session object"""
299+
return self.session.get_session_id_hex()
348300

349301
@staticmethod
350302
def server_parameterized_queries_enabled(protocolVersion):
351-
if (
352-
protocolVersion
353-
and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8
354-
):
355-
return True
356-
else:
357-
return False
303+
"""Delegate to Session class static method"""
304+
return Session.server_parameterized_queries_enabled(protocolVersion)
358305

359-
def get_session_id_hex(self):
360-
return self.thrift_backend.handle_to_hex_id(self._session_handle)
306+
@property
307+
def protocol_version(self):
308+
"""Get the protocol version from the Session object"""
309+
return self.session.protocol_version
310+
311+
@staticmethod
312+
def get_protocol_version(openSessionResp):
313+
"""Delegate to Session class static method"""
314+
return Session.get_protocol_version(openSessionResp)
361315

362316
def cursor(
363317
self,
@@ -369,12 +323,12 @@ def cursor(
369323
370324
Will throw an Error if the connection has been closed.
371325
"""
372-
if not self.open:
326+
if not self.session.open:
373327
raise Error("Cannot create cursor from closed connection")
374328

375329
cursor = Cursor(
376330
self,
377-
self.thrift_backend,
331+
self.session.thrift_backend,
378332
arraysize=arraysize,
379333
result_buffer_size_bytes=buffer_size_bytes,
380334
)
@@ -390,28 +344,10 @@ def _close(self, close_cursors=True) -> None:
390344
for cursor in self._cursors:
391345
cursor.close()
392346

393-
logger.info(f"Closing session {self.get_session_id_hex()}")
394-
if not self.open:
395-
logger.debug("Session appears to have been closed already")
396-
397347
try:
398-
self.thrift_backend.close_session(self._session_handle)
399-
except RequestError as e:
400-
if isinstance(e.args[1], SessionAlreadyClosedError):
401-
logger.info("Session was closed by a prior request")
402-
except DatabaseError as e:
403-
if "Invalid SessionHandle" in str(e):
404-
logger.warning(
405-
f"Attempted to close session that was already closed: {e}"
406-
)
407-
else:
408-
logger.warning(
409-
f"Attempt to close session raised an exception at the server: {e}"
410-
)
348+
self.session.close()
411349
except Exception as e:
412-
logger.error(f"Attempt to close session raised a local exception: {e}")
413-
414-
self.open = False
350+
logger.error(f"Attempt to close session raised an exception: {e}")
415351

416352
def commit(self):
417353
"""No-op because Databricks does not support transactions"""
@@ -811,7 +747,7 @@ def execute(
811747
self._close_and_clear_active_result_set()
812748
execute_response = self.thrift_backend.execute_command(
813749
operation=prepared_operation,
814-
session_handle=self.connection._session_handle,
750+
session_handle=self.connection.session._session_handle,
815751
max_rows=self.arraysize,
816752
max_bytes=self.buffer_size_bytes,
817753
lz4_compression=self.connection.lz4_compression,
@@ -874,7 +810,7 @@ def execute_async(
874810
self._close_and_clear_active_result_set()
875811
self.thrift_backend.execute_command(
876812
operation=prepared_operation,
877-
session_handle=self.connection._session_handle,
813+
session_handle=self.connection.session._session_handle,
878814
max_rows=self.arraysize,
879815
max_bytes=self.buffer_size_bytes,
880816
lz4_compression=self.connection.lz4_compression,
@@ -970,7 +906,7 @@ def catalogs(self) -> "Cursor":
970906
self._check_not_closed()
971907
self._close_and_clear_active_result_set()
972908
execute_response = self.thrift_backend.get_catalogs(
973-
session_handle=self.connection._session_handle,
909+
session_handle=self.connection.session._session_handle,
974910
max_rows=self.arraysize,
975911
max_bytes=self.buffer_size_bytes,
976912
cursor=self,
@@ -996,7 +932,7 @@ def schemas(
996932
self._check_not_closed()
997933
self._close_and_clear_active_result_set()
998934
execute_response = self.thrift_backend.get_schemas(
999-
session_handle=self.connection._session_handle,
935+
session_handle=self.connection.session._session_handle,
1000936
max_rows=self.arraysize,
1001937
max_bytes=self.buffer_size_bytes,
1002938
cursor=self,
@@ -1029,7 +965,7 @@ def tables(
1029965
self._close_and_clear_active_result_set()
1030966

1031967
execute_response = self.thrift_backend.get_tables(
1032-
session_handle=self.connection._session_handle,
968+
session_handle=self.connection.session._session_handle,
1033969
max_rows=self.arraysize,
1034970
max_bytes=self.buffer_size_bytes,
1035971
cursor=self,
@@ -1064,7 +1000,7 @@ def columns(
10641000
self._close_and_clear_active_result_set()
10651001

10661002
execute_response = self.thrift_backend.get_columns(
1067-
session_handle=self.connection._session_handle,
1003+
session_handle=self.connection.session._session_handle,
10681004
max_rows=self.arraysize,
10691005
max_bytes=self.buffer_size_bytes,
10701006
cursor=self,
@@ -1493,7 +1429,7 @@ def close(self) -> None:
14931429
if (
14941430
self.op_state != self.thrift_backend.CLOSED_OP_STATE
14951431
and not self.has_been_closed_server_side
1496-
and self.connection.open
1432+
and self.connection.session.open
14971433
):
14981434
self.thrift_backend.close_command(self.command_id)
14991435
except RequestError as e:

0 commit comments

Comments
 (0)