21
21
import databricks .sql .client as client
22
22
from databricks .sql import InterfaceError , DatabaseError , Error , NotSupportedError
23
23
from databricks .sql .types import Row
24
+ from databricks .sql .result_set import ResultSet , ThriftResultSet
24
25
25
26
from tests .unit .test_fetches import FetchTests
26
27
from tests .unit .test_thrift_backend import ThriftBackendTestSuite
@@ -34,12 +35,11 @@ def new(cls):
34
35
ThriftBackendMock .return_value = ThriftBackendMock
35
36
36
37
cls .apply_property_to_mock (ThriftBackendMock , staging_allowed_local_path = None )
37
- MockTExecuteStatementResp = MagicMock ( spec = TExecuteStatementResp ())
38
-
38
+
39
+ mock_result_set = Mock ( spec = ThriftResultSet )
39
40
cls .apply_property_to_mock (
40
- MockTExecuteStatementResp ,
41
+ mock_result_set ,
41
42
description = None ,
42
- arrow_queue = None ,
43
43
is_staging_operation = False ,
44
44
command_handle = b"\x22 " ,
45
45
has_been_closed_server_side = True ,
@@ -48,7 +48,7 @@ def new(cls):
48
48
arrow_schema_bytes = b"schema" ,
49
49
)
50
50
51
- ThriftBackendMock .execute_command .return_value = MockTExecuteStatementResp
51
+ ThriftBackendMock .execute_command .return_value = mock_result_set
52
52
53
53
return ThriftBackendMock
54
54
@@ -81,21 +81,22 @@ class ClientTestSuite(unittest.TestCase):
81
81
}
82
82
83
83
@patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME , ThriftDatabricksClientMockFactory .new ())
84
- @patch ("%s.client.ResultSet" % PACKAGE_NAME )
85
- def test_closing_connection_closes_commands (self , mock_result_set_class ):
86
- # Test once with has_been_closed_server side, once without
84
+ def test_closing_connection_closes_commands (self ):
87
85
for closed in (True , False ):
88
86
with self .subTest (closed = closed ):
89
- mock_result_set_class .return_value = Mock ()
90
87
connection = databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS )
91
88
cursor = connection .cursor ()
92
- cursor .execute ("SELECT 1;" )
89
+
90
+ # Create a mock result set and set it as the active result set
91
+ mock_result_set = Mock ()
92
+ mock_result_set .has_been_closed_server_side = closed
93
+ cursor .active_result_set = mock_result_set
94
+
95
+ # Close the connection
93
96
connection .close ()
94
-
95
- self .assertTrue (
96
- mock_result_set_class .return_value .has_been_closed_server_side
97
- )
98
- mock_result_set_class .return_value .close .assert_called_once_with ()
97
+
98
+ # After closing the connection, the close method should have been called on the result set
99
+ mock_result_set .close .assert_called_once_with ()
99
100
100
101
@patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
101
102
def test_cant_open_cursor_on_closed_connection (self , mock_client_class ):
@@ -122,10 +123,11 @@ def test_arraysize_buffer_size_passthrough(
122
123
def test_closing_result_set_with_closed_connection_soft_closes_commands (self ):
123
124
mock_connection = Mock ()
124
125
mock_backend = Mock ()
125
- result_set = client .ResultSet (
126
+
127
+ result_set = ThriftResultSet (
126
128
connection = mock_connection ,
127
- backend = mock_backend ,
128
129
execute_response = Mock (),
130
+ thrift_client = mock_backend ,
129
131
)
130
132
# Setup session mock on the mock_connection
131
133
mock_session = Mock ()
@@ -147,7 +149,7 @@ def test_closing_result_set_hard_closes_commands(self):
147
149
mock_session .open = True
148
150
type(mock_connection ).session = PropertyMock (return_value = mock_session )
149
151
150
- result_set = client . ResultSet (
152
+ result_set = ThriftResultSet (
151
153
mock_connection , mock_results_response , mock_thrift_backend
152
154
)
153
155
@@ -157,16 +159,22 @@ def test_closing_result_set_hard_closes_commands(self):
157
159
mock_results_response .command_handle
158
160
)
159
161
160
- @patch ("%s.client.ResultSet " % PACKAGE_NAME )
162
+ @patch ("%s.result_set.ThriftResultSet " % PACKAGE_NAME )
161
163
def test_executing_multiple_commands_uses_the_most_recent_command (
162
164
self , mock_result_set_class
163
165
):
164
-
165
166
mock_result_sets = [Mock (), Mock ()]
167
+ # Set is_staging_operation to False to avoid _handle_staging_operation being called
168
+ for mock_rs in mock_result_sets :
169
+ mock_rs .is_staging_operation = False
170
+
166
171
mock_result_set_class .side_effect = mock_result_sets
172
+
173
+ mock_backend = ThriftDatabricksClientMockFactory .new ()
174
+ mock_backend .execute_command .side_effect = mock_result_sets
167
175
168
176
cursor = client .Cursor (
169
- connection = Mock (), backend = ThriftDatabricksClientMockFactory . new ()
177
+ connection = Mock (), backend = mock_backend
170
178
)
171
179
cursor .execute ("SELECT 1;" )
172
180
cursor .execute ("SELECT 1;" )
@@ -192,7 +200,7 @@ def test_closed_cursor_doesnt_allow_operations(self):
192
200
self .assertIn ("closed" , e .msg )
193
201
194
202
def test_negative_fetch_throws_exception (self ):
195
- result_set = client . ResultSet (Mock (), Mock (), Mock ())
203
+ result_set = ThriftResultSet (Mock (), Mock (), Mock ())
196
204
197
205
with self .assertRaises (ValueError ) as e :
198
206
result_set .fetchmany (- 1 )
@@ -334,14 +342,19 @@ def test_execute_parameter_passthrough(self):
334
342
expected_query ,
335
343
)
336
344
337
- @patch ("%s.client.ResultSet " % PACKAGE_NAME )
345
+ @patch ("%s.result_set.ThriftResultSet " % PACKAGE_NAME )
338
346
def test_executemany_parameter_passhthrough_and_uses_last_result_set (
339
347
self , mock_result_set_class
340
348
):
341
349
# Create a new mock result set each time the class is instantiated
342
350
mock_result_set_instances = [Mock (), Mock (), Mock ()]
351
+ # Set is_staging_operation to False to avoid _handle_staging_operation being called
352
+ for mock_rs in mock_result_set_instances :
353
+ mock_rs .is_staging_operation = False
354
+
343
355
mock_result_set_class .side_effect = mock_result_set_instances
344
356
mock_backend = ThriftDatabricksClientMockFactory .new ()
357
+ mock_backend .execute_command .side_effect = mock_result_set_instances
345
358
346
359
cursor = client .Cursor (Mock (), mock_backend )
347
360
@@ -494,8 +507,9 @@ def test_staging_operation_response_is_handled(
494
507
ThriftDatabricksClientMockFactory .apply_property_to_mock (
495
508
mock_execute_response , is_staging_operation = True
496
509
)
497
- mock_client_class .execute_command .return_value = mock_execute_response
498
- mock_client_class .return_value = mock_client_class
510
+ mock_client = mock_client_class .return_value
511
+ mock_client .execute_command .return_value = Mock (is_staging_operation = True )
512
+ mock_client_class .return_value = mock_client
499
513
500
514
connection = databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS )
501
515
cursor = connection .cursor ()
0 commit comments