Skip to content

Commit c63f6fd

Browse files
chore: move session specific tests from test_client to test_session
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent e7ebe2b commit c63f6fd

File tree

2 files changed

+187
-161
lines changed

2 files changed

+187
-161
lines changed

tests/unit/test_client.py

Lines changed: 0 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -80,93 +80,6 @@ class ClientTestSuite(unittest.TestCase):
8080
"access_token": "tok",
8181
}
8282

83-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
84-
def test_close_uses_the_correct_session_id(self, mock_client_class):
85-
instance = mock_client_class.return_value
86-
87-
mock_open_session_resp = MagicMock(spec=TOpenSessionResp)()
88-
mock_open_session_resp.sessionHandle.sessionId = b"\x22"
89-
instance.open_session.return_value = mock_open_session_resp
90-
91-
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
92-
connection.close()
93-
94-
# Check the close session request has an id of x22
95-
close_session_id = instance.close_session.call_args[0][0].sessionId
96-
self.assertEqual(close_session_id, b"\x22")
97-
98-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
99-
def test_auth_args(self, mock_client_class):
100-
# Test that the following auth args work:
101-
# token = foo,
102-
# token = None, _tls_client_cert_file = something, _use_cert_as_auth = True
103-
connection_args = [
104-
{
105-
"server_hostname": "foo",
106-
"http_path": None,
107-
"access_token": "tok",
108-
},
109-
{
110-
"server_hostname": "foo",
111-
"http_path": None,
112-
"_tls_client_cert_file": "something",
113-
"_use_cert_as_auth": True,
114-
"access_token": None,
115-
},
116-
]
117-
118-
for args in connection_args:
119-
connection = databricks.sql.connect(**args)
120-
host, port, http_path, *_ = mock_client_class.call_args[0]
121-
self.assertEqual(args["server_hostname"], host)
122-
self.assertEqual(args["http_path"], http_path)
123-
connection.close()
124-
125-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
126-
def test_http_header_passthrough(self, mock_client_class):
127-
http_headers = [("foo", "bar")]
128-
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers)
129-
130-
call_args = mock_client_class.call_args[0][3]
131-
self.assertIn(("foo", "bar"), call_args)
132-
133-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
134-
def test_tls_arg_passthrough(self, mock_client_class):
135-
databricks.sql.connect(
136-
**self.DUMMY_CONNECTION_ARGS,
137-
_tls_verify_hostname="hostname",
138-
_tls_trusted_ca_file="trusted ca file",
139-
_tls_client_cert_key_file="trusted client cert",
140-
_tls_client_cert_key_password="key password",
141-
)
142-
143-
kwargs = mock_client_class.call_args[1]
144-
self.assertEqual(kwargs["_tls_verify_hostname"], "hostname")
145-
self.assertEqual(kwargs["_tls_trusted_ca_file"], "trusted ca file")
146-
self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert")
147-
self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password")
148-
149-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
150-
def test_useragent_header(self, mock_client_class):
151-
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
152-
153-
http_headers = mock_client_class.call_args[0][3]
154-
user_agent_header = (
155-
"User-Agent",
156-
"{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__),
157-
)
158-
self.assertIn(user_agent_header, http_headers)
159-
160-
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar")
161-
user_agent_header_with_entry = (
162-
"User-Agent",
163-
"{}/{} ({})".format(
164-
databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar"
165-
),
166-
)
167-
http_headers = mock_client_class.call_args[0][3]
168-
self.assertIn(user_agent_header_with_entry, http_headers)
169-
17083
@patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new())
17184
@patch("%s.client.ResultSet" % PACKAGE_NAME)
17285
def test_closing_connection_closes_commands(self, mock_result_set_class):
@@ -290,21 +203,6 @@ def test_context_manager_closes_cursor(self):
290203
cursor.close = mock_close
291204
mock_close.assert_called_once_with()
292205

293-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
294-
def test_context_manager_closes_connection(self, mock_client_class):
295-
instance = mock_client_class.return_value
296-
297-
mock_open_session_resp = MagicMock(spec=TOpenSessionResp)()
298-
mock_open_session_resp.sessionHandle.sessionId = b"\x22"
299-
instance.open_session.return_value = mock_open_session_resp
300-
301-
with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection:
302-
pass
303-
304-
# Check the close session request has an id of x22
305-
close_session_id = instance.close_session.call_args[0][0].sessionId
306-
self.assertEqual(close_session_id, b"\x22")
307-
308206
def dict_product(self, dicts):
309207
"""
310208
Generate cartesion product of values in input dictionary, outputting a dictionary
@@ -403,21 +301,6 @@ def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command(
403301
self.assertTrue(logger_instance.warning.called)
404302
self.assertFalse(mock_thrift_backend.cancel_command.called)
405303

406-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
407-
def test_max_number_of_retries_passthrough(self, mock_client_class):
408-
databricks.sql.connect(
409-
_retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS
410-
)
411-
412-
self.assertEqual(
413-
mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54
414-
)
415-
416-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
417-
def test_socket_timeout_passthrough(self, mock_client_class):
418-
databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS)
419-
self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234)
420-
421304
def test_version_is_canonical(self):
422305
version = databricks.sql.__version__
423306
canonical_version_re = (
@@ -426,33 +309,6 @@ def test_version_is_canonical(self):
426309
)
427310
self.assertIsNotNone(re.match(canonical_version_re, version))
428311

429-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
430-
def test_configuration_passthrough(self, mock_client_class):
431-
mock_session_config = Mock()
432-
databricks.sql.connect(
433-
session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS
434-
)
435-
436-
self.assertEqual(
437-
mock_client_class.return_value.open_session.call_args[0][0],
438-
mock_session_config,
439-
)
440-
441-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
442-
def test_initial_namespace_passthrough(self, mock_client_class):
443-
mock_cat = Mock()
444-
mock_schem = Mock()
445-
446-
databricks.sql.connect(
447-
**self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem
448-
)
449-
self.assertEqual(
450-
mock_client_class.return_value.open_session.call_args[0][1], mock_cat
451-
)
452-
self.assertEqual(
453-
mock_client_class.return_value.open_session.call_args[0][2], mock_schem
454-
)
455-
456312
def test_execute_parameter_passthrough(self):
457313
mock_thrift_backend = ThriftBackendMockFactory.new()
458314
cursor = client.Cursor(Mock(), mock_thrift_backend)
@@ -610,23 +466,6 @@ def test_column_name_api(self):
610466
},
611467
)
612468

613-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
614-
def test_finalizer_closes_abandoned_connection(self, mock_client_class):
615-
instance = mock_client_class.return_value
616-
617-
mock_open_session_resp = MagicMock(spec=TOpenSessionResp)()
618-
mock_open_session_resp.sessionHandle.sessionId = b"\x22"
619-
instance.open_session.return_value = mock_open_session_resp
620-
621-
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
622-
623-
# not strictly necessary as the refcount is 0, but just to be sure
624-
gc.collect()
625-
626-
# Check the close session request has an id of x22
627-
close_session_id = instance.close_session.call_args[0][0].sessionId
628-
self.assertEqual(close_session_id, b"\x22")
629-
630469
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
631470
def test_cursor_keeps_connection_alive(self, mock_client_class):
632471
instance = mock_client_class.return_value

tests/unit/test_session.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import unittest
2+
from unittest.mock import patch, MagicMock, Mock, PropertyMock
3+
import gc
4+
5+
from databricks.sql.thrift_api.TCLIService.ttypes import (
6+
TOpenSessionResp,
7+
)
8+
9+
import databricks.sql
10+
11+
12+
class SessionTestSuite(unittest.TestCase):
13+
"""
14+
Unit tests for Session functionality
15+
"""
16+
17+
PACKAGE_NAME = "databricks.sql"
18+
DUMMY_CONNECTION_ARGS = {
19+
"server_hostname": "foo",
20+
"http_path": "dummy_path",
21+
"access_token": "tok",
22+
}
23+
24+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
25+
def test_close_uses_the_correct_session_id(self, mock_client_class):
26+
instance = mock_client_class.return_value
27+
28+
mock_open_session_resp = MagicMock(spec=TOpenSessionResp)()
29+
mock_open_session_resp.sessionHandle.sessionId = b"\x22"
30+
instance.open_session.return_value = mock_open_session_resp
31+
32+
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
33+
connection.close()
34+
35+
# Check the close session request has an id of x22
36+
close_session_id = instance.close_session.call_args[0][0].sessionId
37+
self.assertEqual(close_session_id, b"\x22")
38+
39+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
40+
def test_auth_args(self, mock_client_class):
41+
# Test that the following auth args work:
42+
# token = foo,
43+
# token = None, _tls_client_cert_file = something, _use_cert_as_auth = True
44+
connection_args = [
45+
{
46+
"server_hostname": "foo",
47+
"http_path": None,
48+
"access_token": "tok",
49+
},
50+
{
51+
"server_hostname": "foo",
52+
"http_path": None,
53+
"_tls_client_cert_file": "something",
54+
"_use_cert_as_auth": True,
55+
"access_token": None,
56+
},
57+
]
58+
59+
for args in connection_args:
60+
connection = databricks.sql.connect(**args)
61+
host, port, http_path, *_ = mock_client_class.call_args[0]
62+
self.assertEqual(args["server_hostname"], host)
63+
self.assertEqual(args["http_path"], http_path)
64+
connection.close()
65+
66+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
67+
def test_http_header_passthrough(self, mock_client_class):
68+
http_headers = [("foo", "bar")]
69+
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers)
70+
71+
call_args = mock_client_class.call_args[0][3]
72+
self.assertIn(("foo", "bar"), call_args)
73+
74+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
75+
def test_tls_arg_passthrough(self, mock_client_class):
76+
databricks.sql.connect(
77+
**self.DUMMY_CONNECTION_ARGS,
78+
_tls_verify_hostname="hostname",
79+
_tls_trusted_ca_file="trusted ca file",
80+
_tls_client_cert_key_file="trusted client cert",
81+
_tls_client_cert_key_password="key password",
82+
)
83+
84+
kwargs = mock_client_class.call_args[1]
85+
self.assertEqual(kwargs["_tls_verify_hostname"], "hostname")
86+
self.assertEqual(kwargs["_tls_trusted_ca_file"], "trusted ca file")
87+
self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert")
88+
self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password")
89+
90+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
91+
def test_useragent_header(self, mock_client_class):
92+
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
93+
94+
http_headers = mock_client_class.call_args[0][3]
95+
user_agent_header = (
96+
"User-Agent",
97+
"{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__),
98+
)
99+
self.assertIn(user_agent_header, http_headers)
100+
101+
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar")
102+
user_agent_header_with_entry = (
103+
"User-Agent",
104+
"{}/{} ({})".format(
105+
databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar"
106+
),
107+
)
108+
http_headers = mock_client_class.call_args[0][3]
109+
self.assertIn(user_agent_header_with_entry, http_headers)
110+
111+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
112+
def test_context_manager_closes_connection(self, mock_client_class):
113+
instance = mock_client_class.return_value
114+
115+
mock_open_session_resp = MagicMock(spec=TOpenSessionResp)()
116+
mock_open_session_resp.sessionHandle.sessionId = b"\x22"
117+
instance.open_session.return_value = mock_open_session_resp
118+
119+
with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection:
120+
pass
121+
122+
# Check the close session request has an id of x22
123+
close_session_id = instance.close_session.call_args[0][0].sessionId
124+
self.assertEqual(close_session_id, b"\x22")
125+
126+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
127+
def test_max_number_of_retries_passthrough(self, mock_client_class):
128+
databricks.sql.connect(
129+
_retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS
130+
)
131+
132+
self.assertEqual(
133+
mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54
134+
)
135+
136+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
137+
def test_socket_timeout_passthrough(self, mock_client_class):
138+
databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS)
139+
self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234)
140+
141+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
142+
def test_configuration_passthrough(self, mock_client_class):
143+
mock_session_config = Mock()
144+
databricks.sql.connect(
145+
session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS
146+
)
147+
148+
self.assertEqual(
149+
mock_client_class.return_value.open_session.call_args[0][0],
150+
mock_session_config,
151+
)
152+
153+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
154+
def test_initial_namespace_passthrough(self, mock_client_class):
155+
mock_cat = Mock()
156+
mock_schem = Mock()
157+
158+
databricks.sql.connect(
159+
**self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem
160+
)
161+
self.assertEqual(
162+
mock_client_class.return_value.open_session.call_args[0][1], mock_cat
163+
)
164+
self.assertEqual(
165+
mock_client_class.return_value.open_session.call_args[0][2], mock_schem
166+
)
167+
168+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
169+
def test_finalizer_closes_abandoned_connection(self, mock_client_class):
170+
instance = mock_client_class.return_value
171+
172+
mock_open_session_resp = MagicMock(spec=TOpenSessionResp)()
173+
mock_open_session_resp.sessionHandle.sessionId = b"\x22"
174+
instance.open_session.return_value = mock_open_session_resp
175+
176+
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
177+
178+
# not strictly necessary as the refcount is 0, but just to be sure
179+
gc.collect()
180+
181+
# Check the close session request has an id of x22
182+
close_session_id = instance.close_session.call_args[0][0].sessionId
183+
self.assertEqual(close_session_id, b"\x22")
184+
185+
186+
if __name__ == "__main__":
187+
unittest.main()

0 commit comments

Comments
 (0)