Skip to content

Commit 296b02f

Browse files
committed
Update existing tests
Signed-off-by: Levko Kravets <[email protected]>
1 parent 6f224b3 commit 296b02f

File tree

4 files changed

+97
-84
lines changed

4 files changed

+97
-84
lines changed

tests/unit/test_cloud_fetch_queue.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import pyarrow
22
import unittest
33
from unittest.mock import MagicMock, patch
4-
from ssl import create_default_context
54

65
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
76
import databricks.sql.utils as utils
7+
from databricks.sql.types import SSLOptions
88

99
class CloudFetchQueueSuite(unittest.TestCase):
1010

@@ -51,7 +51,7 @@ def test_initializer_adds_links(self, mock_create_next_table):
5151
schema_bytes,
5252
result_links=result_links,
5353
max_download_threads=10,
54-
ssl_context=create_default_context(),
54+
ssl_options=SSLOptions(),
5555
)
5656

5757
assert len(queue.download_manager._pending_links) == 10
@@ -65,7 +65,7 @@ def test_initializer_no_links_to_add(self):
6565
schema_bytes,
6666
result_links=result_links,
6767
max_download_threads=10,
68-
ssl_context=create_default_context(),
68+
ssl_options=SSLOptions(),
6969
)
7070

7171
assert len(queue.download_manager._pending_links) == 0
@@ -78,7 +78,7 @@ def test_create_next_table_no_download(self, mock_get_next_downloaded_file):
7878
MagicMock(),
7979
result_links=[],
8080
max_download_threads=10,
81-
ssl_context=create_default_context(),
81+
ssl_options=SSLOptions(),
8282
)
8383

8484
assert queue._create_next_table() is None
@@ -95,7 +95,7 @@ def test_initializer_create_next_table_success(self, mock_get_next_downloaded_fi
9595
result_links=[],
9696
description=description,
9797
max_download_threads=10,
98-
ssl_context=create_default_context(),
98+
ssl_options=SSLOptions(),
9999
)
100100
expected_result = self.make_arrow_table()
101101

@@ -120,7 +120,7 @@ def test_next_n_rows_0_rows(self, mock_create_next_table):
120120
result_links=[],
121121
description=description,
122122
max_download_threads=10,
123-
ssl_context=create_default_context(),
123+
ssl_options=SSLOptions(),
124124
)
125125
assert queue.table == self.make_arrow_table()
126126
assert queue.table.num_rows == 4
@@ -140,7 +140,7 @@ def test_next_n_rows_partial_table(self, mock_create_next_table):
140140
result_links=[],
141141
description=description,
142142
max_download_threads=10,
143-
ssl_context=create_default_context(),
143+
ssl_options=SSLOptions(),
144144
)
145145
assert queue.table == self.make_arrow_table()
146146
assert queue.table.num_rows == 4
@@ -160,7 +160,7 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table):
160160
result_links=[],
161161
description=description,
162162
max_download_threads=10,
163-
ssl_context=create_default_context(),
163+
ssl_options=SSLOptions(),
164164
)
165165
assert queue.table == self.make_arrow_table()
166166
assert queue.table.num_rows == 4
@@ -180,7 +180,7 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table):
180180
result_links=[],
181181
description=description,
182182
max_download_threads=10,
183-
ssl_context=create_default_context(),
183+
ssl_options=SSLOptions(),
184184
)
185185
assert queue.table == self.make_arrow_table()
186186
assert queue.table.num_rows == 4
@@ -199,7 +199,7 @@ def test_next_n_rows_empty_table(self, mock_create_next_table):
199199
result_links=[],
200200
description=description,
201201
max_download_threads=10,
202-
ssl_context=create_default_context(),
202+
ssl_options=SSLOptions(),
203203
)
204204
assert queue.table is None
205205

@@ -216,7 +216,7 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table)
216216
result_links=[],
217217
description=description,
218218
max_download_threads=10,
219-
ssl_context=create_default_context(),
219+
ssl_options=SSLOptions(),
220220
)
221221
assert queue.table == self.make_arrow_table()
222222
assert queue.table.num_rows == 4
@@ -235,7 +235,7 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl
235235
result_links=[],
236236
description=description,
237237
max_download_threads=10,
238-
ssl_context=create_default_context(),
238+
ssl_options=SSLOptions(),
239239
)
240240
assert queue.table == self.make_arrow_table()
241241
assert queue.table.num_rows == 4
@@ -254,7 +254,7 @@ def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table):
254254
result_links=[],
255255
description=description,
256256
max_download_threads=10,
257-
ssl_context=create_default_context(),
257+
ssl_options=SSLOptions(),
258258
)
259259
assert queue.table == self.make_arrow_table()
260260
assert queue.table.num_rows == 4
@@ -273,7 +273,7 @@ def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_ta
273273
result_links=[],
274274
description=description,
275275
max_download_threads=10,
276-
ssl_context=create_default_context(),
276+
ssl_options=SSLOptions(),
277277
)
278278
assert queue.table == self.make_arrow_table()
279279
assert queue.table.num_rows == 4
@@ -293,7 +293,7 @@ def test_remaining_rows_empty_table(self, mock_create_next_table):
293293
result_links=[],
294294
description=description,
295295
max_download_threads=10,
296-
ssl_context=create_default_context(),
296+
ssl_options=SSLOptions(),
297297
)
298298
assert queue.table is None
299299

tests/unit/test_download_manager.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import unittest
22
from unittest.mock import patch, MagicMock
33

4-
from ssl import create_default_context
5-
64
import databricks.sql.cloudfetch.download_manager as download_manager
5+
from databricks.sql.types import SSLOptions
76
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
87

98

@@ -17,7 +16,7 @@ def create_download_manager(self, links, max_download_threads=10, lz4_compressed
1716
links,
1817
max_download_threads,
1918
lz4_compressed,
20-
ssl_context=create_default_context(),
19+
ssl_options=SSLOptions(),
2120
)
2221

2322
def create_result_link(

tests/unit/test_downloader.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from unittest.mock import Mock, patch, MagicMock
33

44
import requests
5-
from ssl import create_default_context
65

76
import databricks.sql.cloudfetch.downloader as downloader
87
from databricks.sql.exc import Error
8+
from databricks.sql.types import SSLOptions
99

1010

1111
def create_response(**kwargs) -> requests.Response:
@@ -26,7 +26,7 @@ def test_run_link_expired(self, mock_time):
2626
result_link = Mock()
2727
# Already expired
2828
result_link.expiryTime = 999
29-
d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context())
29+
d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions())
3030

3131
with self.assertRaises(Error) as context:
3232
d.run()
@@ -40,7 +40,7 @@ def test_run_link_past_expiry_buffer(self, mock_time):
4040
result_link = Mock()
4141
# Within the expiry buffer time
4242
result_link.expiryTime = 1004
43-
d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context())
43+
d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions())
4444

4545
with self.assertRaises(Error) as context:
4646
d.run()
@@ -58,7 +58,7 @@ def test_run_get_response_not_ok(self, mock_time, mock_session):
5858
settings.use_proxy = False
5959
result_link = Mock(expiryTime=1001)
6060

61-
d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context())
61+
d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions())
6262
with self.assertRaises(requests.exceptions.HTTPError) as context:
6363
d.run()
6464
self.assertTrue('404' in str(context.exception))
@@ -73,7 +73,7 @@ def test_run_uncompressed_successful(self, mock_time, mock_session):
7373
settings.is_lz4_compressed = False
7474
result_link = Mock(bytesNum=100, expiryTime=1001)
7575

76-
d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context())
76+
d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions())
7777
file = d.run()
7878

7979
assert file.file_bytes == b"1234567890" * 10
@@ -89,7 +89,7 @@ def test_run_compressed_successful(self, mock_time, mock_session):
8989
settings.is_lz4_compressed = True
9090
result_link = Mock(bytesNum=100, expiryTime=1001)
9191

92-
d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context())
92+
d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions())
9393
file = d.run()
9494

9595
assert file.file_bytes == b"1234567890" * 10
@@ -102,7 +102,7 @@ def test_download_connection_error(self, mock_time, mock_session):
102102
mock_session.return_value.get.return_value.content = \
103103
b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
104104

105-
d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context())
105+
d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions())
106106
with self.assertRaises(ConnectionError):
107107
d.run()
108108

@@ -114,6 +114,6 @@ def test_download_timeout(self, mock_time, mock_session):
114114
mock_session.return_value.get.return_value.content = \
115115
b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
116116

117-
d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context())
117+
d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions())
118118
with self.assertRaises(TimeoutError):
119119
d.run()

0 commit comments

Comments
 (0)