diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 0c9a08a8..d6a9e6b0 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -45,6 +45,7 @@ from databricks.sql.types import Row, SSLOptions from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import OAuthPersistence +from databricks.sql.session import Session from databricks.sql.thrift_api.TCLIService.ttypes import ( TSparkParameter, @@ -224,66 +225,28 @@ def read(self) -> Optional[OAuthToken]: access_token_kv = {"access_token": access_token} kwargs = {**kwargs, **access_token_kv} - self.open = False - self.host = server_hostname - self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) + self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) + self._cursors = [] # type: List[Cursor] - auth_provider = get_python_sql_connector_auth_provider( - server_hostname, **kwargs - ) - - user_agent_entry = kwargs.get("user_agent_entry") - if user_agent_entry is None: - user_agent_entry = kwargs.get("_user_agent_entry") - if user_agent_entry is not None: - logger.warning( - "[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. " - "This parameter will be removed in the upcoming releases." - ) - - if user_agent_entry: - useragent_header = "{}/{} ({})".format( - USER_AGENT_NAME, __version__, user_agent_entry - ) - else: - useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) - - base_headers = [("User-Agent", useragent_header)] - - self._ssl_options = SSLOptions( - # Double negation is generally a bad thing, but we have to keep backward compatibility - tls_verify=not kwargs.get( - "_tls_no_verify", False - ), # by default - verify cert and host - tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), - tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), - tls_client_cert_file=kwargs.get("_tls_client_cert_file"), - tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), - tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), - ) - - self.thrift_backend = ThriftBackend( - self.host, - self.port, + # Create the session + self.session = Session( + server_hostname, http_path, - (http_headers or []) + base_headers, - auth_provider, - ssl_options=self._ssl_options, - _use_arrow_native_complex_types=_use_arrow_native_complex_types, + http_headers, + session_configuration, + catalog, + schema, + _use_arrow_native_complex_types, **kwargs, ) + self.session.open() - self._open_session_resp = self.thrift_backend.open_session( - session_configuration, catalog, schema + logger.info( + "Successfully opened connection with session " + + str(self.get_session_id_hex()) ) - self._session_handle = self._open_session_resp.sessionHandle - self.protocol_version = self.get_protocol_version(self._open_session_resp) - self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) - self.open = True - logger.info("Successfully opened session " + str(self.get_session_id_hex())) - self._cursors = [] # type: List[Cursor] self.use_inline_params = self._set_use_inline_params_with_warning( kwargs.get("use_inline_params", False) @@ -342,34 +305,32 @@ def __del__(self): logger.debug("Couldn't close unclosed connection: {}".format(e.message)) def get_session_id(self): - return self.thrift_backend.handle_to_id(self._session_handle) + """Get the session ID from the Session object""" + return self.session.get_id() - @staticmethod - def get_protocol_version(openSessionResp): - """ - Since the sessionHandle will sometimes have a serverProtocolVersion, it takes - precedence over the serverProtocolVersion defined in the OpenSessionResponse. - """ - if ( - openSessionResp.sessionHandle - and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion") - and openSessionResp.sessionHandle.serverProtocolVersion - ): - return openSessionResp.sessionHandle.serverProtocolVersion - return openSessionResp.serverProtocolVersion + def get_session_id_hex(self): + """Get the session ID in hex format from the Session object""" + return self.session.get_id_hex() @staticmethod def server_parameterized_queries_enabled(protocolVersion): - if ( - protocolVersion - and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 - ): - return True - else: - return False + """Delegate to Session class static method""" + return Session.server_parameterized_queries_enabled(protocolVersion) - def get_session_id_hex(self): - return self.thrift_backend.handle_to_hex_id(self._session_handle) + @property + def protocol_version(self): + """Get the protocol version from the Session object""" + return self.session.protocol_version + + @staticmethod + def get_protocol_version(openSessionResp): + """Delegate to Session class static method""" + return Session.get_protocol_version(openSessionResp) + + @property + def open(self) -> bool: + """Return whether the connection is open by checking if the session is open.""" + return self.session.is_open def cursor( self, @@ -386,7 +347,7 @@ def cursor( cursor = Cursor( self, - self.thrift_backend, + self.session.thrift_backend, arraysize=arraysize, result_buffer_size_bytes=buffer_size_bytes, ) @@ -402,28 +363,10 @@ def _close(self, close_cursors=True) -> None: for cursor in self._cursors: cursor.close() - logger.info(f"Closing session {self.get_session_id_hex()}") - if not self.open: - logger.debug("Session appears to have been closed already") - try: - self.thrift_backend.close_session(self._session_handle) - except RequestError as e: - if isinstance(e.args[1], SessionAlreadyClosedError): - logger.info("Session was closed by a prior request") - except DatabaseError as e: - if "Invalid SessionHandle" in str(e): - logger.warning( - f"Attempted to close session that was already closed: {e}" - ) - else: - logger.warning( - f"Attempt to close session raised an exception at the server: {e}" - ) + self.session.close() except Exception as e: - logger.error(f"Attempt to close session raised a local exception: {e}") - - self.open = False + logger.error(f"Attempt to close session raised an exception: {e}") def commit(self): """No-op because Databricks does not support transactions""" @@ -833,7 +776,7 @@ def execute( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.execute_command( operation=prepared_operation, - session_handle=self.connection._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -896,7 +839,7 @@ def execute_async( self._close_and_clear_active_result_set() self.thrift_backend.execute_command( operation=prepared_operation, - session_handle=self.connection._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -992,7 +935,7 @@ def catalogs(self) -> "Cursor": self._check_not_closed() self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_catalogs( - session_handle=self.connection._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1018,7 +961,7 @@ def schemas( self._check_not_closed() self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_schemas( - session_handle=self.connection._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1051,7 +994,7 @@ def tables( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_tables( - session_handle=self.connection._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1086,7 +1029,7 @@ def columns( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_columns( - session_handle=self.connection._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py new file mode 100644 index 00000000..f2f38d57 --- /dev/null +++ b/src/databricks/sql/session.py @@ -0,0 +1,160 @@ +import logging +from typing import Dict, Tuple, List, Optional, Any + +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.types import SSLOptions +from databricks.sql.auth.auth import get_python_sql_connector_auth_provider +from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError +from databricks.sql import __version__ +from databricks.sql import USER_AGENT_NAME +from databricks.sql.thrift_backend import ThriftBackend + +logger = logging.getLogger(__name__) + + +class Session: + def __init__( + self, + server_hostname: str, + http_path: str, + http_headers: Optional[List[Tuple[str, str]]] = None, + session_configuration: Optional[Dict[str, Any]] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, + _use_arrow_native_complex_types: Optional[bool] = True, + **kwargs, + ) -> None: + """ + Create a session to a Databricks SQL endpoint or a Databricks cluster. + + This class handles all session-related behavior and communication with the backend. + """ + self.is_open = False + self.host = server_hostname + self.port = kwargs.get("_port", 443) + + self.session_configuration = session_configuration + self.catalog = catalog + self.schema = schema + + auth_provider = get_python_sql_connector_auth_provider( + server_hostname, **kwargs + ) + + user_agent_entry = kwargs.get("user_agent_entry") + if user_agent_entry is None: + user_agent_entry = kwargs.get("_user_agent_entry") + if user_agent_entry is not None: + logger.warning( + "[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. " + "This parameter will be removed in the upcoming releases." + ) + + if user_agent_entry: + useragent_header = "{}/{} ({})".format( + USER_AGENT_NAME, __version__, user_agent_entry + ) + else: + useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) + + base_headers = [("User-Agent", useragent_header)] + + self._ssl_options = SSLOptions( + # Double negation is generally a bad thing, but we have to keep backward compatibility + tls_verify=not kwargs.get( + "_tls_no_verify", False + ), # by default - verify cert and host + tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), + tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), + tls_client_cert_file=kwargs.get("_tls_client_cert_file"), + tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), + tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), + ) + + self.thrift_backend = ThriftBackend( + self.host, + self.port, + http_path, + (http_headers or []) + base_headers, + auth_provider, + ssl_options=self._ssl_options, + _use_arrow_native_complex_types=_use_arrow_native_complex_types, + **kwargs, + ) + + self._handle = None + self.protocol_version = None + + def open(self) -> None: + self._open_session_resp = self.thrift_backend.open_session( + self.session_configuration, self.catalog, self.schema + ) + self._handle = self._open_session_resp.sessionHandle + self.protocol_version = self.get_protocol_version(self._open_session_resp) + self.is_open = True + logger.info("Successfully opened session " + str(self.get_id_hex())) + + @staticmethod + def get_protocol_version(openSessionResp): + """ + Since the sessionHandle will sometimes have a serverProtocolVersion, it takes + precedence over the serverProtocolVersion defined in the OpenSessionResponse. + """ + if ( + openSessionResp.sessionHandle + and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion") + and openSessionResp.sessionHandle.serverProtocolVersion + ): + return openSessionResp.sessionHandle.serverProtocolVersion + return openSessionResp.serverProtocolVersion + + @staticmethod + def server_parameterized_queries_enabled(protocolVersion): + if ( + protocolVersion + and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 + ): + return True + else: + return False + + def get_handle(self): + return self._handle + + def get_id(self): + handle = self.get_handle() + if handle is None: + return None + return self.thrift_backend.handle_to_id(handle) + + def get_id_hex(self): + handle = self.get_handle() + if handle is None: + return None + return self.thrift_backend.handle_to_hex_id(handle) + + def close(self) -> None: + """Close the underlying session.""" + logger.info(f"Closing session {self.get_id_hex()}") + if not self.is_open: + logger.debug("Session appears to have been closed already") + return + + try: + self.thrift_backend.close_session(self.get_handle()) + except RequestError as e: + if isinstance(e.args[1], SessionAlreadyClosedError): + logger.info("Session was closed by a prior request") + except DatabaseError as e: + if "Invalid SessionHandle" in str(e): + logger.warning( + f"Attempted to close session that was already closed: {e}" + ) + else: + logger.warning( + f"Attempt to close session raised an exception at the server: {e}" + ) + except Exception as e: + logger.error(f"Attempt to close session raised a local exception: {e}") + + self.is_open = False diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 440d4efb..abe0e22d 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -856,7 +856,9 @@ def test_closing_a_closed_connection_doesnt_fail(self, caplog): raise KeyboardInterrupt("Simulated interrupt") finally: if conn is not None: - assert not conn.open, "Connection should be closed after KeyboardInterrupt" + assert ( + not conn.open + ), "Connection should be closed after KeyboardInterrupt" def test_cursor_close_properly_closes_operation(self): """Test that Cursor.close() properly closes the active operation handle on the server.""" @@ -883,7 +885,9 @@ def test_cursor_close_properly_closes_operation(self): raise KeyboardInterrupt("Simulated interrupt") finally: if cursor is not None: - assert not cursor.open, "Cursor should be closed after KeyboardInterrupt" + assert ( + not cursor.open + ), "Cursor should be closed after KeyboardInterrupt" def test_nested_cursor_context_managers(self): """Test that nested cursor context managers properly close operations on the server.""" @@ -916,7 +920,7 @@ def test_cursor_error_handling(self): assert op_handle is not None # Manually close the operation to simulate server-side closure - conn.thrift_backend.close_command(op_handle) + conn.session.thrift_backend.close_command(op_handle) cursor.close() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 5271baa7..a9c7a43a 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -81,94 +81,7 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_close_uses_the_correct_session_id(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - connection.close() - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_auth_args(self, mock_client_class): - # Test that the following auth args work: - # token = foo, - # token = None, _tls_client_cert_file = something, _use_cert_as_auth = True - connection_args = [ - { - "server_hostname": "foo", - "http_path": None, - "access_token": "tok", - }, - { - "server_hostname": "foo", - "http_path": None, - "_tls_client_cert_file": "something", - "_use_cert_as_auth": True, - "access_token": None, - }, - ] - - for args in connection_args: - connection = databricks.sql.connect(**args) - host, port, http_path, *_ = mock_client_class.call_args[0] - self.assertEqual(args["server_hostname"], host) - self.assertEqual(args["http_path"], http_path) - connection.close() - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_http_header_passthrough(self, mock_client_class): - http_headers = [("foo", "bar")] - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) - - call_args = mock_client_class.call_args[0][3] - self.assertIn(("foo", "bar"), call_args) - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_tls_arg_passthrough(self, mock_client_class): - databricks.sql.connect( - **self.DUMMY_CONNECTION_ARGS, - _tls_verify_hostname="hostname", - _tls_trusted_ca_file="trusted ca file", - _tls_client_cert_key_file="trusted client cert", - _tls_client_cert_key_password="key password", - ) - - kwargs = mock_client_class.call_args[1] - self.assertEqual(kwargs["_tls_verify_hostname"], "hostname") - self.assertEqual(kwargs["_tls_trusted_ca_file"], "trusted ca file") - self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") - self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_useragent_header(self, mock_client_class): - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - - http_headers = mock_client_class.call_args[0][3] - user_agent_header = ( - "User-Agent", - "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), - ) - self.assertIn(user_agent_header, http_headers) - - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar") - user_agent_header_with_entry = ( - "User-Agent", - "{}/{} ({})".format( - databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar" - ), - ) - http_headers = mock_client_class.call_args[0][3] - self.assertIn(user_agent_header_with_entry, http_headers) - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) @patch("%s.client.ResultSet" % PACKAGE_NAME) def test_closing_connection_closes_commands(self, mock_result_set_class): # Test once with has_been_closed_server side, once without @@ -185,7 +98,7 @@ def test_closing_connection_closes_commands(self, mock_result_set_class): ) mock_result_set_class.return_value.close.assert_called_once_with() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) self.assertTrue(connection.open) @@ -195,7 +108,7 @@ def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection.cursor() self.assertIn("closed", str(cm.exception)) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) @patch("%s.client.Cursor" % PACKAGE_NAME) def test_arraysize_buffer_size_passthrough( self, mock_cursor_class, mock_client_class @@ -215,7 +128,10 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self): thrift_backend=mock_backend, execute_response=Mock(), ) - mock_connection.open = False + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.open = False + type(mock_connection).session = PropertyMock(return_value=mock_session) result_set.close() @@ -227,7 +143,11 @@ def test_closing_result_set_hard_closes_commands(self): mock_results_response.has_been_closed_server_side = False mock_connection = Mock() mock_thrift_backend = Mock() - mock_connection.open = True + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.open = True + type(mock_connection).session = PropertyMock(return_value=mock_session) + result_set = client.ResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -285,37 +205,14 @@ def test_context_manager_closes_cursor(self): mock_close.assert_called_once_with() cursor = client.Cursor(Mock(), Mock()) - cursor.close = Mock() - try: - with self.assertRaises(KeyboardInterrupt): - with cursor: - raise KeyboardInterrupt("Simulated interrupt") - finally: - cursor.close.assert_called() + cursor.close = Mock() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_context_manager_closes_connection(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: - pass - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - connection.close = Mock() try: with self.assertRaises(KeyboardInterrupt): - with connection: + with cursor: raise KeyboardInterrupt("Simulated interrupt") finally: - connection.close.assert_called() + cursor.close.assert_called() def dict_product(self, dicts): """ @@ -415,21 +312,6 @@ def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( self.assertTrue(logger_instance.warning.called) self.assertFalse(mock_thrift_backend.cancel_command.called) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_max_number_of_retries_passthrough(self, mock_client_class): - databricks.sql.connect( - _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS - ) - - self.assertEqual( - mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 - ) - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_socket_timeout_passthrough(self, mock_client_class): - databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) - self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) - def test_version_is_canonical(self): version = databricks.sql.__version__ canonical_version_re = ( @@ -438,33 +320,6 @@ def test_version_is_canonical(self): ) self.assertIsNotNone(re.match(canonical_version_re, version)) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_configuration_passthrough(self, mock_client_class): - mock_session_config = Mock() - databricks.sql.connect( - session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS - ) - - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][0], - mock_session_config, - ) - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_initial_namespace_passthrough(self, mock_client_class): - mock_cat = Mock() - mock_schem = Mock() - - databricks.sql.connect( - **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem - ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][1], mock_cat - ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][2], mock_schem - ) - def test_execute_parameter_passthrough(self): mock_thrift_backend = ThriftBackendMockFactory.new() cursor = client.Cursor(Mock(), mock_thrift_backend) @@ -524,7 +379,7 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set( "last operation", ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_commit_a_noop(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) c.commit() @@ -537,7 +392,7 @@ def test_setoutputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setoutputsize(1) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_rollback_not_supported(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) with self.assertRaises(NotSupportedError): @@ -622,24 +477,7 @@ def test_column_name_api(self): }, ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_finalizer_closes_abandoned_connection(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - - # not strictly necessary as the refcount is 0, but just to be sure - gc.collect() - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_cursor_keeps_connection_alive(self, mock_client_class): instance = mock_client_class.return_value @@ -658,7 +496,7 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_staging_operation_response_is_handled( self, mock_client_class, mock_handle_staging_operation, mock_execute_response ): @@ -677,7 +515,7 @@ def test_staging_operation_response_is_handled( mock_handle_staging_operation.call_count == 1 - @patch("%s.client.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) def test_access_current_query_id(self): operation_id = "EE6A8778-21FC-438B-92D8-96AC51EE3821" @@ -700,7 +538,7 @@ def test_cursor_close_handles_exception(self): mock_backend = Mock() mock_connection = Mock() mock_op_handle = Mock() - + mock_backend.close_command.side_effect = Exception("Test error") cursor = client.Cursor(mock_connection, mock_backend) @@ -709,78 +547,80 @@ def test_cursor_close_handles_exception(self): cursor.close() mock_backend.close_command.assert_called_once_with(mock_op_handle) - + self.assertIsNone(cursor.active_op_handle) - + self.assertFalse(cursor.open) def test_cursor_context_manager_handles_exit_exception(self): """Test that cursor's context manager handles exceptions during __exit__.""" mock_backend = Mock() mock_connection = Mock() - + cursor = client.Cursor(mock_connection, mock_backend) original_close = cursor.close cursor.close = Mock(side_effect=Exception("Test error during close")) - + try: with cursor: raise ValueError("Test error inside context") except ValueError: pass - + cursor.close.assert_called_once() def test_connection_close_handles_cursor_close_exception(self): """Test that _close handles exceptions from cursor.close() properly.""" cursors_closed = [] - + def mock_close_with_exception(): cursors_closed.append(1) raise Exception("Test error during close") - + cursor1 = Mock() cursor1.close = mock_close_with_exception - + def mock_close_normal(): cursors_closed.append(2) - + cursor2 = Mock() cursor2.close = mock_close_normal - + mock_backend = Mock() mock_session_handle = Mock() - + try: for cursor in [cursor1, cursor2]: try: cursor.close() except Exception: pass - + mock_backend.close_session(mock_session_handle) except Exception as e: self.fail(f"Connection close should handle exceptions: {e}") - - self.assertEqual(cursors_closed, [1, 2], "Both cursors should have close called") + + self.assertEqual( + cursors_closed, [1, 2], "Both cursors should have close called" + ) def test_resultset_close_handles_cursor_already_closed_error(self): """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" result_set = client.ResultSet.__new__(client.ResultSet) result_set.thrift_backend = Mock() - result_set.thrift_backend.CLOSED_OP_STATE = 'CLOSED' + result_set.thrift_backend.CLOSED_OP_STATE = "CLOSED" result_set.connection = Mock() result_set.connection.open = True - result_set.op_state = 'RUNNING' + result_set.op_state = "RUNNING" result_set.has_been_closed_server_side = False result_set.command_id = Mock() class MockRequestError(Exception): def __init__(self): self.args = ["Error message", CursorAlreadyClosedError()] - + result_set.thrift_backend.close_command.side_effect = MockRequestError() - + original_close = client.ResultSet.close try: try: @@ -796,11 +636,13 @@ def __init__(self): finally: result_set.has_been_closed_server_side = True result_set.op_state = result_set.thrift_backend.CLOSED_OP_STATE - - result_set.thrift_backend.close_command.assert_called_once_with(result_set.command_id) - + + result_set.thrift_backend.close_command.assert_called_once_with( + result_set.command_id + ) + assert result_set.has_been_closed_server_side is True - + assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE finally: pass diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py new file mode 100644 index 00000000..eb392a22 --- /dev/null +++ b/tests/unit/test_session.py @@ -0,0 +1,187 @@ +import unittest +from unittest.mock import patch, MagicMock, Mock, PropertyMock +import gc + +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, +) + +import databricks.sql + + +class SessionTestSuite(unittest.TestCase): + """ + Unit tests for Session functionality + """ + + PACKAGE_NAME = "databricks.sql" + DUMMY_CONNECTION_ARGS = { + "server_hostname": "foo", + "http_path": "dummy_path", + "access_token": "tok", + } + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_close_uses_the_correct_session_id(self, mock_client_class): + instance = mock_client_class.return_value + + mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + mock_open_session_resp.sessionHandle.sessionId = b"\x22" + instance.open_session.return_value = mock_open_session_resp + + connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection.close() + + # Check the close session request has an id of x22 + close_session_id = instance.close_session.call_args[0][0].sessionId + self.assertEqual(close_session_id, b"\x22") + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_auth_args(self, mock_client_class): + # Test that the following auth args work: + # token = foo, + # token = None, _tls_client_cert_file = something, _use_cert_as_auth = True + connection_args = [ + { + "server_hostname": "foo", + "http_path": None, + "access_token": "tok", + }, + { + "server_hostname": "foo", + "http_path": None, + "_tls_client_cert_file": "something", + "_use_cert_as_auth": True, + "access_token": None, + }, + ] + + for args in connection_args: + connection = databricks.sql.connect(**args) + host, port, http_path, *_ = mock_client_class.call_args[0] + self.assertEqual(args["server_hostname"], host) + self.assertEqual(args["http_path"], http_path) + connection.close() + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_http_header_passthrough(self, mock_client_class): + http_headers = [("foo", "bar")] + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) + + call_args = mock_client_class.call_args[0][3] + self.assertIn(("foo", "bar"), call_args) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_tls_arg_passthrough(self, mock_client_class): + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, + _tls_verify_hostname="hostname", + _tls_trusted_ca_file="trusted ca file", + _tls_client_cert_key_file="trusted client cert", + _tls_client_cert_key_password="key password", + ) + + kwargs = mock_client_class.call_args[1] + self.assertEqual(kwargs["_tls_verify_hostname"], "hostname") + self.assertEqual(kwargs["_tls_trusted_ca_file"], "trusted ca file") + self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") + self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_useragent_header(self, mock_client_class): + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + + http_headers = mock_client_class.call_args[0][3] + user_agent_header = ( + "User-Agent", + "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), + ) + self.assertIn(user_agent_header, http_headers) + + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar") + user_agent_header_with_entry = ( + "User-Agent", + "{}/{} ({})".format( + databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar" + ), + ) + http_headers = mock_client_class.call_args[0][3] + self.assertIn(user_agent_header_with_entry, http_headers) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_context_manager_closes_connection(self, mock_client_class): + instance = mock_client_class.return_value + + mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + mock_open_session_resp.sessionHandle.sessionId = b"\x22" + instance.open_session.return_value = mock_open_session_resp + + with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: + pass + + # Check the close session request has an id of x22 + close_session_id = instance.close_session.call_args[0][0].sessionId + self.assertEqual(close_session_id, b"\x22") + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_max_number_of_retries_passthrough(self, mock_client_class): + databricks.sql.connect( + _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS + ) + + self.assertEqual( + mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 + ) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_socket_timeout_passthrough(self, mock_client_class): + databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) + self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_configuration_passthrough(self, mock_client_class): + mock_session_config = Mock() + databricks.sql.connect( + session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS + ) + + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][0], + mock_session_config, + ) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_initial_namespace_passthrough(self, mock_client_class): + mock_cat = Mock() + mock_schem = Mock() + + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem + ) + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][1], mock_cat + ) + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][2], mock_schem + ) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_finalizer_closes_abandoned_connection(self, mock_client_class): + instance = mock_client_class.return_value + + mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + mock_open_session_resp.sessionHandle.sessionId = b"\x22" + instance.open_session.return_value = mock_open_session_resp + + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + + # not strictly necessary as the refcount is 0, but just to be sure + gc.collect() + + # Check the close session request has an id of x22 + close_session_id = instance.close_session.call_args[0][0].sessionId + self.assertEqual(close_session_id, b"\x22") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 7fe31844..458ea9a8 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -86,7 +86,9 @@ def test_make_request_checks_thrift_status_code(self): def _make_type_desc(self, type): return ttypes.TTypeDesc( - types=[ttypes.TTypeEntry(primitiveEntry=ttypes.TPrimitiveTypeEntry(type=type))] + types=[ + ttypes.TTypeEntry(primitiveEntry=ttypes.TPrimitiveTypeEntry(type=type)) + ] ) def _make_fake_thrift_backend(self):