Skip to content

Commit e7ebe2b

Browse files
update unit tests to address ThriftBackend through session instead of through Connection
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent 0e6efd8 commit e7ebe2b

File tree

1 file changed

+28
-21
lines changed

1 file changed

+28
-21
lines changed

tests/unit/test_client.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class ClientTestSuite(unittest.TestCase):
8080
"access_token": "tok",
8181
}
8282

83-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
83+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
8484
def test_close_uses_the_correct_session_id(self, mock_client_class):
8585
instance = mock_client_class.return_value
8686

@@ -95,7 +95,7 @@ def test_close_uses_the_correct_session_id(self, mock_client_class):
9595
close_session_id = instance.close_session.call_args[0][0].sessionId
9696
self.assertEqual(close_session_id, b"\x22")
9797

98-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
98+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
9999
def test_auth_args(self, mock_client_class):
100100
# Test that the following auth args work:
101101
# token = foo,
@@ -122,15 +122,15 @@ def test_auth_args(self, mock_client_class):
122122
self.assertEqual(args["http_path"], http_path)
123123
connection.close()
124124

125-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
125+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
126126
def test_http_header_passthrough(self, mock_client_class):
127127
http_headers = [("foo", "bar")]
128128
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers)
129129

130130
call_args = mock_client_class.call_args[0][3]
131131
self.assertIn(("foo", "bar"), call_args)
132132

133-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
133+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
134134
def test_tls_arg_passthrough(self, mock_client_class):
135135
databricks.sql.connect(
136136
**self.DUMMY_CONNECTION_ARGS,
@@ -146,7 +146,7 @@ def test_tls_arg_passthrough(self, mock_client_class):
146146
self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert")
147147
self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password")
148148

149-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
149+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
150150
def test_useragent_header(self, mock_client_class):
151151
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
152152

@@ -167,7 +167,7 @@ def test_useragent_header(self, mock_client_class):
167167
http_headers = mock_client_class.call_args[0][3]
168168
self.assertIn(user_agent_header_with_entry, http_headers)
169169

170-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new())
170+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new())
171171
@patch("%s.client.ResultSet" % PACKAGE_NAME)
172172
def test_closing_connection_closes_commands(self, mock_result_set_class):
173173
# Test once with has_been_closed_server side, once without
@@ -184,7 +184,7 @@ def test_closing_connection_closes_commands(self, mock_result_set_class):
184184
)
185185
mock_result_set_class.return_value.close.assert_called_once_with()
186186

187-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
187+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
188188
def test_cant_open_cursor_on_closed_connection(self, mock_client_class):
189189
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
190190
self.assertTrue(connection.open)
@@ -194,7 +194,7 @@ def test_cant_open_cursor_on_closed_connection(self, mock_client_class):
194194
connection.cursor()
195195
self.assertIn("closed", str(cm.exception))
196196

197-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
197+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
198198
@patch("%s.client.Cursor" % PACKAGE_NAME)
199199
def test_arraysize_buffer_size_passthrough(
200200
self, mock_cursor_class, mock_client_class
@@ -214,7 +214,10 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self):
214214
thrift_backend=mock_backend,
215215
execute_response=Mock(),
216216
)
217-
mock_connection.open = False
217+
# Setup session mock on the mock_connection
218+
mock_session = Mock()
219+
mock_session.open = False
220+
type(mock_connection).session = PropertyMock(return_value=mock_session)
218221

219222
result_set.close()
220223

@@ -226,7 +229,11 @@ def test_closing_result_set_hard_closes_commands(self):
226229
mock_results_response.has_been_closed_server_side = False
227230
mock_connection = Mock()
228231
mock_thrift_backend = Mock()
229-
mock_connection.open = True
232+
# Setup session mock on the mock_connection
233+
mock_session = Mock()
234+
mock_session.open = True
235+
type(mock_connection).session = PropertyMock(return_value=mock_session)
236+
230237
result_set = client.ResultSet(
231238
mock_connection, mock_results_response, mock_thrift_backend
232239
)
@@ -283,7 +290,7 @@ def test_context_manager_closes_cursor(self):
283290
cursor.close = mock_close
284291
mock_close.assert_called_once_with()
285292

286-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
293+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
287294
def test_context_manager_closes_connection(self, mock_client_class):
288295
instance = mock_client_class.return_value
289296

@@ -396,7 +403,7 @@ def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command(
396403
self.assertTrue(logger_instance.warning.called)
397404
self.assertFalse(mock_thrift_backend.cancel_command.called)
398405

399-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
406+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
400407
def test_max_number_of_retries_passthrough(self, mock_client_class):
401408
databricks.sql.connect(
402409
_retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS
@@ -406,7 +413,7 @@ def test_max_number_of_retries_passthrough(self, mock_client_class):
406413
mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54
407414
)
408415

409-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
416+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
410417
def test_socket_timeout_passthrough(self, mock_client_class):
411418
databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS)
412419
self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234)
@@ -419,7 +426,7 @@ def test_version_is_canonical(self):
419426
)
420427
self.assertIsNotNone(re.match(canonical_version_re, version))
421428

422-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
429+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
423430
def test_configuration_passthrough(self, mock_client_class):
424431
mock_session_config = Mock()
425432
databricks.sql.connect(
@@ -431,7 +438,7 @@ def test_configuration_passthrough(self, mock_client_class):
431438
mock_session_config,
432439
)
433440

434-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
441+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
435442
def test_initial_namespace_passthrough(self, mock_client_class):
436443
mock_cat = Mock()
437444
mock_schem = Mock()
@@ -505,7 +512,7 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set(
505512
"last operation",
506513
)
507514

508-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
515+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
509516
def test_commit_a_noop(self, mock_thrift_backend_class):
510517
c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
511518
c.commit()
@@ -518,7 +525,7 @@ def test_setoutputsizes_a_noop(self):
518525
cursor = client.Cursor(Mock(), Mock())
519526
cursor.setoutputsize(1)
520527

521-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
528+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
522529
def test_rollback_not_supported(self, mock_thrift_backend_class):
523530
c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
524531
with self.assertRaises(NotSupportedError):
@@ -603,7 +610,7 @@ def test_column_name_api(self):
603610
},
604611
)
605612

606-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
613+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
607614
def test_finalizer_closes_abandoned_connection(self, mock_client_class):
608615
instance = mock_client_class.return_value
609616

@@ -620,7 +627,7 @@ def test_finalizer_closes_abandoned_connection(self, mock_client_class):
620627
close_session_id = instance.close_session.call_args[0][0].sessionId
621628
self.assertEqual(close_session_id, b"\x22")
622629

623-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
630+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
624631
def test_cursor_keeps_connection_alive(self, mock_client_class):
625632
instance = mock_client_class.return_value
626633

@@ -639,7 +646,7 @@ def test_cursor_keeps_connection_alive(self, mock_client_class):
639646

640647
@patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True)
641648
@patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME)
642-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
649+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
643650
def test_staging_operation_response_is_handled(
644651
self, mock_client_class, mock_handle_staging_operation, mock_execute_response
645652
):
@@ -658,7 +665,7 @@ def test_staging_operation_response_is_handled(
658665

659666
mock_handle_staging_operation.call_count == 1
660667

661-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new())
668+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new())
662669
def test_access_current_query_id(self):
663670
operation_id = "EE6A8778-21FC-438B-92D8-96AC51EE3821"
664671

0 commit comments

Comments
 (0)