From dbce25113ebcba3eb724f279cf5b993e84da6f63 Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Wed, 14 Jun 2023 10:56:26 -0700 Subject: [PATCH 01/12] Cloud Fetch download manager Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- .../sql/cloudfetch/download_manager.py | 139 +++++++++++ tests/unit/test_download_manager.py | 215 ++++++++++++++++++ 2 files changed, 354 insertions(+) create mode 100644 src/databricks/sql/cloudfetch/download_manager.py create mode 100644 tests/unit/test_download_manager.py diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py new file mode 100644 index 00000000..447ea616 --- /dev/null +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -0,0 +1,139 @@ +from collections import namedtuple +from concurrent.futures import ThreadPoolExecutor +from databricks.sql.cloudfetch.downloader import ResultSetDownloadHandler + + +class ResultFileDownloadManager: + + def __init__(self, max_download_threads, lz4_compressed): + self.download_handlers = [] + self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1) + self.downloadable_result_settings = _get_downloadable_result_settings(lz4_compressed) + self.fetch_need_retry = False + self.num_consecutive_result_file_download_retries = 0 + self.cloud_fetch_index = 0 + + def add_file_links(self, t_spark_arrow_result_links, next_row_index): + for t_spark_arrow_result_link in t_spark_arrow_result_links: + if t_spark_arrow_result_link.rowCount <= 0: + continue + self.download_handlers.append(ResultSetDownloadHandler( + self.downloadable_result_settings, t_spark_arrow_result_link)) + self.cloud_fetch_index = next_row_index + + def get_next_downloaded_file(self, next_row_index): + if not self.download_handlers: + return None + + # Remove handlers we don't need anymore + self._remove_past_handlers(next_row_index) + + # Schedule the downloads + self._schedule_downloads() + + # Find next file + idx = self._find_next_file_index(next_row_index) + if idx is None: + return None + handler = self.download_handlers[idx] + + # Check (and wait) for download status + if self._check_if_download_successful(handler): + # Buffer should be empty so set buffer to new ArrowQueue with result_file + result = DownloadedFile( + handler.result_file, + handler.result_link.startRowOffset, + handler.result_link.rowCount, + ) + self.cloud_fetch_index += handler.result_link.rowCount + self.download_handlers.pop(idx) + # Return True upon successful download to continue loop and not force a retry + return result + # Download was not successful for next download item, force a retry + return None + + def _remove_past_handlers(self, next_row_index): + """ + Remove any download handlers whose start to end range doesn't include the next row to be fetched + i.e. no need to download + """ + i = 0 + while i < len(self.download_handlers): + result_link = self.download_handlers[i].result_link + if result_link.startRowOffset + result_link.rowCount > next_row_index: + i += 1 + continue + self.download_handlers.pop(i) + + def _schedule_downloads(self): + """ + Schedule downloads for all download handlers if not already scheduled + """ + for handler in self.download_handlers: + if handler.is_download_scheduled: + continue + try: + self.thread_pool.submit(handler) + except: + break + handler.is_download_scheduled = True + + def _find_next_file_index(self, next_row_index): + # Get the next downloaded file + next_indices = [i for i, handler in enumerate(self.download_handlers) + if handler.is_download_scheduled and handler.result_link.startRowOffset == next_row_index] + return next_indices[0] if len(next_indices) > 0 else None + + def _check_if_download_successful(self, handler): + if not handler.is_file_download_successful(): + if handler.is_link_expired: + self._stop_all_downloads_and_clear_handlers() + self.fetch_need_retry = True + return False + elif handler.is_download_timedout: + if self.num_consecutive_result_file_download_retries >= \ + self.downloadable_result_settings.max_consecutive_file_download_retries: + # raise Exception("File download exceeded max retry limit") + self.fetch_need_retry = True + return False + self.num_consecutive_result_file_download_retries += 1 + self.thread_pool.submit(handler) + return self._check_if_download_successful(handler) + else: + self.fetch_need_retry = True + return False + + self.num_consecutive_result_file_download_retries = 0 + self.fetch_need_retry = False + return True + + def _stop_all_downloads_and_clear_handlers(self): + self.download_handlers = [] + + +DownloadableResultSettings = namedtuple( + "DownloadableResultSettings", + "is_lz4_compressed result_file_link_expiry_buffer download_timeout use_proxy disable_proxy_for_cloud_fetch " + "proxy_host proxy_port proxy_uid proxy_pwd max_consecutive_file_download_retries download_retry_wait_time" +) + +DownloadedFile = namedtuple( + "DownloadedFile", + "file_bytes start_row_offset row_count" +) + + +def _get_downloadable_result_settings(lz4_compressed): + return DownloadableResultSettings( + is_lz4_compressed=lz4_compressed, + result_file_link_expiry_buffer=0, + download_timeout=0, + use_proxy=False, + disable_proxy_for_cloud_fetch=False, + proxy_host="", + proxy_port=0, + proxy_uid="", + proxy_pwd="", + max_consecutive_file_download_retries=0, + download_retry_wait_time=0.1 + ) diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py new file mode 100644 index 00000000..54adb539 --- /dev/null +++ b/tests/unit/test_download_manager.py @@ -0,0 +1,215 @@ +import unittest +from unittest.mock import patch, MagicMock + +import databricks.sql.cloudfetch.download_manager as download_manager +import databricks.sql.cloudfetch.downloader as downloader +from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink + + +class DownloadManagerTests(unittest.TestCase): + """ + Unit tests for checking download manager logic. + """ + + def create_download_manager(self): + max_download_threads = 10 + lz4_compressed = True + return download_manager.ResultFileDownloadManager(max_download_threads, lz4_compressed) + + def create_result_link( + self, + file_link: str = "fileLink", + start_row_offset: int = 0, + row_count: int = 8000, + bytes_num: int = 20971520 + ): + return TSparkArrowResultLink(file_link, None, start_row_offset, row_count, bytes_num) + + def create_result_links(self, num_files: int, start_row_offset: int = 0): + result_links = [] + for i in range(num_files): + file_link = "fileLink_" + str(i) + result_link = self.create_result_link(file_link=file_link, start_row_offset=start_row_offset) + result_links.append(result_link) + start_row_offset += result_link.rowCount + return result_links + + def test_add_file_links_zero_row_count(self): + links = [self.create_result_link(row_count=0, bytes_num=0)] + manager = self.create_download_manager() + manager.add_file_links(links, 0) + + assert not manager.download_handlers + + def test_add_file_links_success(self): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links, 0) + + assert len(manager.download_handlers) == 10 + + def test_remove_past_handlers_one(self): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links, 0) + + manager._remove_past_handlers(8000) + assert len(manager.download_handlers) == 9 + + def test_remove_past_handlers_all(self): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links, 0) + + manager._remove_past_handlers(8000*10) + assert len(manager.download_handlers) == 0 + + @patch("concurrent.futures.ThreadPoolExecutor.submit") + def test_schedule_downloads_partial_already_scheduled(self, mock_submit): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links, 0) + + for i in range(5): + manager.download_handlers[i].is_download_scheduled = True + + manager._schedule_downloads() + assert mock_submit.call_count == 5 + assert sum([1 if handler.is_download_scheduled else 0 for handler in manager.download_handlers]) == 10 + + @patch("concurrent.futures.ThreadPoolExecutor.submit") + def test_schedule_downloads_will_not_schedule_twice(self, mock_submit): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links, 0) + + for i in range(5): + manager.download_handlers[i].is_download_scheduled = True + + manager._schedule_downloads() + assert mock_submit.call_count == 5 + assert sum([1 if handler.is_download_scheduled else 0 for handler in manager.download_handlers]) == 10 + + manager._schedule_downloads() + assert mock_submit.call_count == 5 + + @patch("concurrent.futures.ThreadPoolExecutor.submit", side_effect=[True, KeyError("foo")]) + def test_schedule_downloads_submit_fails(self, mock_submit): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links, 0) + + manager._schedule_downloads() + assert mock_submit.call_count == 2 + assert sum([1 if handler.is_download_scheduled else 0 for handler in manager.download_handlers]) == 1 + + @patch("concurrent.futures.ThreadPoolExecutor.submit") + def test_find_next_file_index_all_scheduled_next_row_0(self, mock_submit): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links, 0) + manager._schedule_downloads() + + assert manager._find_next_file_index(0) == 0 + + @patch("concurrent.futures.ThreadPoolExecutor.submit") + def test_find_next_file_index_all_scheduled_next_row_7999(self, mock_submit): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links, 0) + manager._schedule_downloads() + + assert manager._find_next_file_index(7999) is None + + @patch("concurrent.futures.ThreadPoolExecutor.submit") + def test_find_next_file_index_all_scheduled_next_row_8000(self, mock_submit): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links, 0) + manager._schedule_downloads() + + assert manager._find_next_file_index(8000) == 1 + + @patch("concurrent.futures.ThreadPoolExecutor.submit", side_effect=[True, KeyError("foo")]) + def test_find_next_file_index_one_scheduled_next_row_8000(self, mock_submit): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links, 0) + manager._schedule_downloads() + + assert manager._find_next_file_index(8000) is None + + @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", + return_value=True) + @patch("concurrent.futures.ThreadPoolExecutor.submit") + def test_check_if_download_successful_happy(self, mock_submit, mock_is_file_download_successful): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links, 0) + manager._schedule_downloads() + + status = manager._check_if_download_successful(manager.download_handlers[0]) + assert status + assert manager.num_consecutive_result_file_download_retries == 0 + + @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", + return_value=False) + def test_check_if_download_successful_link_expired(self, mock_is_file_download_successful): + manager = self.create_download_manager() + handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) + handler.is_link_expired = True + + status = manager._check_if_download_successful(handler) + mock_is_file_download_successful.assert_called() + assert not status + assert manager.fetch_need_retry + + @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", + return_value=False) + def test_check_if_download_successful_download_timed_out_no_retries(self, mock_is_file_download_successful): + manager = self.create_download_manager() + handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) + handler.is_download_timedout = True + + status = manager._check_if_download_successful(handler) + mock_is_file_download_successful.assert_called() + assert not status + assert manager.fetch_need_retry + + @patch("concurrent.futures.ThreadPoolExecutor.submit") + @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", + return_value=False) + def test_check_if_download_successful_download_timed_out_1_retry(self, mock_is_file_download_successful, mock_submit): + manager = self.create_download_manager() + manager.downloadable_result_settings = download_manager.DownloadableResultSettings( + is_lz4_compressed=True, + result_file_link_expiry_buffer=0, + download_timeout=0, + use_proxy=False, + disable_proxy_for_cloud_fetch=False, + proxy_host="", + proxy_port=0, + proxy_uid="", + proxy_pwd="", + max_consecutive_file_download_retries=1, + download_retry_wait_time=0.1 + ) + handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) + handler.is_download_timedout = True + + status = manager._check_if_download_successful(handler) + assert mock_is_file_download_successful.call_count == 2 + assert mock_submit.call_count == 1 + assert not status + assert manager.fetch_need_retry + + @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", + return_value=False) + def test_check_if_download_successful_other_reason(self, mock_is_file_download_successful): + manager = self.create_download_manager() + handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) + + status = manager._check_if_download_successful(handler) + mock_is_file_download_successful.assert_called() + assert not status + assert manager.fetch_need_retry From 70c6b6e363087fdf9525aa6b01ed992c7dd726cb Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Wed, 21 Jun 2023 06:07:22 -0700 Subject: [PATCH 02/12] Bug fix: submit handler.run Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- src/databricks/sql/cloudfetch/download_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 447ea616..40ee5862 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -73,7 +73,7 @@ def _schedule_downloads(self): if handler.is_download_scheduled: continue try: - self.thread_pool.submit(handler) + self.thread_pool.submit(handler.run) except: break handler.is_download_scheduled = True From 7460536f55a7297f9be6082940f2ad5d3669e3db Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Mon, 26 Jun 2023 03:51:25 -0700 Subject: [PATCH 03/12] Type annotations Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- .../sql/cloudfetch/download_manager.py | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 40ee5862..f4969ed7 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -1,11 +1,18 @@ +import logging + from collections import namedtuple from concurrent.futures import ThreadPoolExecutor +from typing import List, Union + from databricks.sql.cloudfetch.downloader import ResultSetDownloadHandler +from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink + +logger = logging.getLogger(__name__) class ResultFileDownloadManager: - def __init__(self, max_download_threads, lz4_compressed): + def __init__(self, max_download_threads: int, lz4_compressed: bool): self.download_handlers = [] self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1) self.downloadable_result_settings = _get_downloadable_result_settings(lz4_compressed) @@ -13,15 +20,15 @@ def __init__(self, max_download_threads, lz4_compressed): self.num_consecutive_result_file_download_retries = 0 self.cloud_fetch_index = 0 - def add_file_links(self, t_spark_arrow_result_links, next_row_index): - for t_spark_arrow_result_link in t_spark_arrow_result_links: - if t_spark_arrow_result_link.rowCount <= 0: + def add_file_links(self, t_spark_arrow_result_links: List[TSparkArrowResultLink], next_row_index: int) -> None: + for link in t_spark_arrow_result_links: + if link.rowCount <= 0: continue self.download_handlers.append(ResultSetDownloadHandler( - self.downloadable_result_settings, t_spark_arrow_result_link)) + self.downloadable_result_settings, link)) self.cloud_fetch_index = next_row_index - def get_next_downloaded_file(self, next_row_index): + def get_next_downloaded_file(self, next_row_index: int) -> Union[tuple, None]: if not self.download_handlers: return None @@ -52,7 +59,7 @@ def get_next_downloaded_file(self, next_row_index): # Download was not successful for next download item, force a retry return None - def _remove_past_handlers(self, next_row_index): + def _remove_past_handlers(self, next_row_index: int): """ Remove any download handlers whose start to end range doesn't include the next row to be fetched i.e. no need to download @@ -74,17 +81,18 @@ def _schedule_downloads(self): continue try: self.thread_pool.submit(handler.run) - except: + except Exception as e: + logger.error(e) break handler.is_download_scheduled = True - def _find_next_file_index(self, next_row_index): + def _find_next_file_index(self, next_row_index: int): # Get the next downloaded file next_indices = [i for i, handler in enumerate(self.download_handlers) if handler.is_download_scheduled and handler.result_link.startRowOffset == next_row_index] return next_indices[0] if len(next_indices) > 0 else None - def _check_if_download_successful(self, handler): + def _check_if_download_successful(self, handler: ResultSetDownloadHandler): if not handler.is_file_download_successful(): if handler.is_link_expired: self._stop_all_downloads_and_clear_handlers() @@ -93,7 +101,6 @@ def _check_if_download_successful(self, handler): elif handler.is_download_timedout: if self.num_consecutive_result_file_download_retries >= \ self.downloadable_result_settings.max_consecutive_file_download_retries: - # raise Exception("File download exceeded max retry limit") self.fetch_need_retry = True return False self.num_consecutive_result_file_download_retries += 1 From 76803dc4e88fbadb8eec58c8f755be727b82c1e0 Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Mon, 26 Jun 2023 04:03:58 -0700 Subject: [PATCH 04/12] Namedtuple -> dataclass Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- .../sql/cloudfetch/download_manager.py | 49 +++++++------------ tests/unit/test_download_manager.py | 8 --- 2 files changed, 18 insertions(+), 39 deletions(-) diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index f4969ed7..7e97f10f 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -1,7 +1,7 @@ import logging -from collections import namedtuple from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from typing import List, Union from databricks.sql.cloudfetch.downloader import ResultSetDownloadHandler @@ -10,12 +10,27 @@ logger = logging.getLogger(__name__) +@dataclass +class DownloadableResultSettings: + is_lz4_compressed: bool + link_expiry_buffer: float = 0.0 + download_timeout: float = 0.0 + max_consecutive_file_download_retries: int = 0 + + +@dataclass +class DownloadedFile: + file_bytes: bytes + start_row_offset: int + row_count: int + + class ResultFileDownloadManager: def __init__(self, max_download_threads: int, lz4_compressed: bool): self.download_handlers = [] self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1) - self.downloadable_result_settings = _get_downloadable_result_settings(lz4_compressed) + self.downloadable_result_settings = DownloadableResultSettings(lz4_compressed) self.fetch_need_retry = False self.num_consecutive_result_file_download_retries = 0 self.cloud_fetch_index = 0 @@ -28,7 +43,7 @@ def add_file_links(self, t_spark_arrow_result_links: List[TSparkArrowResultLink] self.downloadable_result_settings, link)) self.cloud_fetch_index = next_row_index - def get_next_downloaded_file(self, next_row_index: int) -> Union[tuple, None]: + def get_next_downloaded_file(self, next_row_index: int) -> Union[DownloadedFile, None]: if not self.download_handlers: return None @@ -116,31 +131,3 @@ def _check_if_download_successful(self, handler: ResultSetDownloadHandler): def _stop_all_downloads_and_clear_handlers(self): self.download_handlers = [] - - -DownloadableResultSettings = namedtuple( - "DownloadableResultSettings", - "is_lz4_compressed result_file_link_expiry_buffer download_timeout use_proxy disable_proxy_for_cloud_fetch " - "proxy_host proxy_port proxy_uid proxy_pwd max_consecutive_file_download_retries download_retry_wait_time" -) - -DownloadedFile = namedtuple( - "DownloadedFile", - "file_bytes start_row_offset row_count" -) - - -def _get_downloadable_result_settings(lz4_compressed): - return DownloadableResultSettings( - is_lz4_compressed=lz4_compressed, - result_file_link_expiry_buffer=0, - download_timeout=0, - use_proxy=False, - disable_proxy_for_cloud_fetch=False, - proxy_host="", - proxy_port=0, - proxy_uid="", - proxy_pwd="", - max_consecutive_file_download_retries=0, - download_retry_wait_time=0.1 - ) diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index 54adb539..201b3474 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -183,16 +183,8 @@ def test_check_if_download_successful_download_timed_out_1_retry(self, mock_is_f manager = self.create_download_manager() manager.downloadable_result_settings = download_manager.DownloadableResultSettings( is_lz4_compressed=True, - result_file_link_expiry_buffer=0, download_timeout=0, - use_proxy=False, - disable_proxy_for_cloud_fetch=False, - proxy_host="", - proxy_port=0, - proxy_uid="", - proxy_pwd="", max_consecutive_file_download_retries=1, - download_retry_wait_time=0.1 ) handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) handler.is_download_timedout = True From 90999956a209cd8dbe917ef5286595e2f8e76784 Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Mon, 26 Jun 2023 14:15:57 -0700 Subject: [PATCH 05/12] Shutdown thread pool and clear handlers Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- src/databricks/sql/cloudfetch/download_manager.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 7e97f10f..cf0fab77 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -45,6 +45,7 @@ def add_file_links(self, t_spark_arrow_result_links: List[TSparkArrowResultLink] def get_next_downloaded_file(self, next_row_index: int) -> Union[DownloadedFile, None]: if not self.download_handlers: + self._shutdown_manager() return None # Remove handlers we don't need anymore @@ -56,6 +57,7 @@ def get_next_downloaded_file(self, next_row_index: int) -> Union[DownloadedFile, # Find next file idx = self._find_next_file_index(next_row_index) if idx is None: + self._shutdown_manager() return None handler = self.download_handlers[idx] @@ -72,6 +74,7 @@ def get_next_downloaded_file(self, next_row_index: int) -> Union[DownloadedFile, # Return True upon successful download to continue loop and not force a retry return result # Download was not successful for next download item, force a retry + self._shutdown_manager() return None def _remove_past_handlers(self, next_row_index: int): @@ -110,7 +113,7 @@ def _find_next_file_index(self, next_row_index: int): def _check_if_download_successful(self, handler: ResultSetDownloadHandler): if not handler.is_file_download_successful(): if handler.is_link_expired: - self._stop_all_downloads_and_clear_handlers() + self._shutdown_manager() self.fetch_need_retry = True return False elif handler.is_download_timedout: @@ -129,5 +132,6 @@ def _check_if_download_successful(self, handler: ResultSetDownloadHandler): self.fetch_need_retry = False return True - def _stop_all_downloads_and_clear_handlers(self): + def _shutdown_manager(self): self.download_handlers = [] + self.thread_pool.shutdown(wait=False, cancel_futures=True) From 7018ef56770229a2036af7c4c008bf59c9d9d0c3 Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Mon, 26 Jun 2023 20:02:38 -0700 Subject: [PATCH 06/12] Docstrings and comments Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- .../sql/cloudfetch/download_manager.py | 100 +++++++++++++----- tests/unit/test_download_manager.py | 24 ++--- 2 files changed, 84 insertions(+), 40 deletions(-) diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index cf0fab77..b18d3e9a 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -12,50 +12,90 @@ @dataclass class DownloadableResultSettings: + """ + Class for settings common to each download handler. + + Attributes: + is_lz4_compressed (bool): Whether file is expected to be lz4 compressed. + link_expiry_buffer (int): Time in seconds to prevent download of a link before it expires. Default 0 secs. + download_timeout (int): Timeout for download requests. Default 60 secs. + max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down. + """ + is_lz4_compressed: bool - link_expiry_buffer: float = 0.0 - download_timeout: float = 0.0 + link_expiry_buffer: int = 0 + download_timeout: int = 60 max_consecutive_file_download_retries: int = 0 @dataclass class DownloadedFile: + """ + Class for the result file and metadata. + + Attributes: + file_bytes (bytes): Downloaded file in bytes. + start_row_offset (int): The offset of the starting row in relation to the full result. + row_count (int): Number of rows the file represents in the result. + """ + file_bytes: bytes start_row_offset: int row_count: int class ResultFileDownloadManager: - def __init__(self, max_download_threads: int, lz4_compressed: bool): self.download_handlers = [] self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1) self.downloadable_result_settings = DownloadableResultSettings(lz4_compressed) self.fetch_need_retry = False self.num_consecutive_result_file_download_retries = 0 - self.cloud_fetch_index = 0 - def add_file_links(self, t_spark_arrow_result_links: List[TSparkArrowResultLink], next_row_index: int) -> None: + def add_file_links( + self, t_spark_arrow_result_links: List[TSparkArrowResultLink] + ) -> None: + """ + Create download handler for each cloud fetch link. + + Args: + t_spark_arrow_result_links: List of cloud fetch links consisting of file URL and metadata. + """ for link in t_spark_arrow_result_links: if link.rowCount <= 0: continue - self.download_handlers.append(ResultSetDownloadHandler( - self.downloadable_result_settings, link)) - self.cloud_fetch_index = next_row_index + self.download_handlers.append( + ResultSetDownloadHandler(self.downloadable_result_settings, link) + ) - def get_next_downloaded_file(self, next_row_index: int) -> Union[DownloadedFile, None]: + def get_next_downloaded_file( + self, next_row_offset: int + ) -> Union[DownloadedFile, None]: + """ + Get next file that starts at given offset. + + This function gets the next downloaded file in which its rows start at the specified next_row_offset + in relation to the full result. File downloads are scheduled if not already, and once the correct + download handler is located, the function waits for the download status and returns the resulting file. + If there are no more downloads, a download was not successful, or the correct file could not be located, + this function shuts down the thread pool and returns None. + + Args: + next_row_offset (int): The offset of the starting row of the next file we want data from. + """ + # No more files to download from this batch of links if not self.download_handlers: self._shutdown_manager() return None # Remove handlers we don't need anymore - self._remove_past_handlers(next_row_index) + self._remove_past_handlers(next_row_offset) # Schedule the downloads self._schedule_downloads() # Find next file - idx = self._find_next_file_index(next_row_index) + idx = self._find_next_file_index(next_row_offset) if idx is None: self._shutdown_manager() return None @@ -69,7 +109,6 @@ def get_next_downloaded_file(self, next_row_index: int) -> Union[DownloadedFile, handler.result_link.startRowOffset, handler.result_link.rowCount, ) - self.cloud_fetch_index += handler.result_link.rowCount self.download_handlers.pop(idx) # Return True upon successful download to continue loop and not force a retry return result @@ -77,23 +116,18 @@ def get_next_downloaded_file(self, next_row_index: int) -> Union[DownloadedFile, self._shutdown_manager() return None - def _remove_past_handlers(self, next_row_index: int): - """ - Remove any download handlers whose start to end range doesn't include the next row to be fetched - i.e. no need to download - """ + def _remove_past_handlers(self, next_row_offset: int): + # Any link in which its start to end range doesn't include the next row to be fetched does not need downloading i = 0 while i < len(self.download_handlers): result_link = self.download_handlers[i].result_link - if result_link.startRowOffset + result_link.rowCount > next_row_index: + if result_link.startRowOffset + result_link.rowCount > next_row_offset: i += 1 continue self.download_handlers.pop(i) def _schedule_downloads(self): - """ - Schedule downloads for all download handlers if not already scheduled - """ + # Schedule downloads for all download handlers if not already scheduled. for handler in self.download_handlers: if handler.is_download_scheduled: continue @@ -104,24 +138,33 @@ def _schedule_downloads(self): break handler.is_download_scheduled = True - def _find_next_file_index(self, next_row_index: int): - # Get the next downloaded file - next_indices = [i for i, handler in enumerate(self.download_handlers) - if handler.is_download_scheduled and handler.result_link.startRowOffset == next_row_index] + def _find_next_file_index(self, next_row_offset: int): + # Get the handler index of the next file in order + next_indices = [ + i + for i, handler in enumerate(self.download_handlers) + if handler.is_download_scheduled + and handler.result_link.startRowOffset == next_row_offset + ] return next_indices[0] if len(next_indices) > 0 else None def _check_if_download_successful(self, handler: ResultSetDownloadHandler): + # Check (and wait until download finishes) if download was successful if not handler.is_file_download_successful(): if handler.is_link_expired: - self._shutdown_manager() self.fetch_need_retry = True return False elif handler.is_download_timedout: - if self.num_consecutive_result_file_download_retries >= \ - self.downloadable_result_settings.max_consecutive_file_download_retries: + # Consecutive file retries should not exceed threshold in settings + if ( + self.num_consecutive_result_file_download_retries + >= self.downloadable_result_settings.max_consecutive_file_download_retries + ): self.fetch_need_retry = True return False self.num_consecutive_result_file_download_retries += 1 + + # Re-submit handler run to thread pool and recursively check download status self.thread_pool.submit(handler) return self._check_if_download_successful(handler) else: @@ -133,5 +176,6 @@ def _check_if_download_successful(self, handler: ResultSetDownloadHandler): return True def _shutdown_manager(self): + # Clear download handlers and shutdown the thread pool to cancel pending futures self.download_handlers = [] self.thread_pool.shutdown(wait=False, cancel_futures=True) diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index 201b3474..97bf407a 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -37,21 +37,21 @@ def create_result_links(self, num_files: int, start_row_offset: int = 0): def test_add_file_links_zero_row_count(self): links = [self.create_result_link(row_count=0, bytes_num=0)] manager = self.create_download_manager() - manager.add_file_links(links, 0) + manager.add_file_links(links) assert not manager.download_handlers def test_add_file_links_success(self): links = self.create_result_links(num_files=10) manager = self.create_download_manager() - manager.add_file_links(links, 0) + manager.add_file_links(links) assert len(manager.download_handlers) == 10 def test_remove_past_handlers_one(self): links = self.create_result_links(num_files=10) manager = self.create_download_manager() - manager.add_file_links(links, 0) + manager.add_file_links(links) manager._remove_past_handlers(8000) assert len(manager.download_handlers) == 9 @@ -59,7 +59,7 @@ def test_remove_past_handlers_one(self): def test_remove_past_handlers_all(self): links = self.create_result_links(num_files=10) manager = self.create_download_manager() - manager.add_file_links(links, 0) + manager.add_file_links(links) manager._remove_past_handlers(8000*10) assert len(manager.download_handlers) == 0 @@ -68,7 +68,7 @@ def test_remove_past_handlers_all(self): def test_schedule_downloads_partial_already_scheduled(self, mock_submit): links = self.create_result_links(num_files=10) manager = self.create_download_manager() - manager.add_file_links(links, 0) + manager.add_file_links(links) for i in range(5): manager.download_handlers[i].is_download_scheduled = True @@ -81,7 +81,7 @@ def test_schedule_downloads_partial_already_scheduled(self, mock_submit): def test_schedule_downloads_will_not_schedule_twice(self, mock_submit): links = self.create_result_links(num_files=10) manager = self.create_download_manager() - manager.add_file_links(links, 0) + manager.add_file_links(links) for i in range(5): manager.download_handlers[i].is_download_scheduled = True @@ -97,7 +97,7 @@ def test_schedule_downloads_will_not_schedule_twice(self, mock_submit): def test_schedule_downloads_submit_fails(self, mock_submit): links = self.create_result_links(num_files=10) manager = self.create_download_manager() - manager.add_file_links(links, 0) + manager.add_file_links(links) manager._schedule_downloads() assert mock_submit.call_count == 2 @@ -107,7 +107,7 @@ def test_schedule_downloads_submit_fails(self, mock_submit): def test_find_next_file_index_all_scheduled_next_row_0(self, mock_submit): links = self.create_result_links(num_files=10) manager = self.create_download_manager() - manager.add_file_links(links, 0) + manager.add_file_links(links) manager._schedule_downloads() assert manager._find_next_file_index(0) == 0 @@ -116,7 +116,7 @@ def test_find_next_file_index_all_scheduled_next_row_0(self, mock_submit): def test_find_next_file_index_all_scheduled_next_row_7999(self, mock_submit): links = self.create_result_links(num_files=10) manager = self.create_download_manager() - manager.add_file_links(links, 0) + manager.add_file_links(links) manager._schedule_downloads() assert manager._find_next_file_index(7999) is None @@ -125,7 +125,7 @@ def test_find_next_file_index_all_scheduled_next_row_7999(self, mock_submit): def test_find_next_file_index_all_scheduled_next_row_8000(self, mock_submit): links = self.create_result_links(num_files=10) manager = self.create_download_manager() - manager.add_file_links(links, 0) + manager.add_file_links(links) manager._schedule_downloads() assert manager._find_next_file_index(8000) == 1 @@ -134,7 +134,7 @@ def test_find_next_file_index_all_scheduled_next_row_8000(self, mock_submit): def test_find_next_file_index_one_scheduled_next_row_8000(self, mock_submit): links = self.create_result_links(num_files=10) manager = self.create_download_manager() - manager.add_file_links(links, 0) + manager.add_file_links(links) manager._schedule_downloads() assert manager._find_next_file_index(8000) is None @@ -145,7 +145,7 @@ def test_find_next_file_index_one_scheduled_next_row_8000(self, mock_submit): def test_check_if_download_successful_happy(self, mock_submit, mock_is_file_download_successful): links = self.create_result_links(num_files=10) manager = self.create_download_manager() - manager.add_file_links(links, 0) + manager.add_file_links(links) manager._schedule_downloads() status = manager._check_if_download_successful(manager.download_handlers[0]) From 65af40fb4ed4f084003d9ee646d37e03e6924a34 Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Tue, 27 Jun 2023 15:10:21 -0700 Subject: [PATCH 07/12] handler.run is the correct call Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- src/databricks/sql/cloudfetch/download_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index b18d3e9a..1819535b 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -165,7 +165,7 @@ def _check_if_download_successful(self, handler: ResultSetDownloadHandler): self.num_consecutive_result_file_download_retries += 1 # Re-submit handler run to thread pool and recursively check download status - self.thread_pool.submit(handler) + self.thread_pool.submit(handler.run) return self._check_if_download_successful(handler) else: self.fetch_need_retry = True From f8cdfc540e28444ea79a06c5468ae5617ab636e1 Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Tue, 27 Jun 2023 15:13:45 -0700 Subject: [PATCH 08/12] Link expiry buffer in secs Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- src/databricks/sql/cloudfetch/download_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 1819535b..2c58bee9 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -17,13 +17,13 @@ class DownloadableResultSettings: Attributes: is_lz4_compressed (bool): Whether file is expected to be lz4 compressed. - link_expiry_buffer (int): Time in seconds to prevent download of a link before it expires. Default 0 secs. + link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs. download_timeout (int): Timeout for download requests. Default 60 secs. max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down. """ is_lz4_compressed: bool - link_expiry_buffer: int = 0 + link_expiry_buffer_secs: int = 0 download_timeout: int = 60 max_consecutive_file_download_retries: int = 0 From 14fa3521352c088ae5acf92a1fe83b7092567f4f Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Tue, 27 Jun 2023 16:40:07 -0700 Subject: [PATCH 09/12] Adding type annotations for download_handlers and downloadable_result_settings Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- src/databricks/sql/cloudfetch/download_manager.py | 2 +- src/databricks/sql/cloudfetch/downloader.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 2c58bee9..ffc98b5f 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -46,7 +46,7 @@ class DownloadedFile: class ResultFileDownloadManager: def __init__(self, max_download_threads: int, lz4_compressed: bool): - self.download_handlers = [] + self.download_handlers: List[ResultSetDownloadHandler] = [] self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1) self.downloadable_result_settings = DownloadableResultSettings(lz4_compressed) self.fetch_need_retry = False diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index d3c4a480..7750f5fb 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -5,6 +5,7 @@ import threading import time +from databricks.sql.cloudfetch.download_manager import DownloadableResultSettings from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink logger = logging.getLogger(__name__) @@ -13,7 +14,7 @@ class ResultSetDownloadHandler(threading.Thread): def __init__( self, - downloadable_result_settings, + downloadable_result_settings: DownloadableResultSettings, t_spark_arrow_result_link: TSparkArrowResultLink, ): super().__init__() From e8c728a44390a24c25c48a62e8b7e9ebcdf64697 Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Tue, 27 Jun 2023 16:49:50 -0700 Subject: [PATCH 10/12] Move DownloadableResultSettings to downloader.py to avoid circular import Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- .../sql/cloudfetch/download_manager.py | 20 +------------------ src/databricks/sql/cloudfetch/downloader.py | 20 ++++++++++++++++++- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index ffc98b5f..b9a728c6 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -4,30 +4,12 @@ from dataclasses import dataclass from typing import List, Union -from databricks.sql.cloudfetch.downloader import ResultSetDownloadHandler +from databricks.sql.cloudfetch.downloader import ResultSetDownloadHandler, DownloadableResultSettings from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink logger = logging.getLogger(__name__) -@dataclass -class DownloadableResultSettings: - """ - Class for settings common to each download handler. - - Attributes: - is_lz4_compressed (bool): Whether file is expected to be lz4 compressed. - link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs. - download_timeout (int): Timeout for download requests. Default 60 secs. - max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down. - """ - - is_lz4_compressed: bool - link_expiry_buffer_secs: int = 0 - download_timeout: int = 60 - max_consecutive_file_download_retries: int = 0 - - @dataclass class DownloadedFile: """ diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 7750f5fb..594e5267 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -1,16 +1,34 @@ import logging +from dataclasses import dataclass import requests import lz4.frame import threading import time -from databricks.sql.cloudfetch.download_manager import DownloadableResultSettings from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink logger = logging.getLogger(__name__) +@dataclass +class DownloadableResultSettings: + """ + Class for settings common to each download handler. + + Attributes: + is_lz4_compressed (bool): Whether file is expected to be lz4 compressed. + link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs. + download_timeout (int): Timeout for download requests. Default 60 secs. + max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down. + """ + + is_lz4_compressed: bool + link_expiry_buffer_secs: int = 0 + download_timeout: int = 60 + max_consecutive_file_download_retries: int = 0 + + class ResultSetDownloadHandler(threading.Thread): def __init__( self, From 585bf4ae5a8cc2f95200c3e22177beaac2ff7dae Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Tue, 27 Jun 2023 17:07:22 -0700 Subject: [PATCH 11/12] Black linting Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- src/databricks/sql/cloudfetch/download_manager.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index b9a728c6..aac3ac33 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -4,7 +4,10 @@ from dataclasses import dataclass from typing import List, Union -from databricks.sql.cloudfetch.downloader import ResultSetDownloadHandler, DownloadableResultSettings +from databricks.sql.cloudfetch.downloader import ( + ResultSetDownloadHandler, + DownloadableResultSettings, +) from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink logger = logging.getLogger(__name__) From 243ba57105a9db37cbc06d745ee3296f73eef194 Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Tue, 27 Jun 2023 17:22:19 -0700 Subject: [PATCH 12/12] Timeout is never None Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- src/databricks/sql/cloudfetch/downloader.py | 7 +++++-- tests/unit/test_downloader.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 594e5267..019c4ef9 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -51,8 +51,11 @@ def is_file_download_successful(self) -> bool: This function will block until a file download finishes or until a timeout. """ - timeout = self.settings.download_timeout - timeout = timeout if timeout and timeout > 0 else None + timeout = ( + self.settings.download_timeout + if self.settings.download_timeout > 0 + else None + ) try: if not self.is_download_finished.wait(timeout=timeout): self.is_download_timedout = True diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index cee3a83c..6e13c949 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -136,7 +136,7 @@ def test_download_timeout(self, mock_time, mock_session): @patch("threading.Event.wait", return_value=True) def test_is_file_download_successful_has_finished(self, mock_wait): - for timeout in [None, 0, 1]: + for timeout in [0, 1]: with self.subTest(timeout=timeout): settings = Mock(download_timeout=timeout) result_link = Mock()