Skip to content

Commit 78838a0

Browse files
mattdeekayJesse Whitehouse
authored and
Jesse Whitehouse
committed
Cloud Fetch download manager (databricks#146)
* Cloud Fetch download manager Signed-off-by: Matthew Kim <[email protected]> * Bug fix: submit handler.run Signed-off-by: Matthew Kim <[email protected]> * Type annotations Signed-off-by: Matthew Kim <[email protected]> * Namedtuple -> dataclass Signed-off-by: Matthew Kim <[email protected]> * Shutdown thread pool and clear handlers Signed-off-by: Matthew Kim <[email protected]> * Docstrings and comments Signed-off-by: Matthew Kim <[email protected]> * handler.run is the correct call Signed-off-by: Matthew Kim <[email protected]> * Link expiry buffer in secs Signed-off-by: Matthew Kim <[email protected]> * Adding type annotations for download_handlers and downloadable_result_settings Signed-off-by: Matthew Kim <[email protected]> * Move DownloadableResultSettings to downloader.py to avoid circular import Signed-off-by: Matthew Kim <[email protected]> * Black linting Signed-off-by: Matthew Kim <[email protected]> * Timeout is never None Signed-off-by: Matthew Kim <[email protected]> --------- Signed-off-by: Matthew Kim <[email protected]>
1 parent 794be47 commit 78838a0

File tree

4 files changed

+399
-4
lines changed

4 files changed

+399
-4
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import logging
2+
3+
from concurrent.futures import ThreadPoolExecutor
4+
from dataclasses import dataclass
5+
from typing import List, Union
6+
7+
from databricks.sql.cloudfetch.downloader import (
8+
ResultSetDownloadHandler,
9+
DownloadableResultSettings,
10+
)
11+
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
@dataclass
17+
class DownloadedFile:
18+
"""
19+
Class for the result file and metadata.
20+
21+
Attributes:
22+
file_bytes (bytes): Downloaded file in bytes.
23+
start_row_offset (int): The offset of the starting row in relation to the full result.
24+
row_count (int): Number of rows the file represents in the result.
25+
"""
26+
27+
file_bytes: bytes
28+
start_row_offset: int
29+
row_count: int
30+
31+
32+
class ResultFileDownloadManager:
33+
def __init__(self, max_download_threads: int, lz4_compressed: bool):
34+
self.download_handlers: List[ResultSetDownloadHandler] = []
35+
self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1)
36+
self.downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
37+
self.fetch_need_retry = False
38+
self.num_consecutive_result_file_download_retries = 0
39+
40+
def add_file_links(
41+
self, t_spark_arrow_result_links: List[TSparkArrowResultLink]
42+
) -> None:
43+
"""
44+
Create download handler for each cloud fetch link.
45+
46+
Args:
47+
t_spark_arrow_result_links: List of cloud fetch links consisting of file URL and metadata.
48+
"""
49+
for link in t_spark_arrow_result_links:
50+
if link.rowCount <= 0:
51+
continue
52+
self.download_handlers.append(
53+
ResultSetDownloadHandler(self.downloadable_result_settings, link)
54+
)
55+
56+
def get_next_downloaded_file(
57+
self, next_row_offset: int
58+
) -> Union[DownloadedFile, None]:
59+
"""
60+
Get next file that starts at given offset.
61+
62+
This function gets the next downloaded file in which its rows start at the specified next_row_offset
63+
in relation to the full result. File downloads are scheduled if not already, and once the correct
64+
download handler is located, the function waits for the download status and returns the resulting file.
65+
If there are no more downloads, a download was not successful, or the correct file could not be located,
66+
this function shuts down the thread pool and returns None.
67+
68+
Args:
69+
next_row_offset (int): The offset of the starting row of the next file we want data from.
70+
"""
71+
# No more files to download from this batch of links
72+
if not self.download_handlers:
73+
self._shutdown_manager()
74+
return None
75+
76+
# Remove handlers we don't need anymore
77+
self._remove_past_handlers(next_row_offset)
78+
79+
# Schedule the downloads
80+
self._schedule_downloads()
81+
82+
# Find next file
83+
idx = self._find_next_file_index(next_row_offset)
84+
if idx is None:
85+
self._shutdown_manager()
86+
return None
87+
handler = self.download_handlers[idx]
88+
89+
# Check (and wait) for download status
90+
if self._check_if_download_successful(handler):
91+
# Buffer should be empty so set buffer to new ArrowQueue with result_file
92+
result = DownloadedFile(
93+
handler.result_file,
94+
handler.result_link.startRowOffset,
95+
handler.result_link.rowCount,
96+
)
97+
self.download_handlers.pop(idx)
98+
# Return True upon successful download to continue loop and not force a retry
99+
return result
100+
# Download was not successful for next download item, force a retry
101+
self._shutdown_manager()
102+
return None
103+
104+
def _remove_past_handlers(self, next_row_offset: int):
105+
# Any link in which its start to end range doesn't include the next row to be fetched does not need downloading
106+
i = 0
107+
while i < len(self.download_handlers):
108+
result_link = self.download_handlers[i].result_link
109+
if result_link.startRowOffset + result_link.rowCount > next_row_offset:
110+
i += 1
111+
continue
112+
self.download_handlers.pop(i)
113+
114+
def _schedule_downloads(self):
115+
# Schedule downloads for all download handlers if not already scheduled.
116+
for handler in self.download_handlers:
117+
if handler.is_download_scheduled:
118+
continue
119+
try:
120+
self.thread_pool.submit(handler.run)
121+
except Exception as e:
122+
logger.error(e)
123+
break
124+
handler.is_download_scheduled = True
125+
126+
def _find_next_file_index(self, next_row_offset: int):
127+
# Get the handler index of the next file in order
128+
next_indices = [
129+
i
130+
for i, handler in enumerate(self.download_handlers)
131+
if handler.is_download_scheduled
132+
and handler.result_link.startRowOffset == next_row_offset
133+
]
134+
return next_indices[0] if len(next_indices) > 0 else None
135+
136+
def _check_if_download_successful(self, handler: ResultSetDownloadHandler):
137+
# Check (and wait until download finishes) if download was successful
138+
if not handler.is_file_download_successful():
139+
if handler.is_link_expired:
140+
self.fetch_need_retry = True
141+
return False
142+
elif handler.is_download_timedout:
143+
# Consecutive file retries should not exceed threshold in settings
144+
if (
145+
self.num_consecutive_result_file_download_retries
146+
>= self.downloadable_result_settings.max_consecutive_file_download_retries
147+
):
148+
self.fetch_need_retry = True
149+
return False
150+
self.num_consecutive_result_file_download_retries += 1
151+
152+
# Re-submit handler run to thread pool and recursively check download status
153+
self.thread_pool.submit(handler.run)
154+
return self._check_if_download_successful(handler)
155+
else:
156+
self.fetch_need_retry = True
157+
return False
158+
159+
self.num_consecutive_result_file_download_retries = 0
160+
self.fetch_need_retry = False
161+
return True
162+
163+
def _shutdown_manager(self):
164+
# Clear download handlers and shutdown the thread pool to cancel pending futures
165+
self.download_handlers = []
166+
self.thread_pool.shutdown(wait=False, cancel_futures=True)

src/databricks/sql/cloudfetch/downloader.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from dataclasses import dataclass
23

34
import requests
45
import lz4.frame
@@ -10,10 +11,28 @@
1011
logger = logging.getLogger(__name__)
1112

1213

14+
@dataclass
15+
class DownloadableResultSettings:
16+
"""
17+
Class for settings common to each download handler.
18+
19+
Attributes:
20+
is_lz4_compressed (bool): Whether file is expected to be lz4 compressed.
21+
link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
22+
download_timeout (int): Timeout for download requests. Default 60 secs.
23+
max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
24+
"""
25+
26+
is_lz4_compressed: bool
27+
link_expiry_buffer_secs: int = 0
28+
download_timeout: int = 60
29+
max_consecutive_file_download_retries: int = 0
30+
31+
1332
class ResultSetDownloadHandler(threading.Thread):
1433
def __init__(
1534
self,
16-
downloadable_result_settings,
35+
downloadable_result_settings: DownloadableResultSettings,
1736
t_spark_arrow_result_link: TSparkArrowResultLink,
1837
):
1938
super().__init__()
@@ -32,8 +51,11 @@ def is_file_download_successful(self) -> bool:
3251
3352
This function will block until a file download finishes or until a timeout.
3453
"""
35-
timeout = self.settings.download_timeout
36-
timeout = timeout if timeout and timeout > 0 else None
54+
timeout = (
55+
self.settings.download_timeout
56+
if self.settings.download_timeout > 0
57+
else None
58+
)
3759
try:
3860
if not self.is_download_finished.wait(timeout=timeout):
3961
self.is_download_timedout = True

0 commit comments

Comments
 (0)