Skip to content

Commit d2ffe1e

Browse files
committed
Cloud Fetch download handler
Signed-off-by: Matthew Kim <[email protected]>
1 parent 5379803 commit d2ffe1e

File tree

2 files changed

+334
-0
lines changed

2 files changed

+334
-0
lines changed
+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import logging
2+
3+
import requests
4+
import lz4.frame
5+
import threading
6+
import time
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
class ResultSetDownloadHandler(threading.Thread):
12+
13+
def __init__(self, downloadable_result_settings, t_spark_arrow_result_link):
14+
super().__init__()
15+
self.settings = downloadable_result_settings
16+
self.result_link = t_spark_arrow_result_link
17+
self.is_download_scheduled = False
18+
self.is_download_finished = threading.Event()
19+
self.is_file_downloaded_successfully = False
20+
self.is_link_expired = False
21+
self.is_download_timedout = False
22+
self.http_code = None
23+
self.result_file = None
24+
self.check_result_file_link_expiry = True
25+
self.download_completion_semaphore = threading.Semaphore(0)
26+
27+
def is_file_download_successful(self):
28+
try:
29+
if not self.is_download_finished.is_set():
30+
if self.settings.download_timeout and self.settings.download_timeout > 0:
31+
if not self.download_completion_semaphore.acquire(timeout=self.settings.download_timeout):
32+
self.is_download_timedout = True
33+
logger.debug("Cloud fetch download timed out after {} seconds for url: {}"
34+
.format(self.settings.download_timeout, self.result_link.file_link)
35+
)
36+
return False
37+
else:
38+
self.download_completion_semaphore.acquire()
39+
except:
40+
return False
41+
return self.is_file_downloaded_successfully
42+
43+
def run(self):
44+
self.is_file_downloaded_successfully = False
45+
self.is_link_expired = False
46+
self.is_download_timedout = False
47+
self.is_download_finished = threading.Event()
48+
49+
if self.check_result_file_link_expiry:
50+
current_time = int(time.time() * 1000)
51+
if (self.result_link.expiryTime < current_time) or (
52+
self.result_link.expiryTime - current_time < (
53+
self.settings.result_file_link_expiry_buffer * 1000)
54+
):
55+
self.is_link_expired = True
56+
return
57+
58+
session = requests.Session()
59+
session.timeout = self.settings.download_timeout
60+
61+
if (
62+
self.settings.use_proxy
63+
and not self.settings.disable_proxy_for_cloud_fetch
64+
):
65+
proxy = {
66+
"http": f"http://{self.settings.proxy_host}:{self.settings.proxy_port}",
67+
"https": f"https://{self.settings.proxy_host}:{self.settings.proxy_port}",
68+
}
69+
session.proxies.update(proxy)
70+
71+
# ProxyAuthentication -> static enum BASIC and NONE
72+
if self.settings.proxy_auth == "BASIC":
73+
session.auth = requests.auth.HTTPBasicAuth(self.settings.proxy_uid, self.settings.proxy_pwd)
74+
75+
try:
76+
response = session.get(self.result_link.fileLink)
77+
self.http_code = response.status_code
78+
79+
if self.http_code != 200:
80+
self.is_file_downloaded_successfully = False
81+
else:
82+
if self.settings.is_lz4_compressed:
83+
compressed_data = response.content
84+
uncompressed_data = lz4.frame.decompress(compressed_data)
85+
self.result_file = uncompressed_data
86+
87+
if len(uncompressed_data) != self.result_link.bytesNum:
88+
self.is_file_downloaded_successfully = False
89+
else:
90+
self.is_file_downloaded_successfully = True
91+
92+
else:
93+
self.result_file = response.content
94+
if len(self.result_file) != self.result_link.bytesNum:
95+
self.is_file_downloaded_successfully = False
96+
else:
97+
self.is_file_downloaded_successfully = True
98+
except:
99+
self.is_file_downloaded_successfully = False
100+
101+
finally:
102+
session.close()
103+
self.is_download_finished.set()
104+
self.download_completion_semaphore.release()

tests/unit/test_downloader.py

+230
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
import unittest
2+
from unittest.mock import Mock, patch, MagicMock
3+
4+
import databricks.sql.cloudfetch.downloader as downloader
5+
6+
7+
class DownloaderTests(unittest.TestCase):
8+
"""
9+
Unit tests for checking downloader logic.
10+
"""
11+
12+
@patch('time.time', return_value=1000)
13+
def test_run_link_expired(self, mock_time):
14+
settings = Mock()
15+
result_link = Mock()
16+
# Already expired
17+
result_link.expiryTime = 999999
18+
d = downloader.ResultSetDownloadHandler(settings, result_link)
19+
assert not d.is_link_expired
20+
d.run()
21+
assert d.is_link_expired
22+
mock_time.assert_called_once()
23+
24+
@patch('time.time', return_value=1000)
25+
def test_run_link_past_expiry_buffer(self, mock_time):
26+
settings = Mock()
27+
settings.result_file_link_expiry_buffer = 0.005
28+
result_link = Mock()
29+
# Within the expiry buffer time
30+
result_link.expiryTime = 1000004
31+
d = downloader.ResultSetDownloadHandler(settings, result_link)
32+
assert not d.is_link_expired
33+
d.run()
34+
assert d.is_link_expired
35+
mock_time.assert_called_once()
36+
37+
@patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(status_code=500))))
38+
@patch('time.time', return_value=1000)
39+
def test_run_get_response_not_200(self, mock_time, mock_session):
40+
settings = Mock()
41+
settings.result_file_link_expiry_buffer = 0
42+
settings.download_timeout = 0
43+
settings.use_proxy = False
44+
result_link = Mock()
45+
result_link.expiryTime = 1000001
46+
47+
d = downloader.ResultSetDownloadHandler(settings, result_link)
48+
d.run()
49+
50+
assert not d.is_file_downloaded_successfully
51+
assert d.is_download_finished.is_set()
52+
53+
@patch('requests.Session',
54+
return_value=MagicMock(get=MagicMock(return_value=MagicMock(status_code=200, content=b"1234567890" * 9))))
55+
@patch('time.time', return_value=1000)
56+
def test_run_uncompressed_data_length_incorrect(self, mock_time, mock_session):
57+
settings = Mock()
58+
settings.result_file_link_expiry_buffer = 0
59+
settings.download_timeout = 0
60+
settings.use_proxy = False
61+
settings.is_lz4_compressed = False
62+
result_link = Mock()
63+
result_link.bytesNum = 100
64+
result_link.expiryTime = 1000001
65+
66+
d = downloader.ResultSetDownloadHandler(settings, result_link)
67+
d.run()
68+
69+
assert not d.is_file_downloaded_successfully
70+
assert d.is_download_finished.is_set()
71+
72+
@patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(status_code=200))))
73+
@patch('time.time', return_value=1000)
74+
def test_run_compressed_data_length_incorrect(self, mock_time, mock_session):
75+
settings = Mock()
76+
settings.result_file_link_expiry_buffer = 0
77+
settings.download_timeout = 0
78+
settings.use_proxy = False
79+
settings.is_lz4_compressed = True
80+
result_link = Mock()
81+
result_link.bytesNum = 100
82+
result_link.expiryTime = 1000001
83+
mock_session.return_value.get.return_value.content = \
84+
b'\x04"M\x18h@Z\x00\x00\x00\x00\x00\x00\x00\xec\x14\x00\x00\x00\xaf1234567890\n\x008P67890\x00\x00\x00\x00'
85+
86+
d = downloader.ResultSetDownloadHandler(settings, result_link)
87+
d.run()
88+
89+
assert not d.is_file_downloaded_successfully
90+
assert d.is_download_finished.is_set()
91+
92+
@patch('requests.Session',
93+
return_value=MagicMock(get=MagicMock(return_value=MagicMock(status_code=200, content=b"1234567890" * 10))))
94+
@patch('time.time', return_value=1000)
95+
def test_run_uncompressed_successful(self, mock_time, mock_session):
96+
settings = Mock()
97+
settings.result_file_link_expiry_buffer = 0
98+
settings.download_timeout = 0
99+
settings.use_proxy = False
100+
settings.is_lz4_compressed = False
101+
result_link = Mock()
102+
result_link.bytesNum = 100
103+
result_link.expiryTime = 1000001
104+
105+
d = downloader.ResultSetDownloadHandler(settings, result_link)
106+
d.run()
107+
108+
assert d.result_file == b"1234567890" * 10
109+
assert d.is_file_downloaded_successfully
110+
assert d.is_download_finished.is_set()
111+
112+
@patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(status_code=200))))
113+
@patch('time.time', return_value=1000)
114+
def test_run_compressed_successful(self, mock_time, mock_session):
115+
settings = Mock()
116+
settings.result_file_link_expiry_buffer = 0
117+
settings.download_timeout = 0
118+
settings.use_proxy = False
119+
settings.is_lz4_compressed = True
120+
result_link = Mock()
121+
result_link.bytesNum = 100
122+
result_link.expiryTime = 1000001
123+
mock_session.return_value.get.return_value.content = \
124+
b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
125+
126+
d = downloader.ResultSetDownloadHandler(settings, result_link)
127+
d.run()
128+
129+
assert d.result_file == b"1234567890" * 10
130+
assert d.is_file_downloaded_successfully
131+
assert d.is_download_finished.is_set()
132+
133+
@patch('requests.Session.get', side_effect=ConnectionError('foo'))
134+
@patch('time.time', return_value=1000)
135+
def test_download_connection_error(self, mock_time, mock_session):
136+
settings = Mock()
137+
settings.result_file_link_expiry_buffer = 0
138+
settings.use_proxy = False
139+
settings.is_lz4_compressed = True
140+
result_link = Mock()
141+
result_link.bytesNum = 100
142+
result_link.expiryTime = 1000001
143+
mock_session.return_value.get.return_value.content = \
144+
b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
145+
146+
d = downloader.ResultSetDownloadHandler(settings, result_link)
147+
d.run()
148+
149+
assert not d.is_file_downloaded_successfully
150+
assert d.is_download_finished.is_set()
151+
152+
@patch('requests.Session.get', side_effect=TimeoutError('foo'))
153+
@patch('time.time', return_value=1000)
154+
def test_download_timeout(self, mock_time, mock_session):
155+
settings = Mock()
156+
settings.result_file_link_expiry_buffer = 0
157+
settings.use_proxy = False
158+
settings.is_lz4_compressed = True
159+
result_link = Mock()
160+
result_link.bytesNum = 100
161+
result_link.expiryTime = 1000001
162+
mock_session.return_value.get.return_value.content = \
163+
b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
164+
165+
d = downloader.ResultSetDownloadHandler(settings, result_link)
166+
d.run()
167+
168+
assert not d.is_file_downloaded_successfully
169+
assert d.is_download_finished.is_set()
170+
171+
@patch("threading.Event.is_set", return_value=True)
172+
def test_is_download_successful_with_event_set(self, mock_is_set):
173+
settings = Mock()
174+
result_link = Mock()
175+
handler = downloader.ResultSetDownloadHandler(settings, result_link)
176+
177+
status = handler.is_file_download_successful()
178+
assert status == handler.is_file_downloaded_successfully
179+
180+
@patch("threading.Event.is_set", return_value=False)
181+
@patch("threading.Semaphore.acquire", return_value=True)
182+
def test_is_file_download_successful_null_timeout_download_completes(self, mock_acquire, mock_is_set):
183+
settings = Mock()
184+
settings.download_timeout = None
185+
result_link = Mock()
186+
handler = downloader.ResultSetDownloadHandler(settings, result_link)
187+
188+
status = handler.is_file_download_successful()
189+
assert status == handler.is_file_downloaded_successfully
190+
mock_acquire.assert_called()
191+
192+
@patch("threading.Event.is_set", return_value=False)
193+
@patch("threading.Semaphore.acquire", return_value=True)
194+
def test_is_file_download_successful_zero_timeout_download_completes(self, mock_acquire, mock_is_set):
195+
settings = Mock()
196+
settings.download_timeout = 0
197+
result_link = Mock()
198+
handler = downloader.ResultSetDownloadHandler(settings, result_link)
199+
200+
status = handler.is_file_download_successful()
201+
assert status == handler.is_file_downloaded_successfully
202+
mock_acquire.assert_called()
203+
assert not handler.is_download_timedout
204+
205+
@patch("threading.Event.is_set", return_value=False)
206+
@patch("threading.Semaphore.acquire", return_value=True)
207+
def test_is_file_download_successful_with_timeout_download_completes(self, mock_acquire, mock_is_set):
208+
settings = Mock()
209+
settings.download_timeout = 10
210+
result_link = Mock()
211+
handler = downloader.ResultSetDownloadHandler(settings, result_link)
212+
213+
status = handler.is_file_download_successful()
214+
assert status == handler.is_file_downloaded_successfully
215+
mock_acquire.assert_called_with(timeout=settings.download_timeout)
216+
assert not handler.is_download_timedout
217+
218+
@patch("threading.Event.is_set", return_value=False)
219+
@patch("threading.Semaphore.acquire", return_value=False)
220+
def test_is_file_download_successful_download_times_out(self, mock_acquire, mock_is_set):
221+
settings = Mock()
222+
settings.download_timeout = 10
223+
result_link = Mock()
224+
result_link.fileLink = "foo"
225+
handler = downloader.ResultSetDownloadHandler(settings, result_link)
226+
227+
status = handler.is_file_download_successful()
228+
assert not status
229+
mock_acquire.assert_called_with(timeout=settings.download_timeout)
230+
assert handler.is_download_timedout

0 commit comments

Comments
 (0)