1
1
import logging
2
2
3
- from concurrent .futures import ThreadPoolExecutor
4
- from dataclasses import dataclass
3
+ from concurrent .futures import ThreadPoolExecutor , Future
5
4
from typing import List , Union
6
5
7
6
from databricks .sql .cloudfetch .downloader import (
8
7
ResultSetDownloadHandler ,
9
8
DownloadableResultSettings ,
9
+ DownloadedFile ,
10
10
)
11
11
from databricks .sql .thrift_api .TCLIService .ttypes import TSparkArrowResultLink
12
12
13
13
logger = logging .getLogger (__name__ )
14
14
15
15
16
- @dataclass
17
- class DownloadedFile :
18
- """
19
- Class for the result file and metadata.
20
-
21
- Attributes:
22
- file_bytes (bytes): Downloaded file in bytes.
23
- start_row_offset (int): The offset of the starting row in relation to the full result.
24
- row_count (int): Number of rows the file represents in the result.
25
- """
26
-
27
- file_bytes : bytes
28
- start_row_offset : int
29
- row_count : int
30
-
31
-
32
16
class ResultFileDownloadManager :
33
- def __init__ (self , max_download_threads : int , lz4_compressed : bool ):
34
- self .download_handlers : List [ResultSetDownloadHandler ] = []
35
- self .thread_pool = ThreadPoolExecutor (max_workers = max_download_threads + 1 )
36
- self .downloadable_result_settings = DownloadableResultSettings (lz4_compressed )
37
- self .fetch_need_retry = False
38
- self .num_consecutive_result_file_download_retries = 0
39
-
40
- def add_file_links (
41
- self , t_spark_arrow_result_links : List [TSparkArrowResultLink ]
42
- ) -> None :
43
- """
44
- Create download handler for each cloud fetch link.
45
-
46
- Args:
47
- t_spark_arrow_result_links: List of cloud fetch links consisting of file URL and metadata.
48
- """
49
- for link in t_spark_arrow_result_links :
17
+ def __init__ (
18
+ self ,
19
+ links : List [TSparkArrowResultLink ],
20
+ max_download_threads : int ,
21
+ lz4_compressed : bool ,
22
+ ):
23
+ self ._pending_links : List [TSparkArrowResultLink ] = []
24
+ for link in links :
50
25
if link .rowCount <= 0 :
51
26
continue
52
27
logger .debug (
53
28
"ResultFileDownloadManager.add_file_links: start offset {}, row count: {}" .format (
54
29
link .startRowOffset , link .rowCount
55
30
)
56
31
)
57
- self .download_handlers .append (
58
- ResultSetDownloadHandler (self .downloadable_result_settings , link )
59
- )
32
+ self ._pending_links .append (link )
33
+
34
+ self ._download_tasks : List [Future [DownloadedFile ]] = []
35
+ self ._max_download_threads : int = max_download_threads + 1
36
+ self ._thread_pool = ThreadPoolExecutor (max_workers = self ._max_download_threads )
37
+
38
+ self ._downloadable_result_settings = DownloadableResultSettings (lz4_compressed )
60
39
61
40
def get_next_downloaded_file (
62
41
self , next_row_offset : int
@@ -73,143 +52,49 @@ def get_next_downloaded_file(
73
52
Args:
74
53
next_row_offset (int): The offset of the starting row of the next file we want data from.
75
54
"""
76
- # No more files to download from this batch of links
77
- if not self .download_handlers :
78
- self ._shutdown_manager ()
79
- return None
80
-
81
- # Remove handlers we don't need anymore
82
- self ._remove_past_handlers (next_row_offset )
83
55
84
- # Schedule the downloads
56
+ # Make sure the download queue is always full
85
57
self ._schedule_downloads ()
86
58
87
- # Find next file
88
- idx = self ._find_next_file_index (next_row_offset )
89
- if idx is None :
59
+ # No more files to download from this batch of links
60
+ if len (self ._download_tasks ) == 0 :
90
61
self ._shutdown_manager ()
91
62
return None
92
- handler = self .download_handlers [idx ]
93
63
94
- # Check (and wait) for download status
95
- if self ._check_if_download_successful (handler ):
96
- link = handler .result_link
97
- logger .debug (
98
- "ResultFileDownloadManager: file found for row index {}: start {}, row count: {}" .format (
99
- next_row_offset , link .startRowOffset , link .rowCount
100
- )
101
- )
102
- # Buffer should be empty so set buffer to new ArrowQueue with result_file
103
- result = DownloadedFile (
104
- handler .result_file ,
105
- handler .result_link .startRowOffset ,
106
- handler .result_link .rowCount ,
107
- )
108
- self .download_handlers .pop (idx )
109
- # Return True upon successful download to continue loop and not force a retry
110
- return result
111
- else :
64
+ task = self ._download_tasks .pop (0 )
65
+ # Future's `result()` method will wait for the call to complete, and return
66
+ # the value returned by the call. If the call throws an exception - `result()`
67
+ # will throw the same exception
68
+ file = task .result ()
69
+ if (next_row_offset < file .start_row_offset ) or (
70
+ next_row_offset > file .start_row_offset + file .row_count
71
+ ):
112
72
logger .debug (
113
- "ResultFileDownloadManager: cannot find file for row index {}" .format (
114
- next_row_offset
73
+ "ResultFileDownloadManager: file does not contain row {}, start {}, row count {}" .format (
74
+ next_row_offset , file . start_row_offset , file . row_count
115
75
)
116
76
)
117
77
118
- # Download was not successful for next download item, force a retry
119
- self ._shutdown_manager ()
120
- return None
121
-
122
- def _remove_past_handlers (self , next_row_offset : int ):
123
- logger .debug (
124
- "ResultFileDownloadManager: removing past handlers, current offset: {}" .format (
125
- next_row_offset
126
- )
127
- )
128
- # Any link in which its start to end range doesn't include the next row to be fetched does not need downloading
129
- i = 0
130
- while i < len (self .download_handlers ):
131
- result_link = self .download_handlers [i ].result_link
132
- logger .debug (
133
- "- checking result link: start {}, row count: {}, current offset: {}" .format (
134
- result_link .startRowOffset , result_link .rowCount , next_row_offset
135
- )
136
- )
137
- if result_link .startRowOffset + result_link .rowCount > next_row_offset :
138
- i += 1
139
- continue
140
- self .download_handlers .pop (i )
78
+ return file
141
79
142
80
def _schedule_downloads (self ):
143
- # Schedule downloads for all download handlers if not already scheduled.
81
+ """
82
+ While download queue has a capacity, peek pending links and submit them to thread pool.
83
+ """
144
84
logger .debug ("ResultFileDownloadManager: schedule downloads" )
145
- for handler in self .download_handlers :
146
- if handler .is_download_scheduled :
147
- continue
148
- try :
149
- logger .debug (
150
- "- start: {}, row count: {}" .format (
151
- handler .result_link .startRowOffset , handler .result_link .rowCount
152
- )
153
- )
154
- self .thread_pool .submit (handler .run )
155
- except Exception as e :
156
- logger .error (e )
157
- break
158
- handler .is_download_scheduled = True
159
-
160
- def _find_next_file_index (self , next_row_offset : int ):
161
- logger .debug (
162
- "ResultFileDownloadManager: trying to find file for row {}" .format (
163
- next_row_offset
164
- )
165
- )
166
- # Get the handler index of the next file in order
167
- next_indices = [
168
- i
169
- for i , handler in enumerate (self .download_handlers )
170
- if handler .is_download_scheduled
171
- # TODO: shouldn't `next_row_offset` be tested against the range, not just start row offset?
172
- and handler .result_link .startRowOffset == next_row_offset
173
- ]
174
-
175
- for i in next_indices :
176
- link = self .download_handlers [i ].result_link
85
+ while (len (self ._download_tasks ) < self ._max_download_threads ) and (
86
+ len (self ._pending_links ) > 0
87
+ ):
88
+ link = self ._pending_links .pop (0 )
177
89
logger .debug (
178
- "- found file: start {}, row count {}" .format (
179
- link .startRowOffset , link .rowCount
180
- )
90
+ "- start: {}, row count: {}" .format (link .startRowOffset , link .rowCount )
181
91
)
182
-
183
- return next_indices [0 ] if len (next_indices ) > 0 else None
184
-
185
- def _check_if_download_successful (self , handler : ResultSetDownloadHandler ):
186
- # Check (and wait until download finishes) if download was successful
187
- if not handler .is_file_download_successful ():
188
- if handler .is_link_expired :
189
- self .fetch_need_retry = True
190
- return False
191
- elif handler .is_download_timedout :
192
- # Consecutive file retries should not exceed threshold in settings
193
- if (
194
- self .num_consecutive_result_file_download_retries
195
- >= self .downloadable_result_settings .max_consecutive_file_download_retries
196
- ):
197
- self .fetch_need_retry = True
198
- return False
199
- self .num_consecutive_result_file_download_retries += 1
200
-
201
- # Re-submit handler run to thread pool and recursively check download status
202
- self .thread_pool .submit (handler .run )
203
- return self ._check_if_download_successful (handler )
204
- else :
205
- self .fetch_need_retry = True
206
- return False
207
-
208
- self .num_consecutive_result_file_download_retries = 0
209
- self .fetch_need_retry = False
210
- return True
92
+ handler = ResultSetDownloadHandler (self ._downloadable_result_settings , link )
93
+ task = self ._thread_pool .submit (handler .run )
94
+ self ._download_tasks .append (task )
211
95
212
96
def _shutdown_manager (self ):
213
97
# Clear download handlers and shutdown the thread pool
214
- self .download_handlers = []
215
- self .thread_pool .shutdown (wait = False )
98
+ self ._pending_links = []
99
+ self ._download_tasks = []
100
+ self ._thread_pool .shutdown (wait = False )
0 commit comments