From 7bc38d6e75ebe7ccb9a770c11c94ef9ee10952f8 Mon Sep 17 00:00:00 2001 From: Andre Furlan Date: Fri, 16 Feb 2024 10:32:29 -0800 Subject: [PATCH 1/2] fixes for cloud fetch - part un (#356) * fixes for cloud fetch Signed-off-by: Andre Furlan --------- Signed-off-by: Andre Furlan Co-authored-by: Raymond Cypher --- examples/custom_logger.py | 33 +++++ src/databricks/sql/__init__.py | 32 +++++ .../sql/cloudfetch/download_manager.py | 40 ++---- src/databricks/sql/cloudfetch/downloader.py | 120 ++++++++++++++---- src/databricks/sql/exc.py | 4 + src/databricks/sql/thrift_backend.py | 16 ++- src/databricks/sql/utils.py | 8 ++ tests/unit/test_download_manager.py | 69 +--------- tests/unit/test_downloader.py | 50 ++++++-- 9 files changed, 229 insertions(+), 143 deletions(-) create mode 100644 examples/custom_logger.py diff --git a/examples/custom_logger.py b/examples/custom_logger.py new file mode 100644 index 00000000..8475c118 --- /dev/null +++ b/examples/custom_logger.py @@ -0,0 +1,33 @@ +from databricks import sql +import os +import logging + + +logger = logging.getLogger("databricks.sql") +logger.setLevel(logging.DEBUG) +fh = logging.FileHandler("pysqllogs.log") +fh.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(process)d %(thread)d %(message)s")) +fh.setLevel(logging.DEBUG) +logger.addHandler(fh) + +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), + use_cloud_fetch=True, + max_download_threads = 2 +) as connection: + + with connection.cursor(arraysize=1000, buffer_size_bytes=54857600) as cursor: + print( + "executing query: SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2" + ) + cursor.execute("SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2") + try: + while True: + row = cursor.fetchone() + if row is None: + break + print(f"row: {row}") + except sql.exc.ResultSetDownloadError as e: + print(f"error: {e}") diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 004250f6..e1bc2e91 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -41,6 +41,38 @@ def filter(self, record): logging.getLogger("urllib3.connectionpool").addFilter(RedactUrlQueryParamsFilter()) +import re + + +class RedactUrlQueryParamsFilter(logging.Filter): + pattern = re.compile(r"(\?|&)([\w-]+)=([^&\s]+)") + mask = r"\1\2=" + + def __init__(self): + super().__init__() + + def redact(self, string): + return re.sub(self.pattern, self.mask, str(string)) + + def filter(self, record): + record.msg = self.redact(str(record.msg)) + if isinstance(record.args, dict): + for k in record.args.keys(): + record.args[k] = ( + self.redact(record.args[k]) + if isinstance(record.arg[k], str) + else record.args[k] + ) + else: + record.args = tuple( + (self.redact(arg) if isinstance(arg, str) else arg) + for arg in record.args + ) + + return True + + +logging.getLogger("urllib3.connectionpool").addFilter(RedactUrlQueryParamsFilter()) class DBAPITypeObject(object): def __init__(self, *values): diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 9a997f39..015d00bf 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -8,6 +8,7 @@ ResultSetDownloadHandler, DownloadableResultSettings, ) +from databricks.sql.exc import ResultSetDownloadError from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink logger = logging.getLogger(__name__) @@ -34,8 +35,6 @@ def __init__(self, max_download_threads: int, lz4_compressed: bool): self.download_handlers: List[ResultSetDownloadHandler] = [] self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1) self.downloadable_result_settings = DownloadableResultSettings(lz4_compressed) - self.fetch_need_retry = False - self.num_consecutive_result_file_download_retries = 0 def add_file_links( self, t_spark_arrow_result_links: List[TSparkArrowResultLink] @@ -81,13 +80,15 @@ def get_next_downloaded_file( # Find next file idx = self._find_next_file_index(next_row_offset) + # is this correct? if idx is None: self._shutdown_manager() + logger.debug("could not find next file index") return None handler = self.download_handlers[idx] # Check (and wait) for download status - if self._check_if_download_successful(handler): + if handler.is_file_download_successful(): # Buffer should be empty so set buffer to new ArrowQueue with result_file result = DownloadedFile( handler.result_file, @@ -97,9 +98,11 @@ def get_next_downloaded_file( self.download_handlers.pop(idx) # Return True upon successful download to continue loop and not force a retry return result - # Download was not successful for next download item, force a retry + # Download was not successful for next download item. Fail self._shutdown_manager() - return None + raise ResultSetDownloadError( + f"Download failed for result set starting at {next_row_offset}" + ) def _remove_past_handlers(self, next_row_offset: int): # Any link in which its start to end range doesn't include the next row to be fetched does not need downloading @@ -133,33 +136,6 @@ def _find_next_file_index(self, next_row_offset: int): ] return next_indices[0] if len(next_indices) > 0 else None - def _check_if_download_successful(self, handler: ResultSetDownloadHandler): - # Check (and wait until download finishes) if download was successful - if not handler.is_file_download_successful(): - if handler.is_link_expired: - self.fetch_need_retry = True - return False - elif handler.is_download_timedout: - # Consecutive file retries should not exceed threshold in settings - if ( - self.num_consecutive_result_file_download_retries - >= self.downloadable_result_settings.max_consecutive_file_download_retries - ): - self.fetch_need_retry = True - return False - self.num_consecutive_result_file_download_retries += 1 - - # Re-submit handler run to thread pool and recursively check download status - self.thread_pool.submit(handler.run) - return self._check_if_download_successful(handler) - else: - self.fetch_need_retry = True - return False - - self.num_consecutive_result_file_download_retries = 0 - self.fetch_need_retry = False - return True - def _shutdown_manager(self): # Clear download handlers and shutdown the thread pool self.download_handlers = [] diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 019c4ef9..acfe0bc2 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -1,15 +1,17 @@ import logging from dataclasses import dataclass - import requests import lz4.frame import threading import time - +import os +import re from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink logger = logging.getLogger(__name__) +DEFAULT_CLOUD_FILE_TIMEOUT = int(os.getenv("DATABRICKS_CLOUD_FILE_TIMEOUT", 60)) + @dataclass class DownloadableResultSettings: @@ -20,13 +22,17 @@ class DownloadableResultSettings: is_lz4_compressed (bool): Whether file is expected to be lz4 compressed. link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs. download_timeout (int): Timeout for download requests. Default 60 secs. - max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down. + download_max_retries (int): Number of consecutive download retries before shutting down. + max_retries (int): Number of consecutive download retries before shutting down. + backoff_factor (int): Factor to increase wait time between retries. + """ is_lz4_compressed: bool link_expiry_buffer_secs: int = 0 - download_timeout: int = 60 - max_consecutive_file_download_retries: int = 0 + download_timeout: int = DEFAULT_CLOUD_FILE_TIMEOUT + max_retries: int = 5 + backoff_factor: int = 2 class ResultSetDownloadHandler(threading.Thread): @@ -57,16 +63,21 @@ def is_file_download_successful(self) -> bool: else None ) try: + logger.debug( + f"waiting for at most {timeout} seconds for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" + ) + if not self.is_download_finished.wait(timeout=timeout): self.is_download_timedout = True - logger.debug( - "Cloud fetch download timed out after {} seconds for link representing rows {} to {}".format( - self.settings.download_timeout, - self.result_link.startRowOffset, - self.result_link.startRowOffset + self.result_link.rowCount, - ) + logger.error( + f"cloud fetch download timed out after {self.settings.download_timeout} seconds for link representing rows {self.result_link.startRowOffset} to {self.result_link.startRowOffset + self.result_link.rowCount}" ) - return False + # there are some weird cases when the is_download_finished is not set, but the file is downloaded successfully + return self.is_file_downloaded_successfully + + logger.debug( + f"finish waiting for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" + ) except Exception as e: logger.error(e) return False @@ -81,24 +92,36 @@ def run(self): """ self._reset() - # Check if link is already expired or is expiring - if ResultSetDownloadHandler.check_link_expired( - self.result_link, self.settings.link_expiry_buffer_secs - ): - self.is_link_expired = True - return + try: + # Check if link is already expired or is expiring + if ResultSetDownloadHandler.check_link_expired( + self.result_link, self.settings.link_expiry_buffer_secs + ): + self.is_link_expired = True + return - session = requests.Session() - session.timeout = self.settings.download_timeout + logger.debug( + f"started to download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" + ) - try: # Get the file via HTTP request - response = session.get(self.result_link.fileLink) + response = http_get_with_retry( + url=self.result_link.fileLink, + max_retries=self.settings.max_retries, + backoff_factor=self.settings.backoff_factor, + download_timeout=self.settings.download_timeout, + ) - if not response.ok: - self.is_file_downloaded_successfully = False + if not response: + logger.error( + f"failed downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" + ) return + logger.debug( + f"success downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" + ) + # Save (and decompress if needed) the downloaded file compressed_data = response.content decompressed_data = ( @@ -109,15 +132,22 @@ def run(self): self.result_file = decompressed_data # The size of the downloaded file should match the size specified from TSparkArrowResultLink - self.is_file_downloaded_successfully = ( - len(self.result_file) == self.result_link.bytesNum + success = len(self.result_file) == self.result_link.bytesNum + logger.debug( + f"download successful file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" ) + self.is_file_downloaded_successfully = success except Exception as e: + logger.error( + f"exception downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" + ) logger.error(e) self.is_file_downloaded_successfully = False finally: - session and session.close() + logger.debug( + f"signal finished file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" + ) # Awaken threads waiting for this to be true which signals the run is complete self.is_download_finished.set() @@ -145,6 +175,7 @@ def check_link_expired( link.expiryTime < current_time or link.expiryTime - current_time < expiry_buffer_secs ): + logger.debug("link expired") return True return False @@ -171,3 +202,38 @@ def decompress_data(compressed_data: bytes) -> bytes: uncompressed_data += data start += num_bytes return uncompressed_data + + +def http_get_with_retry(url, max_retries=5, backoff_factor=2, download_timeout=60): + attempts = 0 + pattern = re.compile(r"(\?|&)([\w-]+)=([^&\s]+)") + mask = r"\1\2=" + + # TODO: introduce connection pooling. I am seeing weird errors without it. + while attempts < max_retries: + try: + session = requests.Session() + session.timeout = download_timeout + response = session.get(url) + + # Check if the response status code is in the 2xx range for success + if response.status_code == 200: + return response + else: + logger.error(response) + except requests.RequestException as e: + # if this is not redacted, it will print the pre-signed URL + logger.error(f"request failed with exception: {re.sub(pattern, mask, str(e))}") + finally: + session.close() + # Exponential backoff before the next attempt + wait_time = backoff_factor**attempts + logger.info(f"retrying in {wait_time} seconds...") + time.sleep(wait_time) + + attempts += 1 + + logger.error( + f"exceeded maximum number of retries ({max_retries}) while downloading result." + ) + return None diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 3b27283a..5a0086ce 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -115,3 +115,7 @@ class SessionAlreadyClosedError(RequestError): class CursorAlreadyClosedError(RequestError): """Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected.""" + + +class ResultSetDownloadError(RequestError): + """Thrown if there was an error during the download of a result set""" diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 69ac760a..2ed7d56e 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -393,7 +393,9 @@ def attempt_request(attempt): try: this_method_name = getattr(method, "__name__") - logger.debug("Sending request: {}()".format(this_method_name)) + logger.debug( + "sending thrift request: {}()".format(this_method_name) + ) unsafe_logger.debug("Sending request: {}".format(request)) # These three lines are no-ops if the v3 retry policy is not in use @@ -406,7 +408,9 @@ def attempt_request(attempt): # We need to call type(response) here because thrift doesn't implement __name__ attributes for thrift responses logger.debug( - "Received response: {}()".format(type(response).__name__) + "received thrift response: {}()".format( + type(response).__name__ + ) ) unsafe_logger.debug("Received response: {}".format(response)) return response @@ -764,6 +768,7 @@ def _results_message_to_execute_response(self, resp, operation_state): lz4_compressed = t_result_set_metadata_resp.lz4Compressed is_staging_operation = t_result_set_metadata_resp.isStagingOperation if direct_results and direct_results.resultSet: + logger.debug(f"received direct results") assert direct_results.resultSet.results.startRowOffset == 0 assert direct_results.resultSetMetadata @@ -776,6 +781,7 @@ def _results_message_to_execute_response(self, resp, operation_state): description=description, ) else: + logger.debug(f"must fetch results") arrow_queue_opt = None return ExecuteResponse( arrow_queue=arrow_queue_opt, @@ -840,6 +846,10 @@ def execute_command( ): assert session_handle is not None + logger.debug( + f"executing: cloud fetch: {use_cloud_fetch}, max rows: {max_rows}, max bytes: {max_bytes}" + ) + spark_arrow_types = ttypes.TSparkArrowTypes( timestampAsArrow=self._use_arrow_native_timestamps, decimalAsArrow=self._use_arrow_native_decimals, @@ -955,6 +965,7 @@ def get_columns( return self._handle_execute_response(resp, cursor) def _handle_execute_response(self, resp, cursor): + logger.debug(f"got execute response") cursor.active_op_handle = resp.operationHandle self._check_direct_results_for_error(resp.directResults) @@ -975,6 +986,7 @@ def fetch_results( arrow_schema_bytes, description, ): + logger.debug("started to fetch results") assert op_handle is not None req = ttypes.TFetchResultsReq( diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 7c3a014b..0d0caf87 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -10,6 +10,7 @@ from typing import Any, Dict, List, Union import re +import logging import lz4.frame import pyarrow @@ -24,6 +25,7 @@ from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] +logger = logging.getLogger(__name__) import logging @@ -81,6 +83,9 @@ def build_queue( ) return ArrowQueue(converted_arrow_table, n_valid_rows) elif row_set_type == TSparkRowSetType.URL_BASED_SET: + logger.debug( + f"built cloud fetch queue for {len(t_row_set.resultLinks)} links." + ) return CloudFetchQueue( arrow_schema_bytes, start_row_offset=t_row_set.startRowOffset, @@ -156,6 +161,9 @@ def __init__( self.lz4_compressed = lz4_compressed self.description = description + logger.debug( + f"creating cloud fetch queue for {len(result_links)} links and max_download_threads {self.max_download_threads}." + ) self.download_manager = ResultFileDownloadManager( self.max_download_threads, self.lz4_compressed ) diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index 97bf407a..1bad7f21 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -137,71 +137,4 @@ def test_find_next_file_index_one_scheduled_next_row_8000(self, mock_submit): manager.add_file_links(links) manager._schedule_downloads() - assert manager._find_next_file_index(8000) is None - - @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", - return_value=True) - @patch("concurrent.futures.ThreadPoolExecutor.submit") - def test_check_if_download_successful_happy(self, mock_submit, mock_is_file_download_successful): - links = self.create_result_links(num_files=10) - manager = self.create_download_manager() - manager.add_file_links(links) - manager._schedule_downloads() - - status = manager._check_if_download_successful(manager.download_handlers[0]) - assert status - assert manager.num_consecutive_result_file_download_retries == 0 - - @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", - return_value=False) - def test_check_if_download_successful_link_expired(self, mock_is_file_download_successful): - manager = self.create_download_manager() - handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) - handler.is_link_expired = True - - status = manager._check_if_download_successful(handler) - mock_is_file_download_successful.assert_called() - assert not status - assert manager.fetch_need_retry - - @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", - return_value=False) - def test_check_if_download_successful_download_timed_out_no_retries(self, mock_is_file_download_successful): - manager = self.create_download_manager() - handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) - handler.is_download_timedout = True - - status = manager._check_if_download_successful(handler) - mock_is_file_download_successful.assert_called() - assert not status - assert manager.fetch_need_retry - - @patch("concurrent.futures.ThreadPoolExecutor.submit") - @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", - return_value=False) - def test_check_if_download_successful_download_timed_out_1_retry(self, mock_is_file_download_successful, mock_submit): - manager = self.create_download_manager() - manager.downloadable_result_settings = download_manager.DownloadableResultSettings( - is_lz4_compressed=True, - download_timeout=0, - max_consecutive_file_download_retries=1, - ) - handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) - handler.is_download_timedout = True - - status = manager._check_if_download_successful(handler) - assert mock_is_file_download_successful.call_count == 2 - assert mock_submit.call_count == 1 - assert not status - assert manager.fetch_need_retry - - @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", - return_value=False) - def test_check_if_download_successful_other_reason(self, mock_is_file_download_successful): - manager = self.create_download_manager() - handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) - - status = manager._check_if_download_successful(handler) - mock_is_file_download_successful.assert_called() - assert not status - assert manager.fetch_need_retry + assert manager._find_next_file_index(8000) is None \ No newline at end of file diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 6e13c949..ebb7149e 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -13,18 +13,21 @@ class DownloaderTests(unittest.TestCase): def test_run_link_expired(self, mock_time): settings = Mock() result_link = Mock() + result_link.startRowOffset = 0 + result_link.rowCount = 100 # Already expired result_link.expiryTime = 999 d = downloader.ResultSetDownloadHandler(settings, result_link) assert not d.is_link_expired d.run() assert d.is_link_expired - mock_time.assert_called_once() @patch('time.time', return_value=1000) def test_run_link_past_expiry_buffer(self, mock_time): settings = Mock(link_expiry_buffer_secs=5) result_link = Mock() + result_link.startRowOffset = 0 + result_link.rowCount = 100 # Within the expiry buffer time result_link.expiryTime = 1004 d = downloader.ResultSetDownloadHandler(settings, result_link) @@ -33,13 +36,15 @@ def test_run_link_past_expiry_buffer(self, mock_time): assert d.is_link_expired mock_time.assert_called_once() - @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=False)))) + @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=False, status_code=500)))) @patch('time.time', return_value=1000) def test_run_get_response_not_ok(self, mock_time, mock_session): - settings = Mock(link_expiry_buffer_secs=0, download_timeout=0) + settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, max_retries = 5, backoff_factor = 2) settings.download_timeout = 0 settings.use_proxy = False result_link = Mock(expiryTime=1001) + result_link.startRowOffset = 0 + result_link.rowCount = 100 d = downloader.ResultSetDownloadHandler(settings, result_link) d.run() @@ -48,11 +53,13 @@ def test_run_get_response_not_ok(self, mock_time, mock_session): assert d.is_download_finished.is_set() @patch('requests.Session', - return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, content=b"1234567890" * 9)))) + return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, status_code=200, content=b"1234567890" * 9)))) @patch('time.time', return_value=1000) def test_run_uncompressed_data_length_incorrect(self, mock_time, mock_session): - settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False, is_lz4_compressed=False) + settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False, is_lz4_compressed=False, max_retries = 5, backoff_factor = 2) result_link = Mock(bytesNum=100, expiryTime=1001) + result_link.startRowOffset = 0 + result_link.rowCount = 100 d = downloader.ResultSetDownloadHandler(settings, result_link) d.run() @@ -60,12 +67,14 @@ def test_run_uncompressed_data_length_incorrect(self, mock_time, mock_session): assert not d.is_file_downloaded_successfully assert d.is_download_finished.is_set() - @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True)))) + @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, status_code=200)))) @patch('time.time', return_value=1000) def test_run_compressed_data_length_incorrect(self, mock_time, mock_session): - settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) + settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False, max_retries = 5, backoff_factor = 2) settings.is_lz4_compressed = True result_link = Mock(bytesNum=100, expiryTime=1001) + result_link.startRowOffset = 0 + result_link.rowCount = 100 mock_session.return_value.get.return_value.content = \ b'\x04"M\x18h@Z\x00\x00\x00\x00\x00\x00\x00\xec\x14\x00\x00\x00\xaf1234567890\n\x008P67890\x00\x00\x00\x00' @@ -76,13 +85,14 @@ def test_run_compressed_data_length_incorrect(self, mock_time, mock_session): assert d.is_download_finished.is_set() @patch('requests.Session', - return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, content=b"1234567890" * 10)))) + return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, status_code=200, content=b"1234567890" * 10)))) @patch('time.time', return_value=1000) def test_run_uncompressed_successful(self, mock_time, mock_session): - settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) + settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False, max_retries = 5, backoff_factor = 2) settings.is_lz4_compressed = False result_link = Mock(bytesNum=100, expiryTime=1001) - + result_link.startRowOffset = 0 + result_link.rowCount = 100 d = downloader.ResultSetDownloadHandler(settings, result_link) d.run() @@ -90,14 +100,20 @@ def test_run_uncompressed_successful(self, mock_time, mock_session): assert d.is_file_downloaded_successfully assert d.is_download_finished.is_set() - @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True)))) + @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock( + ok=True, + content=b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00', + status_code=200 + )))) @patch('time.time', return_value=1000) def test_run_compressed_successful(self, mock_time, mock_session): - settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) + settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False, max_retries = 5, backoff_factor = 2) settings.is_lz4_compressed = True result_link = Mock(bytesNum=100, expiryTime=1001) - mock_session.return_value.get.return_value.content = \ - b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' + result_link.startRowOffset = 0 + result_link.rowCount = 100 + # mock_session.return_value.get.return_value.content = \ + # b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' d = downloader.ResultSetDownloadHandler(settings, result_link) d.run() @@ -111,6 +127,8 @@ def test_run_compressed_successful(self, mock_time, mock_session): def test_download_connection_error(self, mock_time, mock_session): settings = Mock(link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True) result_link = Mock(bytesNum=100, expiryTime=1001) + result_link.startRowOffset = 0 + result_link.rowCount = 100 mock_session.return_value.get.return_value.content = \ b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' @@ -125,6 +143,8 @@ def test_download_connection_error(self, mock_time, mock_session): def test_download_timeout(self, mock_time, mock_session): settings = Mock(link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True) result_link = Mock(bytesNum=100, expiryTime=1001) + result_link.startRowOffset = 0 + result_link.rowCount = 100 mock_session.return_value.get.return_value.content = \ b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' @@ -148,6 +168,8 @@ def test_is_file_download_successful_has_finished(self, mock_wait): def test_is_file_download_successful_times_outs(self): settings = Mock(download_timeout=1) result_link = Mock() + result_link.startRowOffset = 0 + result_link.rowCount = 100 handler = downloader.ResultSetDownloadHandler(settings, result_link) status = handler.is_file_download_successful() From cef6cf34daa264ba756aff33de1a6a1d38fe8bb1 Mon Sep 17 00:00:00 2001 From: Andre Furlan Date: Wed, 21 Feb 2024 11:00:40 -0800 Subject: [PATCH 2/2] bump version to 3.1.1 Signed-off-by: Andre Furlan --- CHANGELOG.md | 4 +++ pyproject.toml | 2 +- src/databricks/sql/__init__.py | 34 +-------------------- src/databricks/sql/cloudfetch/downloader.py | 8 +++-- src/databricks/sql/utils.py | 3 ++ 5 files changed, 14 insertions(+), 37 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c972ab6..b830a251 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Release History +# 3.1.1 (2024-02-21) + +- Fix: Cloud fetch file download errors (#356) + # 3.1.0 (2024-02-16) - Revert retry-after behavior to be exponential backoff (#349) diff --git a/pyproject.toml b/pyproject.toml index 784954fc..fae61739 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "3.1.0" +version = "3.1.1" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index e1bc2e91..2cc54e25 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -41,38 +41,6 @@ def filter(self, record): logging.getLogger("urllib3.connectionpool").addFilter(RedactUrlQueryParamsFilter()) -import re - - -class RedactUrlQueryParamsFilter(logging.Filter): - pattern = re.compile(r"(\?|&)([\w-]+)=([^&\s]+)") - mask = r"\1\2=" - - def __init__(self): - super().__init__() - - def redact(self, string): - return re.sub(self.pattern, self.mask, str(string)) - - def filter(self, record): - record.msg = self.redact(str(record.msg)) - if isinstance(record.args, dict): - for k in record.args.keys(): - record.args[k] = ( - self.redact(record.args[k]) - if isinstance(record.arg[k], str) - else record.args[k] - ) - else: - record.args = tuple( - (self.redact(arg) if isinstance(arg, str) else arg) - for arg in record.args - ) - - return True - - -logging.getLogger("urllib3.connectionpool").addFilter(RedactUrlQueryParamsFilter()) class DBAPITypeObject(object): def __init__(self, *values): @@ -94,7 +62,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "3.1.0" +__version__ = "3.1.1" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index acfe0bc2..f553a7df 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -DEFAULT_CLOUD_FILE_TIMEOUT = int(os.getenv("DATABRICKS_CLOUD_FILE_TIMEOUT", 60)) +DEFAULT_CLOUD_FILE_TIMEOUT = int(os.getenv("DATABRICKS_CLOUD_FILE_TIMEOUT", 180)) @dataclass @@ -221,9 +221,11 @@ def http_get_with_retry(url, max_retries=5, backoff_factor=2, download_timeout=6 return response else: logger.error(response) - except requests.RequestException as e: + except Exception as e: # if this is not redacted, it will print the pre-signed URL - logger.error(f"request failed with exception: {re.sub(pattern, mask, str(e))}") + logger.error( + f"request failed with exception: {re.sub(pattern, mask, str(e))}" + ) finally: session.close() # Exponential backoff before the next attempt diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 0d0caf87..91cc2309 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -236,6 +236,9 @@ def _create_next_table(self) -> Union[pyarrow.Table, None]: # The server rarely prepares the exact number of rows requested by the client in cloud fetch. # Subsequently, we drop the extraneous rows in the last file if more rows are retrieved than requested if arrow_table.num_rows > downloaded_file.row_count: + logger.debug( + f"received {arrow_table.num_rows} rows, expected {downloaded_file.row_count} rows. Dropping extraneous rows." + ) self.start_row_index += downloaded_file.row_count return arrow_table.slice(0, downloaded_file.row_count)