Skip to content

Commit 9a968ed

Browse files
fix: correct client tests
1 parent 53ab3c4 commit 9a968ed

File tree

1 file changed

+39
-25
lines changed

1 file changed

+39
-25
lines changed

tests/unit/test_client.py

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import databricks.sql.client as client
2222
from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError
2323
from databricks.sql.types import Row
24+
from databricks.sql.result_set import ResultSet, ThriftResultSet
2425

2526
from tests.unit.test_fetches import FetchTests
2627
from tests.unit.test_thrift_backend import ThriftBackendTestSuite
@@ -34,12 +35,11 @@ def new(cls):
3435
ThriftBackendMock.return_value = ThriftBackendMock
3536

3637
cls.apply_property_to_mock(ThriftBackendMock, staging_allowed_local_path=None)
37-
MockTExecuteStatementResp = MagicMock(spec=TExecuteStatementResp())
38-
38+
39+
mock_result_set = Mock(spec=ThriftResultSet)
3940
cls.apply_property_to_mock(
40-
MockTExecuteStatementResp,
41+
mock_result_set,
4142
description=None,
42-
arrow_queue=None,
4343
is_staging_operation=False,
4444
command_handle=b"\x22",
4545
has_been_closed_server_side=True,
@@ -48,7 +48,7 @@ def new(cls):
4848
arrow_schema_bytes=b"schema",
4949
)
5050

51-
ThriftBackendMock.execute_command.return_value = MockTExecuteStatementResp
51+
ThriftBackendMock.execute_command.return_value = mock_result_set
5252

5353
return ThriftBackendMock
5454

@@ -81,21 +81,22 @@ class ClientTestSuite(unittest.TestCase):
8181
}
8282

8383
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME, ThriftDatabricksClientMockFactory.new())
84-
@patch("%s.client.ResultSet" % PACKAGE_NAME)
85-
def test_closing_connection_closes_commands(self, mock_result_set_class):
86-
# Test once with has_been_closed_server side, once without
84+
def test_closing_connection_closes_commands(self):
8785
for closed in (True, False):
8886
with self.subTest(closed=closed):
89-
mock_result_set_class.return_value = Mock()
9087
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
9188
cursor = connection.cursor()
92-
cursor.execute("SELECT 1;")
89+
90+
# Create a mock result set and set it as the active result set
91+
mock_result_set = Mock()
92+
mock_result_set.has_been_closed_server_side = closed
93+
cursor.active_result_set = mock_result_set
94+
95+
# Close the connection
9396
connection.close()
94-
95-
self.assertTrue(
96-
mock_result_set_class.return_value.has_been_closed_server_side
97-
)
98-
mock_result_set_class.return_value.close.assert_called_once_with()
97+
98+
# After closing the connection, the close method should have been called on the result set
99+
mock_result_set.close.assert_called_once_with()
99100

100101
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
101102
def test_cant_open_cursor_on_closed_connection(self, mock_client_class):
@@ -122,10 +123,11 @@ def test_arraysize_buffer_size_passthrough(
122123
def test_closing_result_set_with_closed_connection_soft_closes_commands(self):
123124
mock_connection = Mock()
124125
mock_backend = Mock()
125-
result_set = client.ResultSet(
126+
127+
result_set = ThriftResultSet(
126128
connection=mock_connection,
127-
backend=mock_backend,
128129
execute_response=Mock(),
130+
thrift_client=mock_backend,
129131
)
130132
# Setup session mock on the mock_connection
131133
mock_session = Mock()
@@ -147,7 +149,7 @@ def test_closing_result_set_hard_closes_commands(self):
147149
mock_session.open = True
148150
type(mock_connection).session = PropertyMock(return_value=mock_session)
149151

150-
result_set = client.ResultSet(
152+
result_set = ThriftResultSet(
151153
mock_connection, mock_results_response, mock_thrift_backend
152154
)
153155

@@ -157,16 +159,22 @@ def test_closing_result_set_hard_closes_commands(self):
157159
mock_results_response.command_handle
158160
)
159161

160-
@patch("%s.client.ResultSet" % PACKAGE_NAME)
162+
@patch("%s.result_set.ThriftResultSet" % PACKAGE_NAME)
161163
def test_executing_multiple_commands_uses_the_most_recent_command(
162164
self, mock_result_set_class
163165
):
164-
165166
mock_result_sets = [Mock(), Mock()]
167+
# Set is_staging_operation to False to avoid _handle_staging_operation being called
168+
for mock_rs in mock_result_sets:
169+
mock_rs.is_staging_operation = False
170+
166171
mock_result_set_class.side_effect = mock_result_sets
172+
173+
mock_backend = ThriftDatabricksClientMockFactory.new()
174+
mock_backend.execute_command.side_effect = mock_result_sets
167175

168176
cursor = client.Cursor(
169-
connection=Mock(), backend=ThriftDatabricksClientMockFactory.new()
177+
connection=Mock(), backend=mock_backend
170178
)
171179
cursor.execute("SELECT 1;")
172180
cursor.execute("SELECT 1;")
@@ -192,7 +200,7 @@ def test_closed_cursor_doesnt_allow_operations(self):
192200
self.assertIn("closed", e.msg)
193201

194202
def test_negative_fetch_throws_exception(self):
195-
result_set = client.ResultSet(Mock(), Mock(), Mock())
203+
result_set = ThriftResultSet(Mock(), Mock(), Mock())
196204

197205
with self.assertRaises(ValueError) as e:
198206
result_set.fetchmany(-1)
@@ -334,14 +342,19 @@ def test_execute_parameter_passthrough(self):
334342
expected_query,
335343
)
336344

337-
@patch("%s.client.ResultSet" % PACKAGE_NAME)
345+
@patch("%s.result_set.ThriftResultSet" % PACKAGE_NAME)
338346
def test_executemany_parameter_passhthrough_and_uses_last_result_set(
339347
self, mock_result_set_class
340348
):
341349
# Create a new mock result set each time the class is instantiated
342350
mock_result_set_instances = [Mock(), Mock(), Mock()]
351+
# Set is_staging_operation to False to avoid _handle_staging_operation being called
352+
for mock_rs in mock_result_set_instances:
353+
mock_rs.is_staging_operation = False
354+
343355
mock_result_set_class.side_effect = mock_result_set_instances
344356
mock_backend = ThriftDatabricksClientMockFactory.new()
357+
mock_backend.execute_command.side_effect = mock_result_set_instances
345358

346359
cursor = client.Cursor(Mock(), mock_backend)
347360

@@ -494,8 +507,9 @@ def test_staging_operation_response_is_handled(
494507
ThriftDatabricksClientMockFactory.apply_property_to_mock(
495508
mock_execute_response, is_staging_operation=True
496509
)
497-
mock_client_class.execute_command.return_value = mock_execute_response
498-
mock_client_class.return_value = mock_client_class
510+
mock_client = mock_client_class.return_value
511+
mock_client.execute_command.return_value = Mock(is_staging_operation=True)
512+
mock_client_class.return_value = mock_client
499513

500514
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
501515
cursor = connection.cursor()

0 commit comments

Comments
 (0)