-
Notifications
You must be signed in to change notification settings - Fork 105
Fix for cloud fetch #362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Fix for cloud fetch #362
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from databricks import sql | ||
import os | ||
import logging | ||
|
||
|
||
logger = logging.getLogger("databricks.sql") | ||
logger.setLevel(logging.DEBUG) | ||
fh = logging.FileHandler("pysqllogs.log") | ||
fh.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(process)d %(thread)d %(message)s")) | ||
fh.setLevel(logging.DEBUG) | ||
logger.addHandler(fh) | ||
|
||
with sql.connect( | ||
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), | ||
http_path=os.getenv("DATABRICKS_HTTP_PATH"), | ||
access_token=os.getenv("DATABRICKS_TOKEN"), | ||
use_cloud_fetch=True, | ||
max_download_threads = 2 | ||
) as connection: | ||
|
||
with connection.cursor(arraysize=1000, buffer_size_bytes=54857600) as cursor: | ||
print( | ||
"executing query: SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2" | ||
) | ||
cursor.execute("SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2") | ||
try: | ||
while True: | ||
row = cursor.fetchone() | ||
if row is None: | ||
break | ||
print(f"row: {row}") | ||
except sql.exc.ResultSetDownloadError as e: | ||
print(f"error: {e}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[tool.poetry] | ||
name = "databricks-sql-connector" | ||
version = "3.1.0" | ||
version = "3.1.1" | ||
description = "Databricks SQL Connector for Python" | ||
authors = ["Databricks <[email protected]>"] | ||
license = "Apache-2.0" | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,17 @@ | ||
import logging | ||
from dataclasses import dataclass | ||
|
||
import requests | ||
import lz4.frame | ||
import threading | ||
import time | ||
|
||
import os | ||
import re | ||
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
DEFAULT_CLOUD_FILE_TIMEOUT = int(os.getenv("DATABRICKS_CLOUD_FILE_TIMEOUT", 180)) | ||
|
||
|
||
@dataclass | ||
class DownloadableResultSettings: | ||
|
@@ -20,13 +22,17 @@ class DownloadableResultSettings: | |
is_lz4_compressed (bool): Whether file is expected to be lz4 compressed. | ||
link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs. | ||
download_timeout (int): Timeout for download requests. Default 60 secs. | ||
max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down. | ||
download_max_retries (int): Number of consecutive download retries before shutting down. | ||
max_retries (int): Number of consecutive download retries before shutting down. | ||
backoff_factor (int): Factor to increase wait time between retries. | ||
|
||
""" | ||
|
||
is_lz4_compressed: bool | ||
link_expiry_buffer_secs: int = 0 | ||
download_timeout: int = 60 | ||
max_consecutive_file_download_retries: int = 0 | ||
download_timeout: int = DEFAULT_CLOUD_FILE_TIMEOUT | ||
max_retries: int = 5 | ||
backoff_factor: int = 2 | ||
|
||
|
||
class ResultSetDownloadHandler(threading.Thread): | ||
|
@@ -57,16 +63,21 @@ def is_file_download_successful(self) -> bool: | |
else None | ||
) | ||
try: | ||
logger.debug( | ||
f"waiting for at most {timeout} seconds for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" | ||
) | ||
|
||
if not self.is_download_finished.wait(timeout=timeout): | ||
self.is_download_timedout = True | ||
logger.debug( | ||
"Cloud fetch download timed out after {} seconds for link representing rows {} to {}".format( | ||
self.settings.download_timeout, | ||
self.result_link.startRowOffset, | ||
self.result_link.startRowOffset + self.result_link.rowCount, | ||
) | ||
logger.error( | ||
f"cloud fetch download timed out after {self.settings.download_timeout} seconds for link representing rows {self.result_link.startRowOffset} to {self.result_link.startRowOffset + self.result_link.rowCount}" | ||
) | ||
return False | ||
# there are some weird cases when the is_download_finished is not set, but the file is downloaded successfully | ||
return self.is_file_downloaded_successfully | ||
|
||
logger.debug( | ||
f"finish waiting for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" | ||
) | ||
except Exception as e: | ||
logger.error(e) | ||
return False | ||
|
@@ -81,24 +92,36 @@ def run(self): | |
""" | ||
self._reset() | ||
|
||
# Check if link is already expired or is expiring | ||
if ResultSetDownloadHandler.check_link_expired( | ||
self.result_link, self.settings.link_expiry_buffer_secs | ||
): | ||
self.is_link_expired = True | ||
return | ||
try: | ||
# Check if link is already expired or is expiring | ||
if ResultSetDownloadHandler.check_link_expired( | ||
self.result_link, self.settings.link_expiry_buffer_secs | ||
): | ||
self.is_link_expired = True | ||
return | ||
|
||
session = requests.Session() | ||
session.timeout = self.settings.download_timeout | ||
logger.debug( | ||
f"started to download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" | ||
) | ||
|
||
try: | ||
# Get the file via HTTP request | ||
response = session.get(self.result_link.fileLink) | ||
response = http_get_with_retry( | ||
url=self.result_link.fileLink, | ||
max_retries=self.settings.max_retries, | ||
backoff_factor=self.settings.backoff_factor, | ||
download_timeout=self.settings.download_timeout, | ||
) | ||
|
||
if not response.ok: | ||
self.is_file_downloaded_successfully = False | ||
if not response: | ||
logger.error( | ||
f"failed downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" | ||
) | ||
return | ||
|
||
logger.debug( | ||
f"success downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" | ||
) | ||
|
||
# Save (and decompress if needed) the downloaded file | ||
compressed_data = response.content | ||
decompressed_data = ( | ||
|
@@ -109,15 +132,22 @@ def run(self): | |
self.result_file = decompressed_data | ||
|
||
# The size of the downloaded file should match the size specified from TSparkArrowResultLink | ||
self.is_file_downloaded_successfully = ( | ||
len(self.result_file) == self.result_link.bytesNum | ||
success = len(self.result_file) == self.result_link.bytesNum | ||
logger.debug( | ||
f"download successful file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" | ||
) | ||
self.is_file_downloaded_successfully = success | ||
except Exception as e: | ||
logger.error( | ||
f"exception downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" | ||
) | ||
logger.error(e) | ||
self.is_file_downloaded_successfully = False | ||
|
||
finally: | ||
session and session.close() | ||
logger.debug( | ||
f"signal finished file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" | ||
) | ||
# Awaken threads waiting for this to be true which signals the run is complete | ||
self.is_download_finished.set() | ||
|
||
|
@@ -145,6 +175,7 @@ def check_link_expired( | |
link.expiryTime < current_time | ||
or link.expiryTime - current_time < expiry_buffer_secs | ||
): | ||
logger.debug("link expired") | ||
return True | ||
return False | ||
|
||
|
@@ -171,3 +202,40 @@ def decompress_data(compressed_data: bytes) -> bytes: | |
uncompressed_data += data | ||
start += num_bytes | ||
return uncompressed_data | ||
|
||
|
||
def http_get_with_retry(url, max_retries=5, backoff_factor=2, download_timeout=60): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are we implementing retry behavior here rather than using a Retry passed to the session? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed. It is in the TODO to also have connection pools |
||
attempts = 0 | ||
pattern = re.compile(r"(\?|&)([\w-]+)=([^&\s]+)") | ||
mask = r"\1\2=<REDACTED>" | ||
|
||
# TODO: introduce connection pooling. I am seeing weird errors without it. | ||
while attempts < max_retries: | ||
try: | ||
session = requests.Session() | ||
session.timeout = download_timeout | ||
response = session.get(url) | ||
|
||
# Check if the response status code is in the 2xx range for success | ||
if response.status_code == 200: | ||
return response | ||
else: | ||
logger.error(response) | ||
except Exception as e: | ||
# if this is not redacted, it will print the pre-signed URL | ||
logger.error( | ||
f"request failed with exception: {re.sub(pattern, mask, str(e))}" | ||
) | ||
finally: | ||
session.close() | ||
# Exponential backoff before the next attempt | ||
wait_time = backoff_factor**attempts | ||
logger.info(f"retrying in {wait_time} seconds...") | ||
time.sleep(wait_time) | ||
|
||
attempts += 1 | ||
|
||
logger.error( | ||
f"exceeded maximum number of retries ({max_retries}) while downloading result." | ||
) | ||
return None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per the change in the comment above, there is no retry attempted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or is it just handled by raising the exception?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the retry is done outside this function, closer to the actual http request