1
1
import logging
2
2
from dataclasses import dataclass
3
-
4
3
import requests
5
4
import lz4 .frame
6
5
import threading
7
6
import time
8
-
7
+ import os
8
+ import re
9
9
from databricks .sql .thrift_api .TCLIService .ttypes import TSparkArrowResultLink
10
10
11
11
logger = logging .getLogger (__name__ )
12
12
13
+ DEFAULT_CLOUD_FILE_TIMEOUT = int (os .getenv ("DATABRICKS_CLOUD_FILE_TIMEOUT" , 60 ))
14
+
13
15
14
16
@dataclass
15
17
class DownloadableResultSettings :
@@ -20,13 +22,17 @@ class DownloadableResultSettings:
20
22
is_lz4_compressed (bool): Whether file is expected to be lz4 compressed.
21
23
link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
22
24
download_timeout (int): Timeout for download requests. Default 60 secs.
23
- max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
25
+ download_max_retries (int): Number of consecutive download retries before shutting down.
26
+ max_retries (int): Number of consecutive download retries before shutting down.
27
+ backoff_factor (int): Factor to increase wait time between retries.
28
+
24
29
"""
25
30
26
31
is_lz4_compressed : bool
27
32
link_expiry_buffer_secs : int = 0
28
- download_timeout : int = 60
29
- max_consecutive_file_download_retries : int = 0
33
+ download_timeout : int = DEFAULT_CLOUD_FILE_TIMEOUT
34
+ max_retries : int = 5
35
+ backoff_factor : int = 2
30
36
31
37
32
38
class ResultSetDownloadHandler (threading .Thread ):
@@ -57,16 +63,21 @@ def is_file_download_successful(self) -> bool:
57
63
else None
58
64
)
59
65
try :
66
+ logger .debug (
67
+ f"waiting for at most { timeout } seconds for download file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
68
+ )
69
+
60
70
if not self .is_download_finished .wait (timeout = timeout ):
61
71
self .is_download_timedout = True
62
- logger .debug (
63
- "Cloud fetch download timed out after {} seconds for link representing rows {} to {}" .format (
64
- self .settings .download_timeout ,
65
- self .result_link .startRowOffset ,
66
- self .result_link .startRowOffset + self .result_link .rowCount ,
67
- )
72
+ logger .error (
73
+ f"cloud fetch download timed out after { self .settings .download_timeout } seconds for link representing rows { self .result_link .startRowOffset } to { self .result_link .startRowOffset + self .result_link .rowCount } "
68
74
)
69
- return False
75
+ # there are some weird cases when the is_download_finished is not set, but the file is downloaded successfully
76
+ return self .is_file_downloaded_successfully
77
+
78
+ logger .debug (
79
+ f"finish waiting for download file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
80
+ )
70
81
except Exception as e :
71
82
logger .error (e )
72
83
return False
@@ -81,24 +92,36 @@ def run(self):
81
92
"""
82
93
self ._reset ()
83
94
84
- # Check if link is already expired or is expiring
85
- if ResultSetDownloadHandler .check_link_expired (
86
- self .result_link , self .settings .link_expiry_buffer_secs
87
- ):
88
- self .is_link_expired = True
89
- return
95
+ try :
96
+ # Check if link is already expired or is expiring
97
+ if ResultSetDownloadHandler .check_link_expired (
98
+ self .result_link , self .settings .link_expiry_buffer_secs
99
+ ):
100
+ self .is_link_expired = True
101
+ return
90
102
91
- session = requests .Session ()
92
- session .timeout = self .settings .download_timeout
103
+ logger .debug (
104
+ f"started to download file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
105
+ )
93
106
94
- try :
95
107
# Get the file via HTTP request
96
- response = session .get (self .result_link .fileLink )
108
+ response = http_get_with_retry (
109
+ url = self .result_link .fileLink ,
110
+ max_retries = self .settings .max_retries ,
111
+ backoff_factor = self .settings .backoff_factor ,
112
+ download_timeout = self .settings .download_timeout ,
113
+ )
97
114
98
- if not response .ok :
99
- self .is_file_downloaded_successfully = False
115
+ if not response :
116
+ logger .error (
117
+ f"failed downloading file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
118
+ )
100
119
return
101
120
121
+ logger .debug (
122
+ f"success downloading file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
123
+ )
124
+
102
125
# Save (and decompress if needed) the downloaded file
103
126
compressed_data = response .content
104
127
decompressed_data = (
@@ -109,15 +132,22 @@ def run(self):
109
132
self .result_file = decompressed_data
110
133
111
134
# The size of the downloaded file should match the size specified from TSparkArrowResultLink
112
- self .is_file_downloaded_successfully = (
113
- len (self .result_file ) == self .result_link .bytesNum
135
+ success = len (self .result_file ) == self .result_link .bytesNum
136
+ logger .debug (
137
+ f"download successful file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
114
138
)
139
+ self .is_file_downloaded_successfully = success
115
140
except Exception as e :
141
+ logger .error (
142
+ f"exception downloading file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
143
+ )
116
144
logger .error (e )
117
145
self .is_file_downloaded_successfully = False
118
146
119
147
finally :
120
- session and session .close ()
148
+ logger .debug (
149
+ f"signal finished file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
150
+ )
121
151
# Awaken threads waiting for this to be true which signals the run is complete
122
152
self .is_download_finished .set ()
123
153
@@ -145,6 +175,7 @@ def check_link_expired(
145
175
link .expiryTime < current_time
146
176
or link .expiryTime - current_time < expiry_buffer_secs
147
177
):
178
+ logger .debug ("link expired" )
148
179
return True
149
180
return False
150
181
@@ -171,3 +202,38 @@ def decompress_data(compressed_data: bytes) -> bytes:
171
202
uncompressed_data += data
172
203
start += num_bytes
173
204
return uncompressed_data
205
+
206
+
207
+ def http_get_with_retry (url , max_retries = 5 , backoff_factor = 2 , download_timeout = 60 ):
208
+ attempts = 0
209
+ pattern = re .compile (r"(\?|&)([\w-]+)=([^&\s]+)" )
210
+ mask = r"\1\2=<REDACTED>"
211
+
212
+ # TODO: introduce connection pooling. I am seeing weird errors without it.
213
+ while attempts < max_retries :
214
+ try :
215
+ session = requests .Session ()
216
+ session .timeout = download_timeout
217
+ response = session .get (url )
218
+
219
+ # Check if the response status code is in the 2xx range for success
220
+ if response .status_code == 200 :
221
+ return response
222
+ else :
223
+ logger .error (response )
224
+ except requests .RequestException as e :
225
+ # if this is not redacted, it will print the pre-signed URL
226
+ logger .error (f"request failed with exception: { re .sub (pattern , mask , str (e ))} " )
227
+ finally :
228
+ session .close ()
229
+ # Exponential backoff before the next attempt
230
+ wait_time = backoff_factor ** attempts
231
+ logger .info (f"retrying in { wait_time } seconds..." )
232
+ time .sleep (wait_time )
233
+
234
+ attempts += 1
235
+
236
+ logger .error (
237
+ f"exceeded maximum number of retries ({ max_retries } ) while downloading result."
238
+ )
239
+ return None
0 commit comments