Skip to content

Commit aa7207e

Browse files
more robust close check
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent 1f0c81f commit aa7207e

File tree

1 file changed

+83
-17
lines changed

1 file changed

+83
-17
lines changed

tests/unit/test_client.py

Lines changed: 83 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from databricks.sql.exc import RequestError, CursorAlreadyClosedError
2424
from databricks.sql.types import Row
2525
from databricks.sql.result_set import ResultSet, ThriftResultSet
26-
from databricks.sql.backend.types import CommandId
26+
from databricks.sql.backend.types import CommandId, CommandState
27+
from databricks.sql.utils import ExecuteResponse
2728

2829
from tests.unit.test_fetches import FetchTests
2930
from tests.unit.test_thrift_backend import ThriftBackendTestSuite
@@ -83,30 +84,95 @@ class ClientTestSuite(unittest.TestCase):
8384
}
8485

8586
@patch(
86-
"%s.session.ThriftDatabricksClient" % PACKAGE_NAME,
87-
ThriftDatabricksClientMockFactory.new(),
87+
"%s.backend.thrift_backend.ThriftDatabricksClient.close_command" % PACKAGE_NAME
8888
)
89-
def test_closing_connection_closes_commands(self):
89+
def test_closing_connection_closes_commands(self, mock_close_command):
90+
"""Test that connection.close() properly closes result sets through the real close chain."""
9091
# Test once with has_been_closed_server side, once without
9192
for closed in (True, False):
9293
with self.subTest(closed=closed):
93-
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
94-
cursor = connection.cursor()
94+
mock_close_command.reset_mock() # Reset for each subtest
95+
96+
# Create a real ThriftResultSet with mocked dependencies
97+
from databricks.sql.utils import ExecuteResponse
98+
from databricks.sql.backend.types import CommandId, CommandState
99+
from databricks.sql.result_set import ThriftResultSet
95100

96-
# Create a mock result set and set it as the active result set
97-
mock_result_set = Mock()
98-
mock_result_set.has_been_closed_server_side = closed
99-
cursor.active_result_set = mock_result_set
101+
# Mock the execute response with controlled state
102+
mock_execute_response = Mock(spec=ExecuteResponse)
103+
mock_execute_response.command_id = Mock(spec=CommandId)
100104

101-
# Close the connection
102-
connection.close()
105+
# Use actual Thrift operation states, not CommandState enums
106+
from databricks.sql.thrift_api.TCLIService import ttypes
103107

104-
# Check that the manually created mock result set's close method was called
105-
self.assertEqual(
106-
mock_result_set.has_been_closed_server_side,
107-
closed,
108+
mock_execute_response.status = (
109+
ttypes.TOperationState.FINISHED_STATE
110+
if not closed
111+
else ttypes.TOperationState.CLOSED_STATE
108112
)
109-
mock_result_set.close.assert_called_once_with()
113+
mock_execute_response.has_been_closed_server_side = closed
114+
mock_execute_response.is_staging_operation = False
115+
116+
# Mock the backend that will be used by the real ThriftResultSet
117+
mock_backend = Mock(spec=ThriftDatabricksClient)
118+
mock_backend.staging_allowed_local_path = None
119+
120+
# Create connection and cursor
121+
with patch(
122+
f"{self.PACKAGE_NAME}.session.ThriftDatabricksClient",
123+
return_value=mock_backend,
124+
):
125+
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
126+
cursor = connection.cursor()
127+
128+
# Create a REAL ThriftResultSet that will be returned by execute_command
129+
real_result_set = ThriftResultSet(
130+
connection=connection,
131+
execute_response=mock_execute_response,
132+
thrift_client=mock_backend,
133+
)
134+
135+
# Verify initial state
136+
self.assertEqual(
137+
real_result_set.has_been_closed_server_side, closed
138+
)
139+
expected_op_state = (
140+
CommandState.CLOSED if closed else CommandState.SUCCEEDED
141+
)
142+
self.assertEqual(real_result_set.op_state, expected_op_state)
143+
144+
# Mock execute_command to return our real result set
145+
cursor.backend.execute_command = Mock(return_value=real_result_set)
146+
147+
# Execute a command - this should set cursor.active_result_set to our real result set
148+
cursor.execute("SELECT 1")
149+
150+
# Verify that cursor.execute() set up the result set correctly
151+
self.assertIsInstance(cursor.active_result_set, ThriftResultSet)
152+
self.assertEqual(
153+
cursor.active_result_set.has_been_closed_server_side, closed
154+
)
155+
156+
# Close the connection - this should trigger the real close chain:
157+
# connection.close() -> cursor.close() -> result_set.close()
158+
connection.close()
159+
160+
# Verify the REAL close logic worked through the chain:
161+
# 1. has_been_closed_server_side should always be True after close()
162+
self.assertTrue(real_result_set.has_been_closed_server_side)
163+
164+
# 2. op_state should always be CLOSED after close()
165+
self.assertEqual(real_result_set.op_state, CommandState.CLOSED)
166+
167+
# 3. Backend close_command should be called appropriately
168+
if not closed:
169+
# Should have called backend.close_command during the close chain
170+
mock_backend.close_command.assert_called_once_with(
171+
mock_execute_response.command_id
172+
)
173+
else:
174+
# Should NOT have called backend.close_command (already closed)
175+
mock_backend.close_command.assert_not_called()
110176

111177
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
112178
def test_cant_open_cursor_on_closed_connection(self, mock_client_class):

0 commit comments

Comments
 (0)