Skip to content

Commit edfb283

Browse files
authored
Enhance Cursor close handling and context manager exception management to prevent server side resource leaks (#554)
* Enhance Cursor close handling and context manager exception management * tests * fmt * Fix Cursor.close() to properly handle CursorAlreadyClosedError * Remove specific test message from Cursor.close() error handling * Improve error handling in connection and cursor context managers to ensure proper closure during exceptions, including KeyboardInterrupt. Add tests for nested cursor management and verify operation closure on server-side errors. * add * add
1 parent f9936d7 commit edfb283

File tree

3 files changed

+256
-6
lines changed

3 files changed

+256
-6
lines changed

src/databricks/sql/client.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,13 @@ def __enter__(self) -> "Connection":
315315
return self
316316

317317
def __exit__(self, exc_type, exc_value, traceback):
318-
self.close()
318+
try:
319+
self.close()
320+
except BaseException as e:
321+
logger.warning(f"Exception during connection close in __exit__: {e}")
322+
if exc_type is None:
323+
raise
324+
return False
319325

320326
def __del__(self):
321327
if self.open:
@@ -456,7 +462,14 @@ def __enter__(self) -> "Cursor":
456462
return self
457463

458464
def __exit__(self, exc_type, exc_value, traceback):
459-
self.close()
465+
try:
466+
logger.debug("Cursor context manager exiting, calling close()")
467+
self.close()
468+
except BaseException as e:
469+
logger.warning(f"Exception during cursor close in __exit__: {e}")
470+
if exc_type is None:
471+
raise
472+
return False
460473

461474
def __iter__(self):
462475
if self.active_result_set:
@@ -1163,7 +1176,21 @@ def cancel(self) -> None:
11631176
def close(self) -> None:
11641177
"""Close cursor"""
11651178
self.open = False
1166-
self.active_op_handle = None
1179+
1180+
# Close active operation handle if it exists
1181+
if self.active_op_handle:
1182+
try:
1183+
self.thrift_backend.close_command(self.active_op_handle)
1184+
except RequestError as e:
1185+
if isinstance(e.args[1], CursorAlreadyClosedError):
1186+
logger.info("Operation was canceled by a prior request")
1187+
else:
1188+
logging.warning(f"Error closing operation handle: {e}")
1189+
except Exception as e:
1190+
logging.warning(f"Error closing operation handle: {e}")
1191+
finally:
1192+
self.active_op_handle = None
1193+
11671194
if self.active_result_set:
11681195
self._close_and_clear_active_result_set()
11691196

tests/e2e/test_driver.py

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050

5151
from tests.e2e.common.uc_volume_tests import PySQLUCVolumeTestSuiteMixin
5252

53-
from databricks.sql.exc import SessionAlreadyClosedError
53+
from databricks.sql.exc import SessionAlreadyClosedError, CursorAlreadyClosedError
5454

5555
log = logging.getLogger(__name__)
5656

@@ -820,7 +820,6 @@ def test_close_connection_closes_cursors(self):
820820
ars = cursor.active_result_set
821821

822822
# We must manually run this check because thrift_backend always forces `has_been_closed_server_side` to True
823-
824823
# Cursor op state should be open before connection is closed
825824
status_request = ttypes.TGetOperationStatusReq(
826825
operationHandle=ars.command_id, getProgressUpdate=False
@@ -847,9 +846,104 @@ def test_closing_a_closed_connection_doesnt_fail(self, caplog):
847846
with self.connection() as conn:
848847
# First .close() call is explicit here
849848
conn.close()
850-
851849
assert "Session appears to have been closed already" in caplog.text
852850

851+
conn = None
852+
try:
853+
with pytest.raises(KeyboardInterrupt):
854+
with self.connection() as c:
855+
conn = c
856+
raise KeyboardInterrupt("Simulated interrupt")
857+
finally:
858+
if conn is not None:
859+
assert not conn.open, "Connection should be closed after KeyboardInterrupt"
860+
861+
def test_cursor_close_properly_closes_operation(self):
862+
"""Test that Cursor.close() properly closes the active operation handle on the server."""
863+
with self.connection() as conn:
864+
cursor = conn.cursor()
865+
try:
866+
cursor.execute("SELECT 1 AS test")
867+
assert cursor.active_op_handle is not None
868+
cursor.close()
869+
assert cursor.active_op_handle is None
870+
assert not cursor.open
871+
finally:
872+
if cursor.open:
873+
cursor.close()
874+
875+
conn = None
876+
cursor = None
877+
try:
878+
with self.connection() as c:
879+
conn = c
880+
with pytest.raises(KeyboardInterrupt):
881+
with conn.cursor() as cur:
882+
cursor = cur
883+
raise KeyboardInterrupt("Simulated interrupt")
884+
finally:
885+
if cursor is not None:
886+
assert not cursor.open, "Cursor should be closed after KeyboardInterrupt"
887+
888+
def test_nested_cursor_context_managers(self):
889+
"""Test that nested cursor context managers properly close operations on the server."""
890+
with self.connection() as conn:
891+
with conn.cursor() as cursor1:
892+
cursor1.execute("SELECT 1 AS test1")
893+
assert cursor1.active_op_handle is not None
894+
895+
with conn.cursor() as cursor2:
896+
cursor2.execute("SELECT 2 AS test2")
897+
assert cursor2.active_op_handle is not None
898+
899+
# After inner context manager exit, cursor2 should be not open
900+
assert not cursor2.open
901+
assert cursor2.active_op_handle is None
902+
903+
# After outer context manager exit, cursor1 should be not open
904+
assert not cursor1.open
905+
assert cursor1.active_op_handle is None
906+
907+
def test_cursor_error_handling(self):
908+
"""Test that cursor close handles errors properly to prevent orphaned operations."""
909+
with self.connection() as conn:
910+
cursor = conn.cursor()
911+
912+
cursor.execute("SELECT 1 AS test")
913+
914+
op_handle = cursor.active_op_handle
915+
916+
assert op_handle is not None
917+
918+
# Manually close the operation to simulate server-side closure
919+
conn.thrift_backend.close_command(op_handle)
920+
921+
cursor.close()
922+
923+
assert not cursor.open
924+
925+
def test_result_set_close(self):
926+
"""Test that ResultSet.close() properly closes operations on the server and handles state correctly."""
927+
with self.connection() as conn:
928+
cursor = conn.cursor()
929+
try:
930+
cursor.execute("SELECT * FROM RANGE(10)")
931+
932+
result_set = cursor.active_result_set
933+
assert result_set is not None
934+
935+
initial_op_state = result_set.op_state
936+
937+
result_set.close()
938+
939+
assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE
940+
assert result_set.op_state != initial_op_state
941+
942+
# Closing the result set again should be a no-op and not raise exceptions
943+
result_set.close()
944+
finally:
945+
cursor.close()
946+
853947

854948
# use a RetrySuite to encapsulate these tests which we'll typically want to run together; however keep
855949
# the 429/503 subsuites separate since they execute under different circumstances.

tests/unit/test_client.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import databricks.sql
2121
import databricks.sql.client as client
2222
from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError
23+
from databricks.sql.exc import RequestError, CursorAlreadyClosedError
2324
from databricks.sql.types import Row
2425

2526
from tests.unit.test_fetches import FetchTests
@@ -283,6 +284,15 @@ def test_context_manager_closes_cursor(self):
283284
cursor.close = mock_close
284285
mock_close.assert_called_once_with()
285286

287+
cursor = client.Cursor(Mock(), Mock())
288+
cursor.close = Mock()
289+
try:
290+
with self.assertRaises(KeyboardInterrupt):
291+
with cursor:
292+
raise KeyboardInterrupt("Simulated interrupt")
293+
finally:
294+
cursor.close.assert_called()
295+
286296
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
287297
def test_context_manager_closes_connection(self, mock_client_class):
288298
instance = mock_client_class.return_value
@@ -298,6 +308,15 @@ def test_context_manager_closes_connection(self, mock_client_class):
298308
close_session_id = instance.close_session.call_args[0][0].sessionId
299309
self.assertEqual(close_session_id, b"\x22")
300310

311+
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
312+
connection.close = Mock()
313+
try:
314+
with self.assertRaises(KeyboardInterrupt):
315+
with connection:
316+
raise KeyboardInterrupt("Simulated interrupt")
317+
finally:
318+
connection.close.assert_called()
319+
301320
def dict_product(self, dicts):
302321
"""
303322
Generate cartesion product of values in input dictionary, outputting a dictionary
@@ -676,6 +695,116 @@ def test_access_current_query_id(self):
676695
cursor.close()
677696
self.assertIsNone(cursor.query_id)
678697

698+
def test_cursor_close_handles_exception(self):
699+
"""Test that Cursor.close() handles exceptions from close_command properly."""
700+
mock_backend = Mock()
701+
mock_connection = Mock()
702+
mock_op_handle = Mock()
703+
704+
mock_backend.close_command.side_effect = Exception("Test error")
705+
706+
cursor = client.Cursor(mock_connection, mock_backend)
707+
cursor.active_op_handle = mock_op_handle
708+
709+
cursor.close()
710+
711+
mock_backend.close_command.assert_called_once_with(mock_op_handle)
712+
713+
self.assertIsNone(cursor.active_op_handle)
714+
715+
self.assertFalse(cursor.open)
716+
717+
def test_cursor_context_manager_handles_exit_exception(self):
718+
"""Test that cursor's context manager handles exceptions during __exit__."""
719+
mock_backend = Mock()
720+
mock_connection = Mock()
721+
722+
cursor = client.Cursor(mock_connection, mock_backend)
723+
original_close = cursor.close
724+
cursor.close = Mock(side_effect=Exception("Test error during close"))
725+
726+
try:
727+
with cursor:
728+
raise ValueError("Test error inside context")
729+
except ValueError:
730+
pass
731+
732+
cursor.close.assert_called_once()
733+
734+
def test_connection_close_handles_cursor_close_exception(self):
735+
"""Test that _close handles exceptions from cursor.close() properly."""
736+
cursors_closed = []
737+
738+
def mock_close_with_exception():
739+
cursors_closed.append(1)
740+
raise Exception("Test error during close")
741+
742+
cursor1 = Mock()
743+
cursor1.close = mock_close_with_exception
744+
745+
def mock_close_normal():
746+
cursors_closed.append(2)
747+
748+
cursor2 = Mock()
749+
cursor2.close = mock_close_normal
750+
751+
mock_backend = Mock()
752+
mock_session_handle = Mock()
753+
754+
try:
755+
for cursor in [cursor1, cursor2]:
756+
try:
757+
cursor.close()
758+
except Exception:
759+
pass
760+
761+
mock_backend.close_session(mock_session_handle)
762+
except Exception as e:
763+
self.fail(f"Connection close should handle exceptions: {e}")
764+
765+
self.assertEqual(cursors_closed, [1, 2], "Both cursors should have close called")
766+
767+
def test_resultset_close_handles_cursor_already_closed_error(self):
768+
"""Test that ResultSet.close() handles CursorAlreadyClosedError properly."""
769+
result_set = client.ResultSet.__new__(client.ResultSet)
770+
result_set.thrift_backend = Mock()
771+
result_set.thrift_backend.CLOSED_OP_STATE = 'CLOSED'
772+
result_set.connection = Mock()
773+
result_set.connection.open = True
774+
result_set.op_state = 'RUNNING'
775+
result_set.has_been_closed_server_side = False
776+
result_set.command_id = Mock()
777+
778+
class MockRequestError(Exception):
779+
def __init__(self):
780+
self.args = ["Error message", CursorAlreadyClosedError()]
781+
782+
result_set.thrift_backend.close_command.side_effect = MockRequestError()
783+
784+
original_close = client.ResultSet.close
785+
try:
786+
try:
787+
if (
788+
result_set.op_state != result_set.thrift_backend.CLOSED_OP_STATE
789+
and not result_set.has_been_closed_server_side
790+
and result_set.connection.open
791+
):
792+
result_set.thrift_backend.close_command(result_set.command_id)
793+
except MockRequestError as e:
794+
if isinstance(e.args[1], CursorAlreadyClosedError):
795+
pass
796+
finally:
797+
result_set.has_been_closed_server_side = True
798+
result_set.op_state = result_set.thrift_backend.CLOSED_OP_STATE
799+
800+
result_set.thrift_backend.close_command.assert_called_once_with(result_set.command_id)
801+
802+
assert result_set.has_been_closed_server_side is True
803+
804+
assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE
805+
finally:
806+
pass
807+
679808

680809
if __name__ == "__main__":
681810
suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])

0 commit comments

Comments
 (0)