1
1
import logging
2
2
from dataclasses import dataclass
3
-
3
+ from datetime import datetime
4
4
import requests
5
5
import lz4 .frame
6
6
import threading
7
7
import time
8
-
8
+ import os
9
+ from threading import get_ident
9
10
from databricks .sql .thrift_api .TCLIService .ttypes import TSparkArrowResultLink
10
11
11
12
logger = logging .getLogger (__name__ )
12
13
14
+ DEFAULT_CLOUD_FILE_TIMEOUT = int (os .getenv ("DATABRICKS_CLOUD_FILE_TIMEOUT" , 60 ))
15
+
13
16
14
17
@dataclass
15
18
class DownloadableResultSettings :
@@ -20,13 +23,17 @@ class DownloadableResultSettings:
20
23
is_lz4_compressed (bool): Whether file is expected to be lz4 compressed.
21
24
link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
22
25
download_timeout (int): Timeout for download requests. Default 60 secs.
23
- max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
26
+ download_max_retries (int): Number of consecutive download retries before shutting down.
27
+ max_retries (int): Number of consecutive download retries before shutting down.
28
+ backoff_factor (int): Factor to increase wait time between retries.
29
+
24
30
"""
25
31
26
32
is_lz4_compressed : bool
27
33
link_expiry_buffer_secs : int = 0
28
- download_timeout : int = 60
29
- max_consecutive_file_download_retries : int = 0
34
+ download_timeout : int = DEFAULT_CLOUD_FILE_TIMEOUT
35
+ max_retries : int = 5
36
+ backoff_factor : int = 2
30
37
31
38
32
39
class ResultSetDownloadHandler (threading .Thread ):
@@ -57,16 +64,21 @@ def is_file_download_successful(self) -> bool:
57
64
else None
58
65
)
59
66
try :
67
+ logger .debug (
68
+ f"waiting for at most { timeout } seconds for download file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
69
+ )
70
+
60
71
if not self .is_download_finished .wait (timeout = timeout ):
61
72
self .is_download_timedout = True
62
73
logger .debug (
63
- "Cloud fetch download timed out after {} seconds for link representing rows {} to {}" .format (
64
- self .settings .download_timeout ,
65
- self .result_link .startRowOffset ,
66
- self .result_link .startRowOffset + self .result_link .rowCount ,
67
- )
74
+ f"cloud fetch download timed out after { self .settings .download_timeout } seconds for link representing rows { self .result_link .startRowOffset } to { self .result_link .startRowOffset + self .result_link .rowCount } "
68
75
)
69
- return False
76
+ # there are some weird cases when the is_download_finished is not set, but the file is downloaded successfully
77
+ return self .is_file_downloaded_successfully
78
+
79
+ logger .debug (
80
+ f"finish waiting for download file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
81
+ )
70
82
except Exception as e :
71
83
logger .error (e )
72
84
return False
@@ -81,24 +93,36 @@ def run(self):
81
93
"""
82
94
self ._reset ()
83
95
84
- # Check if link is already expired or is expiring
85
- if ResultSetDownloadHandler .check_link_expired (
86
- self .result_link , self .settings .link_expiry_buffer_secs
87
- ):
88
- self .is_link_expired = True
89
- return
96
+ try :
97
+ # Check if link is already expired or is expiring
98
+ if ResultSetDownloadHandler .check_link_expired (
99
+ self .result_link , self .settings .link_expiry_buffer_secs
100
+ ):
101
+ self .is_link_expired = True
102
+ return
90
103
91
- session = requests .Session ()
92
- session .timeout = self .settings .download_timeout
104
+ logger .debug (
105
+ f"started to download file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
106
+ )
93
107
94
- try :
95
108
# Get the file via HTTP request
96
- response = session .get (self .result_link .fileLink )
109
+ response = http_get_with_retry (
110
+ url = self .result_link .fileLink ,
111
+ max_retries = self .settings .max_retries ,
112
+ backoff_factor = self .settings .backoff_factor ,
113
+ download_timeout = self .settings .download_timeout ,
114
+ )
97
115
98
- if not response .ok :
99
- self .is_file_downloaded_successfully = False
116
+ if not response :
117
+ logger .error (
118
+ f"failed downloading file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
119
+ )
100
120
return
101
121
122
+ logger .debug (
123
+ f"success downloading file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
124
+ )
125
+
102
126
# Save (and decompress if needed) the downloaded file
103
127
compressed_data = response .content
104
128
decompressed_data = (
@@ -109,15 +133,22 @@ def run(self):
109
133
self .result_file = decompressed_data
110
134
111
135
# The size of the downloaded file should match the size specified from TSparkArrowResultLink
112
- self .is_file_downloaded_successfully = (
113
- len (self .result_file ) == self .result_link .bytesNum
136
+ success = len (self .result_file ) == self .result_link .bytesNum
137
+ logger .debug (
138
+ f"download successful file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
114
139
)
140
+ self .is_file_downloaded_successfully = success
115
141
except Exception as e :
142
+ logger .debug (
143
+ f"exception downloading file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
144
+ )
116
145
logger .error (e )
117
146
self .is_file_downloaded_successfully = False
118
147
119
148
finally :
120
- session and session .close ()
149
+ logger .debug (
150
+ f"signal finished file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
151
+ )
121
152
# Awaken threads waiting for this to be true which signals the run is complete
122
153
self .is_download_finished .set ()
123
154
@@ -145,6 +176,7 @@ def check_link_expired(
145
176
link .expiryTime < current_time
146
177
or link .expiryTime - current_time < expiry_buffer_secs
147
178
):
179
+ logger .debug ("link expired" )
148
180
return True
149
181
return False
150
182
@@ -171,3 +203,34 @@ def decompress_data(compressed_data: bytes) -> bytes:
171
203
uncompressed_data += data
172
204
start += num_bytes
173
205
return uncompressed_data
206
+
207
+
208
+ def http_get_with_retry (url , max_retries = 5 , backoff_factor = 2 , download_timeout = 60 ):
209
+ attempts = 0
210
+
211
+ while attempts < max_retries :
212
+ try :
213
+ session = requests .Session ()
214
+ session .timeout = download_timeout
215
+ response = session .get (url )
216
+
217
+ # Check if the response status code is in the 2xx range for success
218
+ if response .status_code == 200 :
219
+ return response
220
+ else :
221
+ logger .error (response )
222
+ except requests .RequestException as e :
223
+ print (f"request failed with exception: { e } " )
224
+ finally :
225
+ session .close ()
226
+ # Exponential backoff before the next attempt
227
+ wait_time = backoff_factor ** attempts
228
+ logger .info (f"retrying in { wait_time } seconds..." )
229
+ time .sleep (wait_time )
230
+
231
+ attempts += 1
232
+
233
+ logger .error (
234
+ f"exceeded maximum number of retries ({ max_retries } ) while downloading result."
235
+ )
236
+ return None
0 commit comments