|
23 | 23 | from databricks.sql.exc import RequestError, CursorAlreadyClosedError
|
24 | 24 | from databricks.sql.types import Row
|
25 | 25 | from databricks.sql.result_set import ResultSet, ThriftResultSet
|
26 |
| -from databricks.sql.backend.types import CommandId |
| 26 | +from databricks.sql.backend.types import CommandId, CommandState |
| 27 | +from databricks.sql.utils import ExecuteResponse |
27 | 28 |
|
28 | 29 | from tests.unit.test_fetches import FetchTests
|
29 | 30 | from tests.unit.test_thrift_backend import ThriftBackendTestSuite
|
@@ -83,30 +84,95 @@ class ClientTestSuite(unittest.TestCase):
|
83 | 84 | }
|
84 | 85 |
|
85 | 86 | @patch(
|
86 |
| - "%s.session.ThriftDatabricksClient" % PACKAGE_NAME, |
87 |
| - ThriftDatabricksClientMockFactory.new(), |
| 87 | + "%s.backend.thrift_backend.ThriftDatabricksClient.close_command" % PACKAGE_NAME |
88 | 88 | )
|
89 |
| - def test_closing_connection_closes_commands(self): |
| 89 | + def test_closing_connection_closes_commands(self, mock_close_command): |
| 90 | + """Test that connection.close() properly closes result sets through the real close chain.""" |
90 | 91 | # Test once with has_been_closed_server side, once without
|
91 | 92 | for closed in (True, False):
|
92 | 93 | with self.subTest(closed=closed):
|
93 |
| - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) |
94 |
| - cursor = connection.cursor() |
| 94 | + mock_close_command.reset_mock() # Reset for each subtest |
| 95 | + |
| 96 | + # Create a real ThriftResultSet with mocked dependencies |
| 97 | + from databricks.sql.utils import ExecuteResponse |
| 98 | + from databricks.sql.backend.types import CommandId, CommandState |
| 99 | + from databricks.sql.result_set import ThriftResultSet |
95 | 100 |
|
96 |
| - # Create a mock result set and set it as the active result set |
97 |
| - mock_result_set = Mock() |
98 |
| - mock_result_set.has_been_closed_server_side = closed |
99 |
| - cursor.active_result_set = mock_result_set |
| 101 | + # Mock the execute response with controlled state |
| 102 | + mock_execute_response = Mock(spec=ExecuteResponse) |
| 103 | + mock_execute_response.command_id = Mock(spec=CommandId) |
100 | 104 |
|
101 |
| - # Close the connection |
102 |
| - connection.close() |
| 105 | + # Use actual Thrift operation states, not CommandState enums |
| 106 | + from databricks.sql.thrift_api.TCLIService import ttypes |
103 | 107 |
|
104 |
| - # Check that the manually created mock result set's close method was called |
105 |
| - self.assertEqual( |
106 |
| - mock_result_set.has_been_closed_server_side, |
107 |
| - closed, |
| 108 | + mock_execute_response.status = ( |
| 109 | + ttypes.TOperationState.FINISHED_STATE |
| 110 | + if not closed |
| 111 | + else ttypes.TOperationState.CLOSED_STATE |
108 | 112 | )
|
109 |
| - mock_result_set.close.assert_called_once_with() |
| 113 | + mock_execute_response.has_been_closed_server_side = closed |
| 114 | + mock_execute_response.is_staging_operation = False |
| 115 | + |
| 116 | + # Mock the backend that will be used by the real ThriftResultSet |
| 117 | + mock_backend = Mock(spec=ThriftDatabricksClient) |
| 118 | + mock_backend.staging_allowed_local_path = None |
| 119 | + |
| 120 | + # Create connection and cursor |
| 121 | + with patch( |
| 122 | + f"{self.PACKAGE_NAME}.session.ThriftDatabricksClient", |
| 123 | + return_value=mock_backend, |
| 124 | + ): |
| 125 | + connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) |
| 126 | + cursor = connection.cursor() |
| 127 | + |
| 128 | + # Create a REAL ThriftResultSet that will be returned by execute_command |
| 129 | + real_result_set = ThriftResultSet( |
| 130 | + connection=connection, |
| 131 | + execute_response=mock_execute_response, |
| 132 | + thrift_client=mock_backend, |
| 133 | + ) |
| 134 | + |
| 135 | + # Verify initial state |
| 136 | + self.assertEqual( |
| 137 | + real_result_set.has_been_closed_server_side, closed |
| 138 | + ) |
| 139 | + expected_op_state = ( |
| 140 | + CommandState.CLOSED if closed else CommandState.SUCCEEDED |
| 141 | + ) |
| 142 | + self.assertEqual(real_result_set.op_state, expected_op_state) |
| 143 | + |
| 144 | + # Mock execute_command to return our real result set |
| 145 | + cursor.backend.execute_command = Mock(return_value=real_result_set) |
| 146 | + |
| 147 | + # Execute a command - this should set cursor.active_result_set to our real result set |
| 148 | + cursor.execute("SELECT 1") |
| 149 | + |
| 150 | + # Verify that cursor.execute() set up the result set correctly |
| 151 | + self.assertIsInstance(cursor.active_result_set, ThriftResultSet) |
| 152 | + self.assertEqual( |
| 153 | + cursor.active_result_set.has_been_closed_server_side, closed |
| 154 | + ) |
| 155 | + |
| 156 | + # Close the connection - this should trigger the real close chain: |
| 157 | + # connection.close() -> cursor.close() -> result_set.close() |
| 158 | + connection.close() |
| 159 | + |
| 160 | + # Verify the REAL close logic worked through the chain: |
| 161 | + # 1. has_been_closed_server_side should always be True after close() |
| 162 | + self.assertTrue(real_result_set.has_been_closed_server_side) |
| 163 | + |
| 164 | + # 2. op_state should always be CLOSED after close() |
| 165 | + self.assertEqual(real_result_set.op_state, CommandState.CLOSED) |
| 166 | + |
| 167 | + # 3. Backend close_command should be called appropriately |
| 168 | + if not closed: |
| 169 | + # Should have called backend.close_command during the close chain |
| 170 | + mock_backend.close_command.assert_called_once_with( |
| 171 | + mock_execute_response.command_id |
| 172 | + ) |
| 173 | + else: |
| 174 | + # Should NOT have called backend.close_command (already closed) |
| 175 | + mock_backend.close_command.assert_not_called() |
110 | 176 |
|
111 | 177 | @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
|
112 | 178 | def test_cant_open_cursor_on_closed_connection(self, mock_client_class):
|
|
0 commit comments