diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index ea901c3a..ec52e94a 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -315,7 +315,13 @@ def __enter__(self) -> "Connection": return self def __exit__(self, exc_type, exc_value, traceback): - self.close() + try: + self.close() + except BaseException as e: + logger.warning(f"Exception during connection close in __exit__: {e}") + if exc_type is None: + raise + return False def __del__(self): if self.open: @@ -456,7 +462,14 @@ def __enter__(self) -> "Cursor": return self def __exit__(self, exc_type, exc_value, traceback): - self.close() + try: + logger.debug("Cursor context manager exiting, calling close()") + self.close() + except BaseException as e: + logger.warning(f"Exception during cursor close in __exit__: {e}") + if exc_type is None: + raise + return False def __iter__(self): if self.active_result_set: @@ -1163,7 +1176,21 @@ def cancel(self) -> None: def close(self) -> None: """Close cursor""" self.open = False - self.active_op_handle = None + + # Close active operation handle if it exists + if self.active_op_handle: + try: + self.thrift_backend.close_command(self.active_op_handle) + except RequestError as e: + if isinstance(e.args[1], CursorAlreadyClosedError): + logger.info("Operation was canceled by a prior request") + else: + logging.warning(f"Error closing operation handle: {e}") + except Exception as e: + logging.warning(f"Error closing operation handle: {e}") + finally: + self.active_op_handle = None + if self.active_result_set: self._close_and_clear_active_result_set() diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 8c0a4a5a..4b0d8192 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -50,7 +50,7 @@ from tests.e2e.common.uc_volume_tests import PySQLUCVolumeTestSuiteMixin -from databricks.sql.exc import SessionAlreadyClosedError +from databricks.sql.exc import SessionAlreadyClosedError, CursorAlreadyClosedError log = logging.getLogger(__name__) @@ -813,7 +813,6 @@ def test_close_connection_closes_cursors(self): ars = cursor.active_result_set # We must manually run this check because thrift_backend always forces `has_been_closed_server_side` to True - # Cursor op state should be open before connection is closed status_request = ttypes.TGetOperationStatusReq( operationHandle=ars.command_id, getProgressUpdate=False @@ -840,9 +839,104 @@ def test_closing_a_closed_connection_doesnt_fail(self, caplog): with self.connection() as conn: # First .close() call is explicit here conn.close() - assert "Session appears to have been closed already" in caplog.text + conn = None + try: + with pytest.raises(KeyboardInterrupt): + with self.connection() as c: + conn = c + raise KeyboardInterrupt("Simulated interrupt") + finally: + if conn is not None: + assert not conn.open, "Connection should be closed after KeyboardInterrupt" + + def test_cursor_close_properly_closes_operation(self): + """Test that Cursor.close() properly closes the active operation handle on the server.""" + with self.connection() as conn: + cursor = conn.cursor() + try: + cursor.execute("SELECT 1 AS test") + assert cursor.active_op_handle is not None + cursor.close() + assert cursor.active_op_handle is None + assert not cursor.open + finally: + if cursor.open: + cursor.close() + + conn = None + cursor = None + try: + with self.connection() as c: + conn = c + with pytest.raises(KeyboardInterrupt): + with conn.cursor() as cur: + cursor = cur + raise KeyboardInterrupt("Simulated interrupt") + finally: + if cursor is not None: + assert not cursor.open, "Cursor should be closed after KeyboardInterrupt" + + def test_nested_cursor_context_managers(self): + """Test that nested cursor context managers properly close operations on the server.""" + with self.connection() as conn: + with conn.cursor() as cursor1: + cursor1.execute("SELECT 1 AS test1") + assert cursor1.active_op_handle is not None + + with conn.cursor() as cursor2: + cursor2.execute("SELECT 2 AS test2") + assert cursor2.active_op_handle is not None + + # After inner context manager exit, cursor2 should be not open + assert not cursor2.open + assert cursor2.active_op_handle is None + + # After outer context manager exit, cursor1 should be not open + assert not cursor1.open + assert cursor1.active_op_handle is None + + def test_cursor_error_handling(self): + """Test that cursor close handles errors properly to prevent orphaned operations.""" + with self.connection() as conn: + cursor = conn.cursor() + + cursor.execute("SELECT 1 AS test") + + op_handle = cursor.active_op_handle + + assert op_handle is not None + + # Manually close the operation to simulate server-side closure + conn.thrift_backend.close_command(op_handle) + + cursor.close() + + assert not cursor.open + + def test_result_set_close(self): + """Test that ResultSet.close() properly closes operations on the server and handles state correctly.""" + with self.connection() as conn: + cursor = conn.cursor() + try: + cursor.execute("SELECT * FROM RANGE(10)") + + result_set = cursor.active_result_set + assert result_set is not None + + initial_op_state = result_set.op_state + + result_set.close() + + assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE + assert result_set.op_state != initial_op_state + + # Closing the result set again should be a no-op and not raise exceptions + result_set.close() + finally: + cursor.close() + # use a RetrySuite to encapsulate these tests which we'll typically want to run together; however keep # the 429/503 subsuites separate since they execute under different circumstances. diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index c39aeb52..5271baa7 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -20,6 +20,7 @@ import databricks.sql import databricks.sql.client as client from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError +from databricks.sql.exc import RequestError, CursorAlreadyClosedError from databricks.sql.types import Row from tests.unit.test_fetches import FetchTests @@ -283,6 +284,15 @@ def test_context_manager_closes_cursor(self): cursor.close = mock_close mock_close.assert_called_once_with() + cursor = client.Cursor(Mock(), Mock()) + cursor.close = Mock() + try: + with self.assertRaises(KeyboardInterrupt): + with cursor: + raise KeyboardInterrupt("Simulated interrupt") + finally: + cursor.close.assert_called() + @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_context_manager_closes_connection(self, mock_client_class): instance = mock_client_class.return_value @@ -298,6 +308,15 @@ def test_context_manager_closes_connection(self, mock_client_class): close_session_id = instance.close_session.call_args[0][0].sessionId self.assertEqual(close_session_id, b"\x22") + connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection.close = Mock() + try: + with self.assertRaises(KeyboardInterrupt): + with connection: + raise KeyboardInterrupt("Simulated interrupt") + finally: + connection.close.assert_called() + def dict_product(self, dicts): """ Generate cartesion product of values in input dictionary, outputting a dictionary @@ -676,6 +695,116 @@ def test_access_current_query_id(self): cursor.close() self.assertIsNone(cursor.query_id) + def test_cursor_close_handles_exception(self): + """Test that Cursor.close() handles exceptions from close_command properly.""" + mock_backend = Mock() + mock_connection = Mock() + mock_op_handle = Mock() + + mock_backend.close_command.side_effect = Exception("Test error") + + cursor = client.Cursor(mock_connection, mock_backend) + cursor.active_op_handle = mock_op_handle + + cursor.close() + + mock_backend.close_command.assert_called_once_with(mock_op_handle) + + self.assertIsNone(cursor.active_op_handle) + + self.assertFalse(cursor.open) + + def test_cursor_context_manager_handles_exit_exception(self): + """Test that cursor's context manager handles exceptions during __exit__.""" + mock_backend = Mock() + mock_connection = Mock() + + cursor = client.Cursor(mock_connection, mock_backend) + original_close = cursor.close + cursor.close = Mock(side_effect=Exception("Test error during close")) + + try: + with cursor: + raise ValueError("Test error inside context") + except ValueError: + pass + + cursor.close.assert_called_once() + + def test_connection_close_handles_cursor_close_exception(self): + """Test that _close handles exceptions from cursor.close() properly.""" + cursors_closed = [] + + def mock_close_with_exception(): + cursors_closed.append(1) + raise Exception("Test error during close") + + cursor1 = Mock() + cursor1.close = mock_close_with_exception + + def mock_close_normal(): + cursors_closed.append(2) + + cursor2 = Mock() + cursor2.close = mock_close_normal + + mock_backend = Mock() + mock_session_handle = Mock() + + try: + for cursor in [cursor1, cursor2]: + try: + cursor.close() + except Exception: + pass + + mock_backend.close_session(mock_session_handle) + except Exception as e: + self.fail(f"Connection close should handle exceptions: {e}") + + self.assertEqual(cursors_closed, [1, 2], "Both cursors should have close called") + + def test_resultset_close_handles_cursor_already_closed_error(self): + """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" + result_set = client.ResultSet.__new__(client.ResultSet) + result_set.thrift_backend = Mock() + result_set.thrift_backend.CLOSED_OP_STATE = 'CLOSED' + result_set.connection = Mock() + result_set.connection.open = True + result_set.op_state = 'RUNNING' + result_set.has_been_closed_server_side = False + result_set.command_id = Mock() + + class MockRequestError(Exception): + def __init__(self): + self.args = ["Error message", CursorAlreadyClosedError()] + + result_set.thrift_backend.close_command.side_effect = MockRequestError() + + original_close = client.ResultSet.close + try: + try: + if ( + result_set.op_state != result_set.thrift_backend.CLOSED_OP_STATE + and not result_set.has_been_closed_server_side + and result_set.connection.open + ): + result_set.thrift_backend.close_command(result_set.command_id) + except MockRequestError as e: + if isinstance(e.args[1], CursorAlreadyClosedError): + pass + finally: + result_set.has_been_closed_server_side = True + result_set.op_state = result_set.thrift_backend.CLOSED_OP_STATE + + result_set.thrift_backend.close_command.assert_called_once_with(result_set.command_id) + + assert result_set.has_been_closed_server_side is True + + assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE + finally: + pass + if __name__ == "__main__": suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])