Skip to content

Separate Session related functionality from Connection class #571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f97c81d
decouple session class from existing Connection
varun-edachali-dbx May 20, 2025
fe0af87
add open property to Connection to ensure maintenance of existing API
varun-edachali-dbx May 20, 2025
18f8f67
update unit tests to address ThriftBackend through session instead of…
varun-edachali-dbx May 20, 2025
fd8decb
chore: move session specific tests from test_client to test_session
varun-edachali-dbx May 20, 2025
1a92b77
formatting (black)
varun-edachali-dbx May 20, 2025
1b9a50a
use connection open property instead of long chain through session
varun-edachali-dbx May 20, 2025
0bf2794
trigger integration workflow
varun-edachali-dbx May 20, 2025
ff35165
fix: ensure open attribute of Connection never fails
varun-edachali-dbx May 21, 2025
0df486a
fix: de-complicate earlier connection open logic
varun-edachali-dbx May 23, 2025
63b10c3
Revert "fix: de-complicate earlier connection open logic"
varun-edachali-dbx May 23, 2025
f2b3fd5
[empty commit] attempt to trigger ci e2e workflow
varun-edachali-dbx May 23, 2025
53f16ab
Update CODEOWNERS (#562)
jprakash-db May 21, 2025
a026751
Enhance Cursor close handling and context manager exception managemen…
madhav-db May 21, 2025
0d6995c
PECOBLR-86 improve logging on python driver (#556)
saishreeeee May 22, 2025
923bbb6
Revert "Merge remote-tracking branch 'upstream/sea-migration' into de…
varun-edachali-dbx May 23, 2025
8df8c33
Reapply "Merge remote-tracking branch 'upstream/sea-migration' into d…
varun-edachali-dbx May 23, 2025
bcf5994
fix: separate session opening logic from instantiation
varun-edachali-dbx May 23, 2025
500dd0b
fix: use is_open attribute to denote session availability
varun-edachali-dbx May 23, 2025
510b454
fix: access thrift backend through session
varun-edachali-dbx May 23, 2025
634faa9
chore: use get_handle() instead of private session attribute in client
varun-edachali-dbx May 24, 2025
a32862b
formatting (black)
varun-edachali-dbx May 24, 2025
88b728d
Merge remote-tracking branch 'upstream/sea-migration' into decouple-s…
varun-edachali-dbx May 24, 2025
ed04584
Merge remote-tracking branch 'upstream/sea-migration' into decouple-s…
varun-edachali-dbx May 26, 2025
ff842d7
fix: remove accidentally removed assertions
varun-edachali-dbx May 26, 2025
9190a33
Merge remote-tracking branch 'origin/sea-migration' into decouple-ses…
varun-edachali-dbx May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 45 additions & 102 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from databricks.sql.types import Row, SSLOptions
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
from databricks.sql.session import Session

from databricks.sql.thrift_api.TCLIService.ttypes import (
TSparkParameter,
Expand Down Expand Up @@ -224,66 +225,28 @@ def read(self) -> Optional[OAuthToken]:
access_token_kv = {"access_token": access_token}
kwargs = {**kwargs, **access_token_kv}

self.open = False
self.host = server_hostname
self.port = kwargs.get("_port", 443)
self.disable_pandas = kwargs.get("_disable_pandas", False)
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
self._cursors = [] # type: List[Cursor]

auth_provider = get_python_sql_connector_auth_provider(
server_hostname, **kwargs
)

user_agent_entry = kwargs.get("user_agent_entry")
if user_agent_entry is None:
user_agent_entry = kwargs.get("_user_agent_entry")
if user_agent_entry is not None:
logger.warning(
"[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. "
"This parameter will be removed in the upcoming releases."
)

if user_agent_entry:
useragent_header = "{}/{} ({})".format(
USER_AGENT_NAME, __version__, user_agent_entry
)
else:
useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)

base_headers = [("User-Agent", useragent_header)]

self._ssl_options = SSLOptions(
# Double negation is generally a bad thing, but we have to keep backward compatibility
tls_verify=not kwargs.get(
"_tls_no_verify", False
), # by default - verify cert and host
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
)

self.thrift_backend = ThriftBackend(
self.host,
self.port,
# Create the session
self.session = Session(
server_hostname,
http_path,
(http_headers or []) + base_headers,
auth_provider,
ssl_options=self._ssl_options,
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
http_headers,
session_configuration,
catalog,
schema,
_use_arrow_native_complex_types,
**kwargs,
)
self.session.open()

self._open_session_resp = self.thrift_backend.open_session(
session_configuration, catalog, schema
logger.info(
"Successfully opened connection with session "
+ str(self.get_session_id_hex())
)
self._session_handle = self._open_session_resp.sessionHandle
self.protocol_version = self.get_protocol_version(self._open_session_resp)
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
self.open = True
logger.info("Successfully opened session " + str(self.get_session_id_hex()))
self._cursors = [] # type: List[Cursor]

self.use_inline_params = self._set_use_inline_params_with_warning(
kwargs.get("use_inline_params", False)
Expand Down Expand Up @@ -342,34 +305,32 @@ def __del__(self):
logger.debug("Couldn't close unclosed connection: {}".format(e.message))

def get_session_id(self):
return self.thrift_backend.handle_to_id(self._session_handle)
"""Get the session ID from the Session object"""
return self.session.get_id()

@staticmethod
def get_protocol_version(openSessionResp):
"""
Since the sessionHandle will sometimes have a serverProtocolVersion, it takes
precedence over the serverProtocolVersion defined in the OpenSessionResponse.
"""
if (
openSessionResp.sessionHandle
and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion")
and openSessionResp.sessionHandle.serverProtocolVersion
):
return openSessionResp.sessionHandle.serverProtocolVersion
return openSessionResp.serverProtocolVersion
def get_session_id_hex(self):
"""Get the session ID in hex format from the Session object"""
return self.session.get_id_hex()

@staticmethod
def server_parameterized_queries_enabled(protocolVersion):
if (
protocolVersion
and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8
):
return True
else:
return False
"""Delegate to Session class static method"""
return Session.server_parameterized_queries_enabled(protocolVersion)

def get_session_id_hex(self):
return self.thrift_backend.handle_to_hex_id(self._session_handle)
@property
def protocol_version(self):
"""Get the protocol version from the Session object"""
return self.session.protocol_version

@staticmethod
def get_protocol_version(openSessionResp):
"""Delegate to Session class static method"""
return Session.get_protocol_version(openSessionResp)

@property
def open(self) -> bool:
"""Return whether the connection is open by checking if the session is open."""
return self.session.is_open
Comment on lines +320 to +333
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should we group together property and staticmethod? @jprakash-db any coding/lint guidelines OSS python driver follows?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no specific standard in python, because saying something as private etc has no meaning, we can access anything anytime. There are some general standards but nothing concrete


def cursor(
self,
Expand All @@ -386,7 +347,7 @@ def cursor(

cursor = Cursor(
self,
self.thrift_backend,
self.session.thrift_backend,
arraysize=arraysize,
result_buffer_size_bytes=buffer_size_bytes,
)
Expand All @@ -402,28 +363,10 @@ def _close(self, close_cursors=True) -> None:
for cursor in self._cursors:
cursor.close()

logger.info(f"Closing session {self.get_session_id_hex()}")
if not self.open:
logger.debug("Session appears to have been closed already")

try:
self.thrift_backend.close_session(self._session_handle)
except RequestError as e:
if isinstance(e.args[1], SessionAlreadyClosedError):
logger.info("Session was closed by a prior request")
except DatabaseError as e:
if "Invalid SessionHandle" in str(e):
logger.warning(
f"Attempted to close session that was already closed: {e}"
)
else:
logger.warning(
f"Attempt to close session raised an exception at the server: {e}"
)
self.session.close()
except Exception as e:
logger.error(f"Attempt to close session raised a local exception: {e}")

self.open = False
logger.error(f"Attempt to close session raised an exception: {e}")

def commit(self):
"""No-op because Databricks does not support transactions"""
Expand Down Expand Up @@ -833,7 +776,7 @@ def execute(
self._close_and_clear_active_result_set()
execute_response = self.thrift_backend.execute_command(
operation=prepared_operation,
session_handle=self.connection._session_handle,
session_handle=self.connection.session.get_handle(),
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
lz4_compression=self.connection.lz4_compression,
Expand Down Expand Up @@ -896,7 +839,7 @@ def execute_async(
self._close_and_clear_active_result_set()
self.thrift_backend.execute_command(
operation=prepared_operation,
session_handle=self.connection._session_handle,
session_handle=self.connection.session.get_handle(),
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
lz4_compression=self.connection.lz4_compression,
Expand Down Expand Up @@ -992,7 +935,7 @@ def catalogs(self) -> "Cursor":
self._check_not_closed()
self._close_and_clear_active_result_set()
execute_response = self.thrift_backend.get_catalogs(
session_handle=self.connection._session_handle,
session_handle=self.connection.session.get_handle(),
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
cursor=self,
Expand All @@ -1018,7 +961,7 @@ def schemas(
self._check_not_closed()
self._close_and_clear_active_result_set()
execute_response = self.thrift_backend.get_schemas(
session_handle=self.connection._session_handle,
session_handle=self.connection.session.get_handle(),
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
cursor=self,
Expand Down Expand Up @@ -1051,7 +994,7 @@ def tables(
self._close_and_clear_active_result_set()

execute_response = self.thrift_backend.get_tables(
session_handle=self.connection._session_handle,
session_handle=self.connection.session.get_handle(),
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
cursor=self,
Expand Down Expand Up @@ -1086,7 +1029,7 @@ def columns(
self._close_and_clear_active_result_set()

execute_response = self.thrift_backend.get_columns(
session_handle=self.connection._session_handle,
session_handle=self.connection.session.get_handle(),
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
cursor=self,
Expand Down
160 changes: 160 additions & 0 deletions src/databricks/sql/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import logging
from typing import Dict, Tuple, List, Optional, Any

from databricks.sql.thrift_api.TCLIService import ttypes
from databricks.sql.types import SSLOptions
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError
from databricks.sql import __version__
from databricks.sql import USER_AGENT_NAME
from databricks.sql.thrift_backend import ThriftBackend

logger = logging.getLogger(__name__)


class Session:
def __init__(
self,
server_hostname: str,
http_path: str,
http_headers: Optional[List[Tuple[str, str]]] = None,
session_configuration: Optional[Dict[str, Any]] = None,
catalog: Optional[str] = None,
schema: Optional[str] = None,
_use_arrow_native_complex_types: Optional[bool] = True,
**kwargs,
Comment on lines +17 to +25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

  • these are identical to connection params. are all of them relevant to a DBSQL session?
  • should be introduce a type/namedtuple like ConnectionParams or a better name. The reason is that now for each added connection param, we would have to modify at two places

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, that makes sense.

Shouldn't we name it SessionParams though? These parameters will be passed through to the session constructor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe Session is an internal abstraction and these params originate from Connection so better to name ConnectionParams. Additionally, i think this change might break a lot of things. Let's do it completely separately (let's log a JIRA ticket for now and take it up later)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

) -> None:
"""
Create a session to a Databricks SQL endpoint or a Databricks cluster.

This class handles all session-related behavior and communication with the backend.
"""
self.is_open = False
self.host = server_hostname
self.port = kwargs.get("_port", 443)

self.session_configuration = session_configuration
self.catalog = catalog
self.schema = schema

auth_provider = get_python_sql_connector_auth_provider(
server_hostname, **kwargs
)

user_agent_entry = kwargs.get("user_agent_entry")
if user_agent_entry is None:
user_agent_entry = kwargs.get("_user_agent_entry")
if user_agent_entry is not None:
logger.warning(
"[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. "
"This parameter will be removed in the upcoming releases."
)

if user_agent_entry:
useragent_header = "{}/{} ({})".format(
USER_AGENT_NAME, __version__, user_agent_entry
)
else:
useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)

base_headers = [("User-Agent", useragent_header)]

self._ssl_options = SSLOptions(
# Double negation is generally a bad thing, but we have to keep backward compatibility
tls_verify=not kwargs.get(
"_tls_no_verify", False
), # by default - verify cert and host
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
)

self.thrift_backend = ThriftBackend(
self.host,
self.port,
http_path,
(http_headers or []) + base_headers,
auth_provider,
ssl_options=self._ssl_options,
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
**kwargs,
)

self._handle = None
self.protocol_version = None

def open(self) -> None:
self._open_session_resp = self.thrift_backend.open_session(
self.session_configuration, self.catalog, self.schema
)
self._handle = self._open_session_resp.sessionHandle
self.protocol_version = self.get_protocol_version(self._open_session_resp)
self.is_open = True
logger.info("Successfully opened session " + str(self.get_id_hex()))

@staticmethod
def get_protocol_version(openSessionResp):
"""
Since the sessionHandle will sometimes have a serverProtocolVersion, it takes
precedence over the serverProtocolVersion defined in the OpenSessionResponse.
"""
Comment on lines +101 to +102
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
precedence over the serverProtocolVersion defined in the OpenSessionResponse.
"""
precedence over the serverProtocolVersion defined in the OpenSessionResponse.
"""

i think there is a line gap after a multi-line pydoc. @jprakash-db do we follow any python coding guidelines like for docstring: https://peps.python.org/pep-0257/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in databricks we follow this https://databricks.atlassian.net/wiki/spaces/UN/pages/3334538555/Python+Guidelines+go+py but this could be different for OSS repo.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use the black formatter which follows the PEP-257 style.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. is there a linter? in the CI or do we have to run the linter manually?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we have pylint but it needs to be run manually.

if (
openSessionResp.sessionHandle
and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion")
and openSessionResp.sessionHandle.serverProtocolVersion
):
return openSessionResp.sessionHandle.serverProtocolVersion
return openSessionResp.serverProtocolVersion

@staticmethod
def server_parameterized_queries_enabled(protocolVersion):
if (
protocolVersion
and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8
):
return True
else:
return False

def get_handle(self):
return self._handle

def get_id(self):
handle = self.get_handle()
if handle is None:
return None
return self.thrift_backend.handle_to_id(handle)

def get_id_hex(self):
handle = self.get_handle()
if handle is None:
return None
return self.thrift_backend.handle_to_hex_id(handle)

def close(self) -> None:
"""Close the underlying session."""
logger.info(f"Closing session {self.get_id_hex()}")
if not self.is_open:
logger.debug("Session appears to have been closed already")
return

try:
self.thrift_backend.close_session(self.get_handle())
except RequestError as e:
if isinstance(e.args[1], SessionAlreadyClosedError):
logger.info("Session was closed by a prior request")
except DatabaseError as e:
if "Invalid SessionHandle" in str(e):
logger.warning(
f"Attempted to close session that was already closed: {e}"
)
else:
logger.warning(
f"Attempt to close session raised an exception at the server: {e}"
)
except Exception as e:
logger.error(f"Attempt to close session raised a local exception: {e}")

self.is_open = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when opening session the flag is set at the end which makes sense. for closing session call, should we be eager to unset the flag in the very beginning?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the session close fails? Shouldn't the session remain open?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the client code has called on the Session class to close the session, then the client assumes that method will close the session. I think unsetting the flag right away makes more sense then. However, an interesting question is do we use this flag internally in Session class to make unsetting meaningful (i.e., when flag is false, do we give null or throw exception when getting session handle?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently there does not seem to exist such a dependency, but I'm still not clear on this.

If the close() call raises an exception isn't the client expected to retry?

Copy link
Collaborator Author

@varun-edachali-dbx varun-edachali-dbx May 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loading
Loading