Skip to content

Commit e3d4efe

Browse files
rcypher-databricksandrefurlan-db
authored andcommitted
fixes for cloud fetch
backport to version 2 Signed-off-by: Andre Furlan <[email protected]>
1 parent a737ef3 commit e3d4efe

File tree

8 files changed

+252
-141
lines changed

8 files changed

+252
-141
lines changed

examples/custom_logger.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from databricks import sql
2+
import os
3+
import logging
4+
5+
6+
logger = logging.getLogger("databricks.sql")
7+
logger.setLevel(logging.DEBUG)
8+
fh = logging.FileHandler("pysqllogs.log")
9+
fh.setFormatter(logging.Formatter("%(asctime)s %(process)d %(thread)d %(message)s"))
10+
fh.setLevel(logging.DEBUG)
11+
logger.addHandler(fh)
12+
13+
with sql.connect(
14+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
15+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
16+
access_token=os.getenv("DATABRICKS_TOKEN"),
17+
use_cloud_fetch=True,
18+
max_download_threads = 2
19+
) as connection:
20+
21+
with connection.cursor(arraysize=1000, buffer_size_bytes=54857600) as cursor:
22+
print(
23+
"executing query: SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2"
24+
)
25+
cursor.execute("SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2")
26+
try:
27+
while True:
28+
row = cursor.fetchone()
29+
if row is None:
30+
break
31+
print(f"row: {row}")
32+
except sql.exc.ResultSetDownloadError as e:
33+
print(f"error: {e}")

src/databricks/sql/cloudfetch/download_manager.py

+8-32
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ResultSetDownloadHandler,
99
DownloadableResultSettings,
1010
)
11+
from databricks.sql.exc import ResultSetDownloadError
1112
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
1213

1314
logger = logging.getLogger(__name__)
@@ -34,8 +35,6 @@ def __init__(self, max_download_threads: int, lz4_compressed: bool):
3435
self.download_handlers: List[ResultSetDownloadHandler] = []
3536
self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1)
3637
self.downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
37-
self.fetch_need_retry = False
38-
self.num_consecutive_result_file_download_retries = 0
3938

4039
def add_file_links(
4140
self, t_spark_arrow_result_links: List[TSparkArrowResultLink]
@@ -81,13 +80,15 @@ def get_next_downloaded_file(
8180

8281
# Find next file
8382
idx = self._find_next_file_index(next_row_offset)
83+
# is this correct?
8484
if idx is None:
8585
self._shutdown_manager()
86+
logger.debug("could not find next file index")
8687
return None
8788
handler = self.download_handlers[idx]
8889

8990
# Check (and wait) for download status
90-
if self._check_if_download_successful(handler):
91+
if handler.is_file_download_successful():
9192
# Buffer should be empty so set buffer to new ArrowQueue with result_file
9293
result = DownloadedFile(
9394
handler.result_file,
@@ -97,9 +98,11 @@ def get_next_downloaded_file(
9798
self.download_handlers.pop(idx)
9899
# Return True upon successful download to continue loop and not force a retry
99100
return result
100-
# Download was not successful for next download item, force a retry
101+
# Download was not successful for next download item. Fail
101102
self._shutdown_manager()
102-
return None
103+
raise ResultSetDownloadError(
104+
f"Download failed for result set starting at {next_row_offset}"
105+
)
103106

104107
def _remove_past_handlers(self, next_row_offset: int):
105108
# Any link in which its start to end range doesn't include the next row to be fetched does not need downloading
@@ -133,33 +136,6 @@ def _find_next_file_index(self, next_row_offset: int):
133136
]
134137
return next_indices[0] if len(next_indices) > 0 else None
135138

136-
def _check_if_download_successful(self, handler: ResultSetDownloadHandler):
137-
# Check (and wait until download finishes) if download was successful
138-
if not handler.is_file_download_successful():
139-
if handler.is_link_expired:
140-
self.fetch_need_retry = True
141-
return False
142-
elif handler.is_download_timedout:
143-
# Consecutive file retries should not exceed threshold in settings
144-
if (
145-
self.num_consecutive_result_file_download_retries
146-
>= self.downloadable_result_settings.max_consecutive_file_download_retries
147-
):
148-
self.fetch_need_retry = True
149-
return False
150-
self.num_consecutive_result_file_download_retries += 1
151-
152-
# Re-submit handler run to thread pool and recursively check download status
153-
self.thread_pool.submit(handler.run)
154-
return self._check_if_download_successful(handler)
155-
else:
156-
self.fetch_need_retry = True
157-
return False
158-
159-
self.num_consecutive_result_file_download_retries = 0
160-
self.fetch_need_retry = False
161-
return True
162-
163139
def _shutdown_manager(self):
164140
# Clear download handlers and shutdown the thread pool
165141
self.download_handlers = []

src/databricks/sql/cloudfetch/downloader.py

+89-26
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import logging
22
from dataclasses import dataclass
3-
3+
from datetime import datetime
44
import requests
55
import lz4.frame
66
import threading
77
import time
8-
8+
import os
9+
from threading import get_ident
910
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
1011

1112
logger = logging.getLogger(__name__)
1213

14+
DEFAULT_CLOUD_FILE_TIMEOUT = int(os.getenv("DATABRICKS_CLOUD_FILE_TIMEOUT", 60))
15+
1316

1417
@dataclass
1518
class DownloadableResultSettings:
@@ -20,13 +23,17 @@ class DownloadableResultSettings:
2023
is_lz4_compressed (bool): Whether file is expected to be lz4 compressed.
2124
link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
2225
download_timeout (int): Timeout for download requests. Default 60 secs.
23-
max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
26+
download_max_retries (int): Number of consecutive download retries before shutting down.
27+
max_retries (int): Number of consecutive download retries before shutting down.
28+
backoff_factor (int): Factor to increase wait time between retries.
29+
2430
"""
2531

2632
is_lz4_compressed: bool
2733
link_expiry_buffer_secs: int = 0
28-
download_timeout: int = 60
29-
max_consecutive_file_download_retries: int = 0
34+
download_timeout: int = DEFAULT_CLOUD_FILE_TIMEOUT
35+
max_retries: int = 5
36+
backoff_factor: int = 2
3037

3138

3239
class ResultSetDownloadHandler(threading.Thread):
@@ -57,16 +64,21 @@ def is_file_download_successful(self) -> bool:
5764
else None
5865
)
5966
try:
67+
logger.debug(
68+
f"waiting for at most {timeout} seconds for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
69+
)
70+
6071
if not self.is_download_finished.wait(timeout=timeout):
6172
self.is_download_timedout = True
6273
logger.debug(
63-
"Cloud fetch download timed out after {} seconds for link representing rows {} to {}".format(
64-
self.settings.download_timeout,
65-
self.result_link.startRowOffset,
66-
self.result_link.startRowOffset + self.result_link.rowCount,
67-
)
74+
f"cloud fetch download timed out after {self.settings.download_timeout} seconds for link representing rows {self.result_link.startRowOffset} to {self.result_link.startRowOffset + self.result_link.rowCount}"
6875
)
69-
return False
76+
# there are some weird cases when the is_download_finished is not set, but the file is downloaded successfully
77+
return self.is_file_downloaded_successfully
78+
79+
logger.debug(
80+
f"finish waiting for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
81+
)
7082
except Exception as e:
7183
logger.error(e)
7284
return False
@@ -81,24 +93,36 @@ def run(self):
8193
"""
8294
self._reset()
8395

84-
# Check if link is already expired or is expiring
85-
if ResultSetDownloadHandler.check_link_expired(
86-
self.result_link, self.settings.link_expiry_buffer_secs
87-
):
88-
self.is_link_expired = True
89-
return
96+
try:
97+
# Check if link is already expired or is expiring
98+
if ResultSetDownloadHandler.check_link_expired(
99+
self.result_link, self.settings.link_expiry_buffer_secs
100+
):
101+
self.is_link_expired = True
102+
return
90103

91-
session = requests.Session()
92-
session.timeout = self.settings.download_timeout
104+
logger.debug(
105+
f"started to download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
106+
)
93107

94-
try:
95108
# Get the file via HTTP request
96-
response = session.get(self.result_link.fileLink)
109+
response = http_get_with_retry(
110+
url=self.result_link.fileLink,
111+
max_retries=self.settings.max_retries,
112+
backoff_factor=self.settings.backoff_factor,
113+
download_timeout=self.settings.download_timeout,
114+
)
97115

98-
if not response.ok:
99-
self.is_file_downloaded_successfully = False
116+
if not response:
117+
logger.error(
118+
f"failed downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
119+
)
100120
return
101121

122+
logger.debug(
123+
f"success downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
124+
)
125+
102126
# Save (and decompress if needed) the downloaded file
103127
compressed_data = response.content
104128
decompressed_data = (
@@ -109,15 +133,22 @@ def run(self):
109133
self.result_file = decompressed_data
110134

111135
# The size of the downloaded file should match the size specified from TSparkArrowResultLink
112-
self.is_file_downloaded_successfully = (
113-
len(self.result_file) == self.result_link.bytesNum
136+
success = len(self.result_file) == self.result_link.bytesNum
137+
logger.debug(
138+
f"download successful file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
114139
)
140+
self.is_file_downloaded_successfully = success
115141
except Exception as e:
142+
logger.debug(
143+
f"exception downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
144+
)
116145
logger.error(e)
117146
self.is_file_downloaded_successfully = False
118147

119148
finally:
120-
session and session.close()
149+
logger.debug(
150+
f"signal finished file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
151+
)
121152
# Awaken threads waiting for this to be true which signals the run is complete
122153
self.is_download_finished.set()
123154

@@ -145,6 +176,7 @@ def check_link_expired(
145176
link.expiryTime < current_time
146177
or link.expiryTime - current_time < expiry_buffer_secs
147178
):
179+
logger.debug("link expired")
148180
return True
149181
return False
150182

@@ -171,3 +203,34 @@ def decompress_data(compressed_data: bytes) -> bytes:
171203
uncompressed_data += data
172204
start += num_bytes
173205
return uncompressed_data
206+
207+
208+
def http_get_with_retry(url, max_retries=5, backoff_factor=2, download_timeout=60):
209+
attempts = 0
210+
211+
while attempts < max_retries:
212+
try:
213+
session = requests.Session()
214+
session.timeout = download_timeout
215+
response = session.get(url)
216+
217+
# Check if the response status code is in the 2xx range for success
218+
if response.status_code == 200:
219+
return response
220+
else:
221+
logger.error(response)
222+
except requests.RequestException as e:
223+
print(f"request failed with exception: {e}")
224+
finally:
225+
session.close()
226+
# Exponential backoff before the next attempt
227+
wait_time = backoff_factor**attempts
228+
logger.info(f"retrying in {wait_time} seconds...")
229+
time.sleep(wait_time)
230+
231+
attempts += 1
232+
233+
logger.error(
234+
f"exceeded maximum number of retries ({max_retries}) while downloading result."
235+
)
236+
return None

src/databricks/sql/exc.py

+4
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,7 @@ class SessionAlreadyClosedError(RequestError):
115115

116116
class CursorAlreadyClosedError(RequestError):
117117
"""Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected."""
118+
119+
120+
class ResultSetDownloadError(RequestError):
121+
"""Thrown if there was an error during the download of a result set"""

src/databricks/sql/thrift_backend.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,9 @@ def attempt_request(attempt):
371371

372372
this_method_name = getattr(method, "__name__")
373373

374-
logger.debug("Sending request: {}(<REDACTED>)".format(this_method_name))
374+
logger.debug(
375+
"sending thrift request: {}(<REDACTED>)".format(this_method_name)
376+
)
375377
unsafe_logger.debug("Sending request: {}".format(request))
376378

377379
# These three lines are no-ops if the v3 retry policy is not in use
@@ -387,7 +389,9 @@ def attempt_request(attempt):
387389

388390
# We need to call type(response) here because thrift doesn't implement __name__ attributes for thrift responses
389391
logger.debug(
390-
"Received response: {}(<REDACTED>)".format(type(response).__name__)
392+
"received thrift response: {}(<REDACTED>)".format(
393+
type(response).__name__
394+
)
391395
)
392396
unsafe_logger.debug("Received response: {}".format(response))
393397
return response
@@ -741,6 +745,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
741745
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
742746
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
743747
if direct_results and direct_results.resultSet:
748+
logger.debug(f"received direct results")
744749
assert direct_results.resultSet.results.startRowOffset == 0
745750
assert direct_results.resultSetMetadata
746751

@@ -753,6 +758,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
753758
description=description,
754759
)
755760
else:
761+
logger.debug(f"must fetch results")
756762
arrow_queue_opt = None
757763
return ExecuteResponse(
758764
arrow_queue=arrow_queue_opt,
@@ -816,6 +822,10 @@ def execute_command(
816822
):
817823
assert session_handle is not None
818824

825+
logger.debug(
826+
f"executing: cloud fetch: {use_cloud_fetch}, max rows: {max_rows}, max bytes: {max_bytes}"
827+
)
828+
819829
spark_arrow_types = ttypes.TSparkArrowTypes(
820830
timestampAsArrow=self._use_arrow_native_timestamps,
821831
decimalAsArrow=self._use_arrow_native_decimals,
@@ -930,6 +940,7 @@ def get_columns(
930940
return self._handle_execute_response(resp, cursor)
931941

932942
def _handle_execute_response(self, resp, cursor):
943+
logger.debug(f"got execute response")
933944
cursor.active_op_handle = resp.operationHandle
934945
self._check_direct_results_for_error(resp.directResults)
935946

@@ -950,6 +961,7 @@ def fetch_results(
950961
arrow_schema_bytes,
951962
description,
952963
):
964+
logger.debug("started to fetch results")
953965
assert op_handle is not None
954966

955967
req = ttypes.TFetchResultsReq(

0 commit comments

Comments
 (0)