Skip to content

Fix for cloud fetch #362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Release History

# 3.1.1 (2024-02-21)

- Fix: Cloud fetch file download errors (#356)

# 3.1.0 (2024-02-16)

- Revert retry-after behavior to be exponential backoff (#349)
Expand Down
33 changes: 33 additions & 0 deletions examples/custom_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from databricks import sql
import os
import logging


logger = logging.getLogger("databricks.sql")
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler("pysqllogs.log")
fh.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(process)d %(thread)d %(message)s"))
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)

with sql.connect(
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
access_token=os.getenv("DATABRICKS_TOKEN"),
use_cloud_fetch=True,
max_download_threads = 2
) as connection:

with connection.cursor(arraysize=1000, buffer_size_bytes=54857600) as cursor:
print(
"executing query: SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2"
)
cursor.execute("SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2")
try:
while True:
row = cursor.fetchone()
if row is None:
break
print(f"row: {row}")
except sql.exc.ResultSetDownloadError as e:
print(f"error: {e}")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "databricks-sql-connector"
version = "3.1.0"
version = "3.1.1"
description = "Databricks SQL Connector for Python"
authors = ["Databricks <[email protected]>"]
license = "Apache-2.0"
Expand Down
2 changes: 1 addition & 1 deletion src/databricks/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __repr__(self):
DATE = DBAPITypeObject("date")
ROWID = DBAPITypeObject()

__version__ = "3.1.0"
__version__ = "3.1.1"
USER_AGENT_NAME = "PyDatabricksSqlConnector"

# These two functions are pyhive legacy
Expand Down
40 changes: 8 additions & 32 deletions src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ResultSetDownloadHandler,
DownloadableResultSettings,
)
from databricks.sql.exc import ResultSetDownloadError
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

logger = logging.getLogger(__name__)
Expand All @@ -34,8 +35,6 @@ def __init__(self, max_download_threads: int, lz4_compressed: bool):
self.download_handlers: List[ResultSetDownloadHandler] = []
self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1)
self.downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
self.fetch_need_retry = False
self.num_consecutive_result_file_download_retries = 0

def add_file_links(
self, t_spark_arrow_result_links: List[TSparkArrowResultLink]
Expand Down Expand Up @@ -81,13 +80,15 @@ def get_next_downloaded_file(

# Find next file
idx = self._find_next_file_index(next_row_offset)
# is this correct?
if idx is None:
self._shutdown_manager()
logger.debug("could not find next file index")
return None
handler = self.download_handlers[idx]

# Check (and wait) for download status
if self._check_if_download_successful(handler):
if handler.is_file_download_successful():
# Buffer should be empty so set buffer to new ArrowQueue with result_file
result = DownloadedFile(
handler.result_file,
Expand All @@ -97,9 +98,11 @@ def get_next_downloaded_file(
self.download_handlers.pop(idx)
# Return True upon successful download to continue loop and not force a retry
return result
# Download was not successful for next download item, force a retry
# Download was not successful for next download item. Fail
self._shutdown_manager()
return None
raise ResultSetDownloadError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per the change in the comment above, there is no retry attempted?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or is it just handled by raising the exception?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the retry is done outside this function, closer to the actual http request

f"Download failed for result set starting at {next_row_offset}"
)

def _remove_past_handlers(self, next_row_offset: int):
# Any link in which its start to end range doesn't include the next row to be fetched does not need downloading
Expand Down Expand Up @@ -133,33 +136,6 @@ def _find_next_file_index(self, next_row_offset: int):
]
return next_indices[0] if len(next_indices) > 0 else None

def _check_if_download_successful(self, handler: ResultSetDownloadHandler):
# Check (and wait until download finishes) if download was successful
if not handler.is_file_download_successful():
if handler.is_link_expired:
self.fetch_need_retry = True
return False
elif handler.is_download_timedout:
# Consecutive file retries should not exceed threshold in settings
if (
self.num_consecutive_result_file_download_retries
>= self.downloadable_result_settings.max_consecutive_file_download_retries
):
self.fetch_need_retry = True
return False
self.num_consecutive_result_file_download_retries += 1

# Re-submit handler run to thread pool and recursively check download status
self.thread_pool.submit(handler.run)
return self._check_if_download_successful(handler)
else:
self.fetch_need_retry = True
return False

self.num_consecutive_result_file_download_retries = 0
self.fetch_need_retry = False
return True

def _shutdown_manager(self):
# Clear download handlers and shutdown the thread pool
self.download_handlers = []
Expand Down
122 changes: 95 additions & 27 deletions src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import logging
from dataclasses import dataclass

import requests
import lz4.frame
import threading
import time

import os
import re
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

logger = logging.getLogger(__name__)

DEFAULT_CLOUD_FILE_TIMEOUT = int(os.getenv("DATABRICKS_CLOUD_FILE_TIMEOUT", 180))


@dataclass
class DownloadableResultSettings:
Expand All @@ -20,13 +22,17 @@ class DownloadableResultSettings:
is_lz4_compressed (bool): Whether file is expected to be lz4 compressed.
link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
download_timeout (int): Timeout for download requests. Default 60 secs.
max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
download_max_retries (int): Number of consecutive download retries before shutting down.
max_retries (int): Number of consecutive download retries before shutting down.
backoff_factor (int): Factor to increase wait time between retries.

"""

is_lz4_compressed: bool
link_expiry_buffer_secs: int = 0
download_timeout: int = 60
max_consecutive_file_download_retries: int = 0
download_timeout: int = DEFAULT_CLOUD_FILE_TIMEOUT
max_retries: int = 5
backoff_factor: int = 2


class ResultSetDownloadHandler(threading.Thread):
Expand Down Expand Up @@ -57,16 +63,21 @@ def is_file_download_successful(self) -> bool:
else None
)
try:
logger.debug(
f"waiting for at most {timeout} seconds for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)

if not self.is_download_finished.wait(timeout=timeout):
self.is_download_timedout = True
logger.debug(
"Cloud fetch download timed out after {} seconds for link representing rows {} to {}".format(
self.settings.download_timeout,
self.result_link.startRowOffset,
self.result_link.startRowOffset + self.result_link.rowCount,
)
logger.error(
f"cloud fetch download timed out after {self.settings.download_timeout} seconds for link representing rows {self.result_link.startRowOffset} to {self.result_link.startRowOffset + self.result_link.rowCount}"
)
return False
# there are some weird cases when the is_download_finished is not set, but the file is downloaded successfully
return self.is_file_downloaded_successfully

logger.debug(
f"finish waiting for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
except Exception as e:
logger.error(e)
return False
Expand All @@ -81,24 +92,36 @@ def run(self):
"""
self._reset()

# Check if link is already expired or is expiring
if ResultSetDownloadHandler.check_link_expired(
self.result_link, self.settings.link_expiry_buffer_secs
):
self.is_link_expired = True
return
try:
# Check if link is already expired or is expiring
if ResultSetDownloadHandler.check_link_expired(
self.result_link, self.settings.link_expiry_buffer_secs
):
self.is_link_expired = True
return

session = requests.Session()
session.timeout = self.settings.download_timeout
logger.debug(
f"started to download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)

try:
# Get the file via HTTP request
response = session.get(self.result_link.fileLink)
response = http_get_with_retry(
url=self.result_link.fileLink,
max_retries=self.settings.max_retries,
backoff_factor=self.settings.backoff_factor,
download_timeout=self.settings.download_timeout,
)

if not response.ok:
self.is_file_downloaded_successfully = False
if not response:
logger.error(
f"failed downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
return

logger.debug(
f"success downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)

# Save (and decompress if needed) the downloaded file
compressed_data = response.content
decompressed_data = (
Expand All @@ -109,15 +132,22 @@ def run(self):
self.result_file = decompressed_data

# The size of the downloaded file should match the size specified from TSparkArrowResultLink
self.is_file_downloaded_successfully = (
len(self.result_file) == self.result_link.bytesNum
success = len(self.result_file) == self.result_link.bytesNum
logger.debug(
f"download successful file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
self.is_file_downloaded_successfully = success
except Exception as e:
logger.error(
f"exception downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
logger.error(e)
self.is_file_downloaded_successfully = False

finally:
session and session.close()
logger.debug(
f"signal finished file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
# Awaken threads waiting for this to be true which signals the run is complete
self.is_download_finished.set()

Expand Down Expand Up @@ -145,6 +175,7 @@ def check_link_expired(
link.expiryTime < current_time
or link.expiryTime - current_time < expiry_buffer_secs
):
logger.debug("link expired")
return True
return False

Expand All @@ -171,3 +202,40 @@ def decompress_data(compressed_data: bytes) -> bytes:
uncompressed_data += data
start += num_bytes
return uncompressed_data


def http_get_with_retry(url, max_retries=5, backoff_factor=2, download_timeout=60):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we implementing retry behavior here rather than using a Retry passed to the session?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. It is in the TODO to also have connection pools

attempts = 0
pattern = re.compile(r"(\?|&)([\w-]+)=([^&\s]+)")
mask = r"\1\2=<REDACTED>"

# TODO: introduce connection pooling. I am seeing weird errors without it.
while attempts < max_retries:
try:
session = requests.Session()
session.timeout = download_timeout
response = session.get(url)

# Check if the response status code is in the 2xx range for success
if response.status_code == 200:
return response
else:
logger.error(response)
except Exception as e:
# if this is not redacted, it will print the pre-signed URL
logger.error(
f"request failed with exception: {re.sub(pattern, mask, str(e))}"
)
finally:
session.close()
# Exponential backoff before the next attempt
wait_time = backoff_factor**attempts
logger.info(f"retrying in {wait_time} seconds...")
time.sleep(wait_time)

attempts += 1

logger.error(
f"exceeded maximum number of retries ({max_retries}) while downloading result."
)
return None
4 changes: 4 additions & 0 deletions src/databricks/sql/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,7 @@ class SessionAlreadyClosedError(RequestError):

class CursorAlreadyClosedError(RequestError):
"""Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected."""


class ResultSetDownloadError(RequestError):
"""Thrown if there was an error during the download of a result set"""
Loading