Skip to content

fixes for cloud fetch - part un #356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions examples/custom_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from databricks import sql
import os
import logging


logger = logging.getLogger("databricks.sql")
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler("pysqllogs.log")
fh.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(process)d %(thread)d %(message)s"))
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)

with sql.connect(
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
access_token=os.getenv("DATABRICKS_TOKEN"),
use_cloud_fetch=True,
max_download_threads = 2
) as connection:

with connection.cursor(arraysize=1000, buffer_size_bytes=54857600) as cursor:
print(
"executing query: SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2"
)
cursor.execute("SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2")
try:
while True:
row = cursor.fetchone()
if row is None:
break
print(f"row: {row}")
except sql.exc.ResultSetDownloadError as e:
print(f"error: {e}")
32 changes: 32 additions & 0 deletions src/databricks/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,38 @@
threadsafety = 1 # Threads may share the module, but not connections.
paramstyle = "pyformat" # Python extended format codes, e.g. ...WHERE name=%(name)s

import re


class RedactUrlQueryParamsFilter(logging.Filter):
pattern = re.compile(r"(\?|&)([\w-]+)=([^&\s]+)")
mask = r"\1\2=<REDACTED>"

def __init__(self):
super().__init__()

def redact(self, string):
return re.sub(self.pattern, self.mask, str(string))

def filter(self, record):
record.msg = self.redact(str(record.msg))
if isinstance(record.args, dict):
for k in record.args.keys():
record.args[k] = (
self.redact(record.args[k])
if isinstance(record.arg[k], str)
else record.args[k]
)
else:
record.args = tuple(
(self.redact(arg) if isinstance(arg, str) else arg)
for arg in record.args
)

return True


logging.getLogger("urllib3.connectionpool").addFilter(RedactUrlQueryParamsFilter())

class DBAPITypeObject(object):
def __init__(self, *values):
Expand Down
40 changes: 8 additions & 32 deletions src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ResultSetDownloadHandler,
DownloadableResultSettings,
)
from databricks.sql.exc import ResultSetDownloadError
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

logger = logging.getLogger(__name__)
Expand All @@ -34,8 +35,6 @@ def __init__(self, max_download_threads: int, lz4_compressed: bool):
self.download_handlers: List[ResultSetDownloadHandler] = []
self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1)
self.downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
self.fetch_need_retry = False
self.num_consecutive_result_file_download_retries = 0

def add_file_links(
self, t_spark_arrow_result_links: List[TSparkArrowResultLink]
Expand Down Expand Up @@ -81,13 +80,15 @@ def get_next_downloaded_file(

# Find next file
idx = self._find_next_file_index(next_row_offset)
# is this correct?
if idx is None:
self._shutdown_manager()
logger.debug("could not find next file index")
return None
handler = self.download_handlers[idx]

# Check (and wait) for download status
if self._check_if_download_successful(handler):
if handler.is_file_download_successful():
# Buffer should be empty so set buffer to new ArrowQueue with result_file
result = DownloadedFile(
handler.result_file,
Expand All @@ -97,9 +98,11 @@ def get_next_downloaded_file(
self.download_handlers.pop(idx)
# Return True upon successful download to continue loop and not force a retry
return result
# Download was not successful for next download item, force a retry
# Download was not successful for next download item. Fail
self._shutdown_manager()
return None
raise ResultSetDownloadError(
f"Download failed for result set starting at {next_row_offset}"
)

def _remove_past_handlers(self, next_row_offset: int):
# Any link in which its start to end range doesn't include the next row to be fetched does not need downloading
Expand Down Expand Up @@ -133,33 +136,6 @@ def _find_next_file_index(self, next_row_offset: int):
]
return next_indices[0] if len(next_indices) > 0 else None

def _check_if_download_successful(self, handler: ResultSetDownloadHandler):
# Check (and wait until download finishes) if download was successful
if not handler.is_file_download_successful():
if handler.is_link_expired:
self.fetch_need_retry = True
return False
elif handler.is_download_timedout:
# Consecutive file retries should not exceed threshold in settings
if (
self.num_consecutive_result_file_download_retries
>= self.downloadable_result_settings.max_consecutive_file_download_retries
):
self.fetch_need_retry = True
return False
self.num_consecutive_result_file_download_retries += 1

# Re-submit handler run to thread pool and recursively check download status
self.thread_pool.submit(handler.run)
return self._check_if_download_successful(handler)
else:
self.fetch_need_retry = True
return False

self.num_consecutive_result_file_download_retries = 0
self.fetch_need_retry = False
return True

def _shutdown_manager(self):
# Clear download handlers and shutdown the thread pool
self.download_handlers = []
Expand Down
120 changes: 93 additions & 27 deletions src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import logging
from dataclasses import dataclass

import requests
import lz4.frame
import threading
import time

import os
import re
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

logger = logging.getLogger(__name__)

DEFAULT_CLOUD_FILE_TIMEOUT = int(os.getenv("DATABRICKS_CLOUD_FILE_TIMEOUT", 60))


@dataclass
class DownloadableResultSettings:
Expand All @@ -20,13 +22,17 @@ class DownloadableResultSettings:
is_lz4_compressed (bool): Whether file is expected to be lz4 compressed.
link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
download_timeout (int): Timeout for download requests. Default 60 secs.
max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
download_max_retries (int): Number of consecutive download retries before shutting down.
max_retries (int): Number of consecutive download retries before shutting down.
backoff_factor (int): Factor to increase wait time between retries.

"""

is_lz4_compressed: bool
link_expiry_buffer_secs: int = 0
download_timeout: int = 60
max_consecutive_file_download_retries: int = 0
download_timeout: int = DEFAULT_CLOUD_FILE_TIMEOUT
max_retries: int = 5
backoff_factor: int = 2


class ResultSetDownloadHandler(threading.Thread):
Expand Down Expand Up @@ -57,16 +63,21 @@ def is_file_download_successful(self) -> bool:
else None
)
try:
logger.debug(
f"waiting for at most {timeout} seconds for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)

if not self.is_download_finished.wait(timeout=timeout):
self.is_download_timedout = True
logger.debug(
"Cloud fetch download timed out after {} seconds for link representing rows {} to {}".format(
self.settings.download_timeout,
self.result_link.startRowOffset,
self.result_link.startRowOffset + self.result_link.rowCount,
)
logger.error(
f"cloud fetch download timed out after {self.settings.download_timeout} seconds for link representing rows {self.result_link.startRowOffset} to {self.result_link.startRowOffset + self.result_link.rowCount}"
)
return False
# there are some weird cases when the is_download_finished is not set, but the file is downloaded successfully
return self.is_file_downloaded_successfully

logger.debug(
f"finish waiting for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
except Exception as e:
logger.error(e)
return False
Expand All @@ -81,24 +92,36 @@ def run(self):
"""
self._reset()

# Check if link is already expired or is expiring
if ResultSetDownloadHandler.check_link_expired(
self.result_link, self.settings.link_expiry_buffer_secs
):
self.is_link_expired = True
return
try:
# Check if link is already expired or is expiring
if ResultSetDownloadHandler.check_link_expired(
self.result_link, self.settings.link_expiry_buffer_secs
):
self.is_link_expired = True
return

session = requests.Session()
session.timeout = self.settings.download_timeout
logger.debug(
f"started to download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)

try:
# Get the file via HTTP request
response = session.get(self.result_link.fileLink)
response = http_get_with_retry(
url=self.result_link.fileLink,
max_retries=self.settings.max_retries,
backoff_factor=self.settings.backoff_factor,
download_timeout=self.settings.download_timeout,
)

if not response.ok:
self.is_file_downloaded_successfully = False
if not response:
logger.error(
f"failed downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
return

logger.debug(
f"success downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)

# Save (and decompress if needed) the downloaded file
compressed_data = response.content
decompressed_data = (
Expand All @@ -109,15 +132,22 @@ def run(self):
self.result_file = decompressed_data

# The size of the downloaded file should match the size specified from TSparkArrowResultLink
self.is_file_downloaded_successfully = (
len(self.result_file) == self.result_link.bytesNum
success = len(self.result_file) == self.result_link.bytesNum
logger.debug(
f"download successful file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
self.is_file_downloaded_successfully = success
except Exception as e:
logger.error(
f"exception downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
logger.error(e)
self.is_file_downloaded_successfully = False

finally:
session and session.close()
logger.debug(
f"signal finished file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
# Awaken threads waiting for this to be true which signals the run is complete
self.is_download_finished.set()

Expand Down Expand Up @@ -145,6 +175,7 @@ def check_link_expired(
link.expiryTime < current_time
or link.expiryTime - current_time < expiry_buffer_secs
):
logger.debug("link expired")
return True
return False

Expand All @@ -171,3 +202,38 @@ def decompress_data(compressed_data: bytes) -> bytes:
uncompressed_data += data
start += num_bytes
return uncompressed_data


def http_get_with_retry(url, max_retries=5, backoff_factor=2, download_timeout=60):
attempts = 0
pattern = re.compile(r"(\?|&)([\w-]+)=([^&\s]+)")
mask = r"\1\2=<REDACTED>"

# TODO: introduce connection pooling. I am seeing weird errors without it.
while attempts < max_retries:
try:
session = requests.Session()
session.timeout = download_timeout
response = session.get(url)

# Check if the response status code is in the 2xx range for success
if response.status_code == 200:
return response
else:
logger.error(response)
except requests.RequestException as e:
# if this is not redacted, it will print the pre-signed URL
logger.error(f"request failed with exception: {re.sub(pattern, mask, str(e))}")
finally:
session.close()
# Exponential backoff before the next attempt
wait_time = backoff_factor**attempts
logger.info(f"retrying in {wait_time} seconds...")
time.sleep(wait_time)

attempts += 1

logger.error(
f"exceeded maximum number of retries ({max_retries}) while downloading result."
)
return None
4 changes: 4 additions & 0 deletions src/databricks/sql/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,7 @@ class SessionAlreadyClosedError(RequestError):

class CursorAlreadyClosedError(RequestError):
"""Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected."""


class ResultSetDownloadError(RequestError):
"""Thrown if there was an error during the download of a result set"""
Loading