Skip to content

Commit 6a348ec

Browse files
fixes for cloud fetch - part un (#356)
* fixes for cloud fetch Signed-off-by: Andre Furlan <[email protected]> --------- Signed-off-by: Andre Furlan <[email protected]> Co-authored-by: Raymond Cypher <[email protected]>
1 parent a737ef3 commit 6a348ec

File tree

9 files changed

+229
-143
lines changed

9 files changed

+229
-143
lines changed

examples/custom_logger.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from databricks import sql
2+
import os
3+
import logging
4+
5+
6+
logger = logging.getLogger("databricks.sql")
7+
logger.setLevel(logging.DEBUG)
8+
fh = logging.FileHandler("pysqllogs.log")
9+
fh.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(process)d %(thread)d %(message)s"))
10+
fh.setLevel(logging.DEBUG)
11+
logger.addHandler(fh)
12+
13+
with sql.connect(
14+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
15+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
16+
access_token=os.getenv("DATABRICKS_TOKEN"),
17+
use_cloud_fetch=True,
18+
max_download_threads = 2
19+
) as connection:
20+
21+
with connection.cursor(arraysize=1000, buffer_size_bytes=54857600) as cursor:
22+
print(
23+
"executing query: SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2"
24+
)
25+
cursor.execute("SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2")
26+
try:
27+
while True:
28+
row = cursor.fetchone()
29+
if row is None:
30+
break
31+
print(f"row: {row}")
32+
except sql.exc.ResultSetDownloadError as e:
33+
print(f"error: {e}")

src/databricks/sql/__init__.py

+32
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,38 @@
77
threadsafety = 1 # Threads may share the module, but not connections.
88
paramstyle = "pyformat" # Python extended format codes, e.g. ...WHERE name=%(name)s
99

10+
import re
11+
12+
13+
class RedactUrlQueryParamsFilter(logging.Filter):
14+
pattern = re.compile(r"(\?|&)([\w-]+)=([^&\s]+)")
15+
mask = r"\1\2=<REDACTED>"
16+
17+
def __init__(self):
18+
super().__init__()
19+
20+
def redact(self, string):
21+
return re.sub(self.pattern, self.mask, str(string))
22+
23+
def filter(self, record):
24+
record.msg = self.redact(str(record.msg))
25+
if isinstance(record.args, dict):
26+
for k in record.args.keys():
27+
record.args[k] = (
28+
self.redact(record.args[k])
29+
if isinstance(record.arg[k], str)
30+
else record.args[k]
31+
)
32+
else:
33+
record.args = tuple(
34+
(self.redact(arg) if isinstance(arg, str) else arg)
35+
for arg in record.args
36+
)
37+
38+
return True
39+
40+
41+
logging.getLogger("urllib3.connectionpool").addFilter(RedactUrlQueryParamsFilter())
1042

1143
class DBAPITypeObject(object):
1244
def __init__(self, *values):

src/databricks/sql/cloudfetch/download_manager.py

+8-32
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ResultSetDownloadHandler,
99
DownloadableResultSettings,
1010
)
11+
from databricks.sql.exc import ResultSetDownloadError
1112
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
1213

1314
logger = logging.getLogger(__name__)
@@ -34,8 +35,6 @@ def __init__(self, max_download_threads: int, lz4_compressed: bool):
3435
self.download_handlers: List[ResultSetDownloadHandler] = []
3536
self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1)
3637
self.downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
37-
self.fetch_need_retry = False
38-
self.num_consecutive_result_file_download_retries = 0
3938

4039
def add_file_links(
4140
self, t_spark_arrow_result_links: List[TSparkArrowResultLink]
@@ -81,13 +80,15 @@ def get_next_downloaded_file(
8180

8281
# Find next file
8382
idx = self._find_next_file_index(next_row_offset)
83+
# is this correct?
8484
if idx is None:
8585
self._shutdown_manager()
86+
logger.debug("could not find next file index")
8687
return None
8788
handler = self.download_handlers[idx]
8889

8990
# Check (and wait) for download status
90-
if self._check_if_download_successful(handler):
91+
if handler.is_file_download_successful():
9192
# Buffer should be empty so set buffer to new ArrowQueue with result_file
9293
result = DownloadedFile(
9394
handler.result_file,
@@ -97,9 +98,11 @@ def get_next_downloaded_file(
9798
self.download_handlers.pop(idx)
9899
# Return True upon successful download to continue loop and not force a retry
99100
return result
100-
# Download was not successful for next download item, force a retry
101+
# Download was not successful for next download item. Fail
101102
self._shutdown_manager()
102-
return None
103+
raise ResultSetDownloadError(
104+
f"Download failed for result set starting at {next_row_offset}"
105+
)
103106

104107
def _remove_past_handlers(self, next_row_offset: int):
105108
# Any link in which its start to end range doesn't include the next row to be fetched does not need downloading
@@ -133,33 +136,6 @@ def _find_next_file_index(self, next_row_offset: int):
133136
]
134137
return next_indices[0] if len(next_indices) > 0 else None
135138

136-
def _check_if_download_successful(self, handler: ResultSetDownloadHandler):
137-
# Check (and wait until download finishes) if download was successful
138-
if not handler.is_file_download_successful():
139-
if handler.is_link_expired:
140-
self.fetch_need_retry = True
141-
return False
142-
elif handler.is_download_timedout:
143-
# Consecutive file retries should not exceed threshold in settings
144-
if (
145-
self.num_consecutive_result_file_download_retries
146-
>= self.downloadable_result_settings.max_consecutive_file_download_retries
147-
):
148-
self.fetch_need_retry = True
149-
return False
150-
self.num_consecutive_result_file_download_retries += 1
151-
152-
# Re-submit handler run to thread pool and recursively check download status
153-
self.thread_pool.submit(handler.run)
154-
return self._check_if_download_successful(handler)
155-
else:
156-
self.fetch_need_retry = True
157-
return False
158-
159-
self.num_consecutive_result_file_download_retries = 0
160-
self.fetch_need_retry = False
161-
return True
162-
163139
def _shutdown_manager(self):
164140
# Clear download handlers and shutdown the thread pool
165141
self.download_handlers = []

src/databricks/sql/cloudfetch/downloader.py

+93-27
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import logging
22
from dataclasses import dataclass
3-
43
import requests
54
import lz4.frame
65
import threading
76
import time
8-
7+
import os
8+
import re
99
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
1010

1111
logger = logging.getLogger(__name__)
1212

13+
DEFAULT_CLOUD_FILE_TIMEOUT = int(os.getenv("DATABRICKS_CLOUD_FILE_TIMEOUT", 60))
14+
1315

1416
@dataclass
1517
class DownloadableResultSettings:
@@ -20,13 +22,17 @@ class DownloadableResultSettings:
2022
is_lz4_compressed (bool): Whether file is expected to be lz4 compressed.
2123
link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
2224
download_timeout (int): Timeout for download requests. Default 60 secs.
23-
max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
25+
download_max_retries (int): Number of consecutive download retries before shutting down.
26+
max_retries (int): Number of consecutive download retries before shutting down.
27+
backoff_factor (int): Factor to increase wait time between retries.
28+
2429
"""
2530

2631
is_lz4_compressed: bool
2732
link_expiry_buffer_secs: int = 0
28-
download_timeout: int = 60
29-
max_consecutive_file_download_retries: int = 0
33+
download_timeout: int = DEFAULT_CLOUD_FILE_TIMEOUT
34+
max_retries: int = 5
35+
backoff_factor: int = 2
3036

3137

3238
class ResultSetDownloadHandler(threading.Thread):
@@ -57,16 +63,21 @@ def is_file_download_successful(self) -> bool:
5763
else None
5864
)
5965
try:
66+
logger.debug(
67+
f"waiting for at most {timeout} seconds for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
68+
)
69+
6070
if not self.is_download_finished.wait(timeout=timeout):
6171
self.is_download_timedout = True
62-
logger.debug(
63-
"Cloud fetch download timed out after {} seconds for link representing rows {} to {}".format(
64-
self.settings.download_timeout,
65-
self.result_link.startRowOffset,
66-
self.result_link.startRowOffset + self.result_link.rowCount,
67-
)
72+
logger.error(
73+
f"cloud fetch download timed out after {self.settings.download_timeout} seconds for link representing rows {self.result_link.startRowOffset} to {self.result_link.startRowOffset + self.result_link.rowCount}"
6874
)
69-
return False
75+
# there are some weird cases when the is_download_finished is not set, but the file is downloaded successfully
76+
return self.is_file_downloaded_successfully
77+
78+
logger.debug(
79+
f"finish waiting for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
80+
)
7081
except Exception as e:
7182
logger.error(e)
7283
return False
@@ -81,24 +92,36 @@ def run(self):
8192
"""
8293
self._reset()
8394

84-
# Check if link is already expired or is expiring
85-
if ResultSetDownloadHandler.check_link_expired(
86-
self.result_link, self.settings.link_expiry_buffer_secs
87-
):
88-
self.is_link_expired = True
89-
return
95+
try:
96+
# Check if link is already expired or is expiring
97+
if ResultSetDownloadHandler.check_link_expired(
98+
self.result_link, self.settings.link_expiry_buffer_secs
99+
):
100+
self.is_link_expired = True
101+
return
90102

91-
session = requests.Session()
92-
session.timeout = self.settings.download_timeout
103+
logger.debug(
104+
f"started to download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
105+
)
93106

94-
try:
95107
# Get the file via HTTP request
96-
response = session.get(self.result_link.fileLink)
108+
response = http_get_with_retry(
109+
url=self.result_link.fileLink,
110+
max_retries=self.settings.max_retries,
111+
backoff_factor=self.settings.backoff_factor,
112+
download_timeout=self.settings.download_timeout,
113+
)
97114

98-
if not response.ok:
99-
self.is_file_downloaded_successfully = False
115+
if not response:
116+
logger.error(
117+
f"failed downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
118+
)
100119
return
101120

121+
logger.debug(
122+
f"success downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
123+
)
124+
102125
# Save (and decompress if needed) the downloaded file
103126
compressed_data = response.content
104127
decompressed_data = (
@@ -109,15 +132,22 @@ def run(self):
109132
self.result_file = decompressed_data
110133

111134
# The size of the downloaded file should match the size specified from TSparkArrowResultLink
112-
self.is_file_downloaded_successfully = (
113-
len(self.result_file) == self.result_link.bytesNum
135+
success = len(self.result_file) == self.result_link.bytesNum
136+
logger.debug(
137+
f"download successful file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
114138
)
139+
self.is_file_downloaded_successfully = success
115140
except Exception as e:
141+
logger.error(
142+
f"exception downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
143+
)
116144
logger.error(e)
117145
self.is_file_downloaded_successfully = False
118146

119147
finally:
120-
session and session.close()
148+
logger.debug(
149+
f"signal finished file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
150+
)
121151
# Awaken threads waiting for this to be true which signals the run is complete
122152
self.is_download_finished.set()
123153

@@ -145,6 +175,7 @@ def check_link_expired(
145175
link.expiryTime < current_time
146176
or link.expiryTime - current_time < expiry_buffer_secs
147177
):
178+
logger.debug("link expired")
148179
return True
149180
return False
150181

@@ -171,3 +202,38 @@ def decompress_data(compressed_data: bytes) -> bytes:
171202
uncompressed_data += data
172203
start += num_bytes
173204
return uncompressed_data
205+
206+
207+
def http_get_with_retry(url, max_retries=5, backoff_factor=2, download_timeout=60):
208+
attempts = 0
209+
pattern = re.compile(r"(\?|&)([\w-]+)=([^&\s]+)")
210+
mask = r"\1\2=<REDACTED>"
211+
212+
# TODO: introduce connection pooling. I am seeing weird errors without it.
213+
while attempts < max_retries:
214+
try:
215+
session = requests.Session()
216+
session.timeout = download_timeout
217+
response = session.get(url)
218+
219+
# Check if the response status code is in the 2xx range for success
220+
if response.status_code == 200:
221+
return response
222+
else:
223+
logger.error(response)
224+
except requests.RequestException as e:
225+
# if this is not redacted, it will print the pre-signed URL
226+
logger.error(f"request failed with exception: {re.sub(pattern, mask, str(e))}")
227+
finally:
228+
session.close()
229+
# Exponential backoff before the next attempt
230+
wait_time = backoff_factor**attempts
231+
logger.info(f"retrying in {wait_time} seconds...")
232+
time.sleep(wait_time)
233+
234+
attempts += 1
235+
236+
logger.error(
237+
f"exceeded maximum number of retries ({max_retries}) while downloading result."
238+
)
239+
return None

src/databricks/sql/exc.py

+4
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,7 @@ class SessionAlreadyClosedError(RequestError):
115115

116116
class CursorAlreadyClosedError(RequestError):
117117
"""Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected."""
118+
119+
120+
class ResultSetDownloadError(RequestError):
121+
"""Thrown if there was an error during the download of a result set"""

0 commit comments

Comments
 (0)