Skip to content

Cloud Fetch download handler #127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 109 additions & 64 deletions src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@
import threading
import time

from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

logger = logging.getLogger(__name__)


class ResultSetDownloadHandler(threading.Thread):

def __init__(self, downloadable_result_settings, t_spark_arrow_result_link):
def __init__(
self,
downloadable_result_settings,
t_spark_arrow_result_link: TSparkArrowResultLink,
):
super().__init__()
self.settings = downloadable_result_settings
self.result_link = t_spark_arrow_result_link
Expand All @@ -19,85 +24,125 @@ def __init__(self, downloadable_result_settings, t_spark_arrow_result_link):
self.is_file_downloaded_successfully = False
self.is_link_expired = False
self.is_download_timedout = False
self.http_code = None
self.result_file = None
self.check_result_file_link_expiry = True
self.download_completion_semaphore = threading.Semaphore(0)

def is_file_download_successful(self):
def is_file_download_successful(self) -> bool:
"""
Check and report if cloud fetch file downloaded successfully.

This function will block until a file download finishes or until a timeout.
"""
timeout = self.settings.download_timeout
timeout = timeout if timeout and timeout > 0 else None
try:
if not self.is_download_finished.is_set():
if self.settings.download_timeout and self.settings.download_timeout > 0:
if not self.download_completion_semaphore.acquire(timeout=self.settings.download_timeout):
self.is_download_timedout = True
logger.debug("Cloud fetch download timed out after {} seconds for url: {}"
.format(self.settings.download_timeout, self.result_link.file_link)
)
return False
else:
self.download_completion_semaphore.acquire()
except:
if not self.is_download_finished.wait(timeout=timeout):
self.is_download_timedout = True
logger.debug(
"Cloud fetch download timed out after {} seconds for url: {}".format(
self.settings.download_timeout,
self.result_link.fileLink,
)
)
return False
except Exception as e:
logger.error(e)
return False
return self.is_file_downloaded_successfully

def run(self):
self.is_file_downloaded_successfully = False
self.is_link_expired = False
self.is_download_timedout = False
self.is_download_finished = threading.Event()
"""
Download the file described in the cloud fetch link.

if self.check_result_file_link_expiry:
current_time = int(time.time() * 1000)
if (self.result_link.expiryTime < current_time) or (
self.result_link.expiryTime - current_time < (
self.settings.result_file_link_expiry_buffer * 1000)
):
self.is_link_expired = True
return

session = requests.Session()
session.timeout = self.settings.download_timeout
This function checks if the link has or is expiring, gets the file via a requests session, decompresses the
file, and signals to waiting threads that the download is finished and whether it was successful.
"""
self._reset()

if (
self.settings.use_proxy
and not self.settings.disable_proxy_for_cloud_fetch
# Check if link is already expired or is expiring
if ResultSetDownloadHandler.check_link_expired(
self.result_link, self.settings.link_expiry_buffer
):
proxy = {
"http": f"http://{self.settings.proxy_host}:{self.settings.proxy_port}",
"https": f"https://{self.settings.proxy_host}:{self.settings.proxy_port}",
}
session.proxies.update(proxy)
self.is_link_expired = True
return

# ProxyAuthentication -> static enum BASIC and NONE
if self.settings.proxy_auth == "BASIC":
session.auth = requests.auth.HTTPBasicAuth(self.settings.proxy_uid, self.settings.proxy_pwd)
session = requests.Session()
session.timeout = self.settings.download_timeout

try:
# Get the file via HTTP request
response = session.get(self.result_link.fileLink)
self.http_code = response.status_code

if self.http_code != 200:
if not response.ok:
self.is_file_downloaded_successfully = False
else:
if self.settings.is_lz4_compressed:
compressed_data = response.content
uncompressed_data, bytes_read = lz4.frame.decompress(compressed_data, return_bytes_read=True)
if bytes_read < len(compressed_data):
d_context = lz4.frame.create_decompression_context()
start = 0
uncompressed_data = bytearray()
while start < len(compressed_data):
data, num_bytes, e = lz4.frame.decompress_chunk(d_context, compressed_data[start:])
uncompressed_data += data
start += num_bytes
self.result_file = uncompressed_data
else:
self.result_file = response.content
self.is_file_downloaded_successfully = len(self.result_file) == self.result_link.bytesNum
except:
return

# Save (and decompress if needed) the downloaded file
compressed_data = response.content
decompressed_data = (
ResultSetDownloadHandler.decompress_data(compressed_data)
if self.settings.is_lz4_compressed
else compressed_data
)
self.result_file = decompressed_data

# The size of the downloaded file should match the size specified from TSparkArrowResultLink
self.is_file_downloaded_successfully = (
len(self.result_file) == self.result_link.bytesNum
)
except Exception as e:
logger.error(e)
self.is_file_downloaded_successfully = False

finally:
session.close()
session and session.close()
# Awaken threads waiting for this to be true which signals the run is complete
self.is_download_finished.set()
self.download_completion_semaphore.release()

def _reset(self):
# Reset download-related flags for every retry of run()
self.is_file_downloaded_successfully = False
self.is_link_expired = False
self.is_download_timedout = False
self.is_download_finished = threading.Event()

@staticmethod
def check_link_expired(
link: TSparkArrowResultLink, expiry_buffer_secs: int
) -> bool:
"""
Check if a link has expired or will expire.

Expiry buffer can be set to avoid downloading files that has not expired yet when the function is called,
but may expire before the file has fully downloaded.
"""
current_time = int(time.time())
if (
link.expiryTime < current_time
or link.expiryTime - current_time < expiry_buffer_secs
):
return True
return False

@staticmethod
def decompress_data(compressed_data: bytes) -> bytes:
"""
Decompress lz4 frame compressed data.

Decompresses data that has been lz4 compressed, either via the whole frame or by series of chunks.
"""
uncompressed_data, bytes_read = lz4.frame.decompress(
compressed_data, return_bytes_read=True
)
# The last cloud fetch file of the entire result is commonly punctuated by frequent end-of-frame markers.
# Full frame decompression above will short-circuit, so chunking is necessary
if bytes_read < len(compressed_data):
d_context = lz4.frame.create_decompression_context()
start = 0
uncompressed_data = bytearray()
while start < len(compressed_data):
data, num_bytes, is_end = lz4.frame.decompress_chunk(
d_context, compressed_data[start:]
)
uncompressed_data += data
start += num_bytes
return uncompressed_data
Loading