Skip to content

Commit 308fbe0

Browse files
generalise open session, fix session tests to consider positional args
1 parent 15055c0 commit 308fbe0

File tree

2 files changed

+15
-17
lines changed

2 files changed

+15
-17
lines changed

src/databricks/sql/session.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,27 +84,25 @@ def __init__(
8484
**kwargs,
8585
)
8686

87-
self._handle = None
8887
self.protocol_version = None
8988

9089
def open(self) -> None:
91-
self._open_session_resp = self.thrift_backend.open_session(
90+
self._session_id = self.backend.open_session(
9291
self.session_configuration, self.catalog, self.schema
9392
)
94-
self._handle = self._open_session_resp.sessionHandle
95-
self.protocol_version = self.get_protocol_version(self._open_session_resp)
93+
self.protocol_version = self.get_protocol_version(self._session_id)
9694
self.is_open = True
9795
logger.info("Successfully opened session " + str(self.get_id_hex()))
9896

9997
@staticmethod
100-
def get_protocol_version(openSessionResp):
98+
def get_protocol_version(session_id: SessionId):
10199
"""
102100
Since the sessionHandle will sometimes have a serverProtocolVersion, it takes
103101
precedence over the serverProtocolVersion defined in the OpenSessionResponse.
104102
"""
105-
if sessionId.backend_type != BackendType.THRIFT:
103+
if session_id.backend_type != BackendType.THRIFT:
106104
return None
107-
session_handle = sessionId.to_thrift_handle()
105+
session_handle = session_id.to_thrift_handle()
108106
if (
109107
session_handle
110108
and hasattr(session_handle, "serverProtocolVersion")

tests/unit/test_session.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -146,18 +146,18 @@ def test_socket_timeout_passthrough(self, mock_client_class):
146146
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
147147
def test_configuration_passthrough(self, mock_client_class):
148148
mock_session_config = Mock()
149-
149+
150150
# Create a mock SessionId that will be returned by open_session
151151
mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33")
152152
mock_client_class.return_value.open_session.return_value = mock_session_id
153-
153+
154154
databricks.sql.connect(
155155
session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS
156156
)
157157

158-
# Check that open_session was called with the correct session_configuration
159-
call_kwargs = mock_client_class.return_value.open_session.call_args[1]
160-
self.assertEqual(call_kwargs["session_configuration"], mock_session_config)
158+
# Check that open_session was called with the correct session_configuration as first positional argument
159+
call_args = mock_client_class.return_value.open_session.call_args[0]
160+
self.assertEqual(call_args[0], mock_session_config)
161161

162162
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
163163
def test_initial_namespace_passthrough(self, mock_client_class):
@@ -171,11 +171,11 @@ def test_initial_namespace_passthrough(self, mock_client_class):
171171
databricks.sql.connect(
172172
**self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem
173173
)
174-
175-
# Check that open_session was called with the correct catalog and schema
176-
call_kwargs = mock_client_class.return_value.open_session.call_args[1]
177-
self.assertEqual(call_kwargs["catalog"], mock_cat)
178-
self.assertEqual(call_kwargs["schema"], mock_schem)
174+
175+
# Check that open_session was called with the correct catalog and schema as positional arguments
176+
call_args = mock_client_class.return_value.open_session.call_args[0]
177+
self.assertEqual(call_args[1], mock_cat) # catalog is second positional argument
178+
self.assertEqual(call_args[2], mock_schem) # schema is third positional argument
179179

180180
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
181181
def test_finalizer_closes_abandoned_connection(self, mock_client_class):

0 commit comments

Comments
 (0)