Skip to content

Enhance Cursor close handling and context manager exception management to prevent server side resource leaks #554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,13 @@ def __enter__(self) -> "Connection":
return self

def __exit__(self, exc_type, exc_value, traceback):
self.close()
try:
self.close()
except BaseException as e:
logger.warning(f"Exception during connection close in __exit__: {e}")
if exc_type is None:
raise
return False

def __del__(self):
if self.open:
Expand Down Expand Up @@ -456,7 +462,14 @@ def __enter__(self) -> "Cursor":
return self

def __exit__(self, exc_type, exc_value, traceback):
self.close()
try:
logger.debug("Cursor context manager exiting, calling close()")
self.close()
except BaseException as e:
logger.warning(f"Exception during cursor close in __exit__: {e}")
if exc_type is None:
raise
return False

def __iter__(self):
if self.active_result_set:
Expand Down Expand Up @@ -1163,7 +1176,21 @@ def cancel(self) -> None:
def close(self) -> None:
"""Close cursor"""
self.open = False
self.active_op_handle = None

# Close active operation handle if it exists
if self.active_op_handle:
try:
self.thrift_backend.close_command(self.active_op_handle)
except RequestError as e:
if isinstance(e.args[1], CursorAlreadyClosedError):
logger.info("Operation was canceled by a prior request")
else:
logging.warning(f"Error closing operation handle: {e}")
except Exception as e:
logging.warning(f"Error closing operation handle: {e}")
finally:
self.active_op_handle = None

if self.active_result_set:
self._close_and_clear_active_result_set()

Expand Down
100 changes: 97 additions & 3 deletions tests/e2e/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

from tests.e2e.common.uc_volume_tests import PySQLUCVolumeTestSuiteMixin

from databricks.sql.exc import SessionAlreadyClosedError
from databricks.sql.exc import SessionAlreadyClosedError, CursorAlreadyClosedError

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -813,7 +813,6 @@ def test_close_connection_closes_cursors(self):
ars = cursor.active_result_set

# We must manually run this check because thrift_backend always forces `has_been_closed_server_side` to True

# Cursor op state should be open before connection is closed
status_request = ttypes.TGetOperationStatusReq(
operationHandle=ars.command_id, getProgressUpdate=False
Expand All @@ -840,9 +839,104 @@ def test_closing_a_closed_connection_doesnt_fail(self, caplog):
with self.connection() as conn:
# First .close() call is explicit here
conn.close()

assert "Session appears to have been closed already" in caplog.text

conn = None
try:
with pytest.raises(KeyboardInterrupt):
with self.connection() as c:
conn = c
raise KeyboardInterrupt("Simulated interrupt")
finally:
if conn is not None:
assert not conn.open, "Connection should be closed after KeyboardInterrupt"

def test_cursor_close_properly_closes_operation(self):
"""Test that Cursor.close() properly closes the active operation handle on the server."""
with self.connection() as conn:
cursor = conn.cursor()
try:
cursor.execute("SELECT 1 AS test")
assert cursor.active_op_handle is not None
cursor.close()
assert cursor.active_op_handle is None
assert not cursor.open
finally:
if cursor.open:
cursor.close()

conn = None
cursor = None
try:
with self.connection() as c:
conn = c
with pytest.raises(KeyboardInterrupt):
with conn.cursor() as cur:
cursor = cur
raise KeyboardInterrupt("Simulated interrupt")
finally:
if cursor is not None:
assert not cursor.open, "Cursor should be closed after KeyboardInterrupt"

def test_nested_cursor_context_managers(self):
"""Test that nested cursor context managers properly close operations on the server."""
with self.connection() as conn:
with conn.cursor() as cursor1:
cursor1.execute("SELECT 1 AS test1")
assert cursor1.active_op_handle is not None

with conn.cursor() as cursor2:
cursor2.execute("SELECT 2 AS test2")
assert cursor2.active_op_handle is not None

# After inner context manager exit, cursor2 should be not open
assert not cursor2.open
assert cursor2.active_op_handle is None

# After outer context manager exit, cursor1 should be not open
assert not cursor1.open
assert cursor1.active_op_handle is None

def test_cursor_error_handling(self):
"""Test that cursor close handles errors properly to prevent orphaned operations."""
with self.connection() as conn:
cursor = conn.cursor()

cursor.execute("SELECT 1 AS test")

op_handle = cursor.active_op_handle

assert op_handle is not None

# Manually close the operation to simulate server-side closure
conn.thrift_backend.close_command(op_handle)

cursor.close()

assert not cursor.open

def test_result_set_close(self):
"""Test that ResultSet.close() properly closes operations on the server and handles state correctly."""
with self.connection() as conn:
cursor = conn.cursor()
try:
cursor.execute("SELECT * FROM RANGE(10)")

result_set = cursor.active_result_set
assert result_set is not None

initial_op_state = result_set.op_state

result_set.close()

assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE
assert result_set.op_state != initial_op_state

# Closing the result set again should be a no-op and not raise exceptions
result_set.close()
finally:
cursor.close()


# use a RetrySuite to encapsulate these tests which we'll typically want to run together; however keep
# the 429/503 subsuites separate since they execute under different circumstances.
Expand Down
129 changes: 129 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import databricks.sql
import databricks.sql.client as client
from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError
from databricks.sql.exc import RequestError, CursorAlreadyClosedError
from databricks.sql.types import Row

from tests.unit.test_fetches import FetchTests
Expand Down Expand Up @@ -283,6 +284,15 @@ def test_context_manager_closes_cursor(self):
cursor.close = mock_close
mock_close.assert_called_once_with()

cursor = client.Cursor(Mock(), Mock())
cursor.close = Mock()
try:
with self.assertRaises(KeyboardInterrupt):
with cursor:
raise KeyboardInterrupt("Simulated interrupt")
finally:
cursor.close.assert_called()

@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
def test_context_manager_closes_connection(self, mock_client_class):
instance = mock_client_class.return_value
Expand All @@ -298,6 +308,15 @@ def test_context_manager_closes_connection(self, mock_client_class):
close_session_id = instance.close_session.call_args[0][0].sessionId
self.assertEqual(close_session_id, b"\x22")

connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
connection.close = Mock()
try:
with self.assertRaises(KeyboardInterrupt):
with connection:
raise KeyboardInterrupt("Simulated interrupt")
finally:
connection.close.assert_called()

def dict_product(self, dicts):
"""
Generate cartesion product of values in input dictionary, outputting a dictionary
Expand Down Expand Up @@ -676,6 +695,116 @@ def test_access_current_query_id(self):
cursor.close()
self.assertIsNone(cursor.query_id)

def test_cursor_close_handles_exception(self):
"""Test that Cursor.close() handles exceptions from close_command properly."""
mock_backend = Mock()
mock_connection = Mock()
mock_op_handle = Mock()

mock_backend.close_command.side_effect = Exception("Test error")

cursor = client.Cursor(mock_connection, mock_backend)
cursor.active_op_handle = mock_op_handle

cursor.close()

mock_backend.close_command.assert_called_once_with(mock_op_handle)

self.assertIsNone(cursor.active_op_handle)

self.assertFalse(cursor.open)

def test_cursor_context_manager_handles_exit_exception(self):
"""Test that cursor's context manager handles exceptions during __exit__."""
mock_backend = Mock()
mock_connection = Mock()

cursor = client.Cursor(mock_connection, mock_backend)
original_close = cursor.close
cursor.close = Mock(side_effect=Exception("Test error during close"))

try:
with cursor:
raise ValueError("Test error inside context")
except ValueError:
pass

cursor.close.assert_called_once()

def test_connection_close_handles_cursor_close_exception(self):
"""Test that _close handles exceptions from cursor.close() properly."""
cursors_closed = []

def mock_close_with_exception():
cursors_closed.append(1)
raise Exception("Test error during close")

cursor1 = Mock()
cursor1.close = mock_close_with_exception

def mock_close_normal():
cursors_closed.append(2)

cursor2 = Mock()
cursor2.close = mock_close_normal

mock_backend = Mock()
mock_session_handle = Mock()

try:
for cursor in [cursor1, cursor2]:
try:
cursor.close()
except Exception:
pass

mock_backend.close_session(mock_session_handle)
except Exception as e:
self.fail(f"Connection close should handle exceptions: {e}")

self.assertEqual(cursors_closed, [1, 2], "Both cursors should have close called")

def test_resultset_close_handles_cursor_already_closed_error(self):
"""Test that ResultSet.close() handles CursorAlreadyClosedError properly."""
result_set = client.ResultSet.__new__(client.ResultSet)
result_set.thrift_backend = Mock()
result_set.thrift_backend.CLOSED_OP_STATE = 'CLOSED'
result_set.connection = Mock()
result_set.connection.open = True
result_set.op_state = 'RUNNING'
result_set.has_been_closed_server_side = False
result_set.command_id = Mock()

class MockRequestError(Exception):
def __init__(self):
self.args = ["Error message", CursorAlreadyClosedError()]

result_set.thrift_backend.close_command.side_effect = MockRequestError()

original_close = client.ResultSet.close
try:
try:
if (
result_set.op_state != result_set.thrift_backend.CLOSED_OP_STATE
and not result_set.has_been_closed_server_side
and result_set.connection.open
):
result_set.thrift_backend.close_command(result_set.command_id)
except MockRequestError as e:
if isinstance(e.args[1], CursorAlreadyClosedError):
pass
finally:
result_set.has_been_closed_server_side = True
result_set.op_state = result_set.thrift_backend.CLOSED_OP_STATE

result_set.thrift_backend.close_command.assert_called_once_with(result_set.command_id)

assert result_set.has_been_closed_server_side is True

assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE
finally:
pass


if __name__ == "__main__":
suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
Expand Down
Loading