From 86d1615f3d96033f4ce290f474df09923561fcf4 Mon Sep 17 00:00:00 2001 From: Raymond Cypher Date: Fri, 9 Feb 2024 17:15:45 -0700 Subject: [PATCH 1/3] Debug logging and temp code changes (#350) Signed-off-by: Raymond Cypher --- src/databricks/sql/cloudfetch/downloader.py | 48 ++++++++++++++++++--- src/databricks/sql/thrift_backend.py | 1 + src/databricks/sql/utils.py | 5 ++- 3 files changed, 47 insertions(+), 7 deletions(-) diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 019c4ef9..697e09a5 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -1,15 +1,19 @@ import logging from dataclasses import dataclass - +from datetime import datetime import requests import lz4.frame import threading import time - +import os +from threading import get_ident from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink +logging.basicConfig(format="%(asctime)s %(message)s") logger = logging.getLogger(__name__) +DEFAULT_CLOUD_FILE_TIMEOUT = int(os.getenv("DATABRICKS_CLOUD_FILE_TIMEOUT", 60)) + @dataclass class DownloadableResultSettings: @@ -25,7 +29,7 @@ class DownloadableResultSettings: is_lz4_compressed: bool link_expiry_buffer_secs: int = 0 - download_timeout: int = 60 + download_timeout: int = DEFAULT_CLOUD_FILE_TIMEOUT max_consecutive_file_download_retries: int = 0 @@ -57,16 +61,27 @@ def is_file_download_successful(self) -> bool: else None ) try: + msg = f"{datetime.now()} {(os.getpid(), get_ident())} wait for {timeout} for download file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount} " + logger.debug( + f"{datetime.now()} {(os.getpid(), get_ident())} wait for {timeout} for download file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount} " + ) + if not self.is_download_finished.wait(timeout=timeout): self.is_download_timedout = True logger.debug( - "Cloud fetch download timed out after {} seconds for link representing rows {} to {}".format( + "{} {} Cloud fetch download timed out after {} seconds for link representing rows {} to {}".format( + datetime.now(), + (os.getpid(), get_ident()), self.settings.download_timeout, self.result_link.startRowOffset, self.result_link.startRowOffset + self.result_link.rowCount, ) ) return False + + logger.debug( + f"{datetime.now()} {(os.getpid(), get_ident())} success wait for {timeout} for download file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount} " + ) except Exception as e: logger.error(e) return False @@ -92,10 +107,22 @@ def run(self): session.timeout = self.settings.download_timeout try: + logger.debug( + f"{datetime.now()} {(os.getpid(), get_ident())} start download file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount}" + ) + # Get the file via HTTP request response = session.get(self.result_link.fileLink) + logger.debug( + f"{datetime.now()} {(os.getpid(), get_ident())} finish download file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount}" + ) + if not response.ok: + logger.error( + f"{datetime.now()} {(os.getpid(), get_ident())} failed downloading file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount}" + ) + logger.error(response) self.is_file_downloaded_successfully = False return @@ -109,18 +136,27 @@ def run(self): self.result_file = decompressed_data # The size of the downloaded file should match the size specified from TSparkArrowResultLink - self.is_file_downloaded_successfully = ( - len(self.result_file) == self.result_link.bytesNum + success = len(self.result_file) == self.result_link.bytesNum + logger.debug( + f"{datetime.now()} {(os.getpid(), get_ident())} download successful file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount}" ) + self.is_file_downloaded_successfully = success except Exception as e: + logger.debug( + f"{datetime.now()} {(os.getpid(), get_ident())} exception download file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount}" + ) logger.error(e) self.is_file_downloaded_successfully = False finally: session and session.close() + logger.debug( + f"{datetime.now()} {(os.getpid(), get_ident())} signal finished file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount}" + ) # Awaken threads waiting for this to be true which signals the run is complete self.is_download_finished.set() + def _reset(self): """ Reset download-related flags for every retry of run() diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 69ac760a..4a9e80c4 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -975,6 +975,7 @@ def fetch_results( arrow_schema_bytes, description, ): + logger.debug("ThriftBackend fetch_results") assert op_handle is not None req = ttypes.TFetchResultsReq( diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 7c3a014b..14fa06af 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -156,8 +156,11 @@ def __init__( self.lz4_compressed = lz4_compressed self.description = description + # self.download_manager = ResultFileDownloadManager( + # self.max_download_threads, self.lz4_compressed + # ) self.download_manager = ResultFileDownloadManager( - self.max_download_threads, self.lz4_compressed + 1, self.lz4_compressed ) self.download_manager.add_file_links(result_links) From 90ec42880582642013d5bc16e906874829cd554a Mon Sep 17 00:00:00 2001 From: Andre Furlan Date: Thu, 15 Feb 2024 13:00:57 -0800 Subject: [PATCH 2/3] fixes for cloud fetch --- examples/query_execute.py | 31 +++++-- .../sql/cloudfetch/download_manager.py | 35 ++------ src/databricks/sql/cloudfetch/downloader.py | 88 ++++++++++++------- src/databricks/sql/exc.py | 4 + src/databricks/sql/thrift_backend.py | 21 ++++- src/databricks/sql/utils.py | 11 ++- 6 files changed, 112 insertions(+), 78 deletions(-) diff --git a/examples/query_execute.py b/examples/query_execute.py index ec79fd0e..b6919f71 100644 --- a/examples/query_execute.py +++ b/examples/query_execute.py @@ -1,13 +1,30 @@ +import threading from databricks import sql import os +import logging + + +logger = logging.getLogger("databricks.sql") +logger.setLevel(logging.INFO) +fh = logging.FileHandler('pysqllogs.log') +fh.setFormatter(logging.Formatter("%(asctime)s %(process)d %(thread)d %(message)s")) +fh.setLevel(logging.DEBUG) +logger.addHandler(fh) with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), http_path = os.getenv("DATABRICKS_HTTP_PATH"), - access_token = os.getenv("DATABRICKS_TOKEN")) as connection: - - with connection.cursor() as cursor: - cursor.execute("SELECT * FROM default.diamonds LIMIT 2") - result = cursor.fetchall() + access_token = os.getenv("DATABRICKS_TOKEN"), + # max_download_threads = 2 + ) as connection: - for row in result: - print(row) \ No newline at end of file + with connection.cursor( + # arraysize=100 + ) as cursor: + cursor.execute("SELECT * FROM range(0, 10000000) AS t1 LEFT JOIN (SELECT 1) AS t2") + # cursor.execute("SELECT * FROM andre.plotly_iot_dashboard.bronze_sensors limit 1000001") + try: + result = cursor.fetchall() + print(f"result length: {len(result)}") + except sql.exc.ResultSetDownloadError as e: + print(f"error: {e}") + # buffer_size_bytes=4857600 \ No newline at end of file diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 9a997f39..fdbbe55b 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -8,6 +8,7 @@ ResultSetDownloadHandler, DownloadableResultSettings, ) +from databricks.sql.exc import ResultSetDownloadError from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink logger = logging.getLogger(__name__) @@ -81,13 +82,15 @@ def get_next_downloaded_file( # Find next file idx = self._find_next_file_index(next_row_offset) + # is this correct? if idx is None: self._shutdown_manager() + logger.debug("could not find next file index") return None handler = self.download_handlers[idx] # Check (and wait) for download status - if self._check_if_download_successful(handler): + if handler.is_file_download_successful(): # Buffer should be empty so set buffer to new ArrowQueue with result_file result = DownloadedFile( handler.result_file, @@ -97,9 +100,9 @@ def get_next_downloaded_file( self.download_handlers.pop(idx) # Return True upon successful download to continue loop and not force a retry return result - # Download was not successful for next download item, force a retry + # Download was not successful for next download item. Fail self._shutdown_manager() - return None + raise ResultSetDownloadError(f"Download failed for result set starting at {next_row_offset}") def _remove_past_handlers(self, next_row_offset: int): # Any link in which its start to end range doesn't include the next row to be fetched does not need downloading @@ -133,32 +136,6 @@ def _find_next_file_index(self, next_row_offset: int): ] return next_indices[0] if len(next_indices) > 0 else None - def _check_if_download_successful(self, handler: ResultSetDownloadHandler): - # Check (and wait until download finishes) if download was successful - if not handler.is_file_download_successful(): - if handler.is_link_expired: - self.fetch_need_retry = True - return False - elif handler.is_download_timedout: - # Consecutive file retries should not exceed threshold in settings - if ( - self.num_consecutive_result_file_download_retries - >= self.downloadable_result_settings.max_consecutive_file_download_retries - ): - self.fetch_need_retry = True - return False - self.num_consecutive_result_file_download_retries += 1 - - # Re-submit handler run to thread pool and recursively check download status - self.thread_pool.submit(handler.run) - return self._check_if_download_successful(handler) - else: - self.fetch_need_retry = True - return False - - self.num_consecutive_result_file_download_retries = 0 - self.fetch_need_retry = False - return True def _shutdown_manager(self): # Clear download handlers and shutdown the thread pool diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 697e09a5..c880ce1b 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -61,26 +61,19 @@ def is_file_download_successful(self) -> bool: else None ) try: - msg = f"{datetime.now()} {(os.getpid(), get_ident())} wait for {timeout} for download file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount} " logger.debug( - f"{datetime.now()} {(os.getpid(), get_ident())} wait for {timeout} for download file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount} " + f"waiting for at most {timeout} seconds for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" ) if not self.is_download_finished.wait(timeout=timeout): self.is_download_timedout = True logger.debug( - "{} {} Cloud fetch download timed out after {} seconds for link representing rows {} to {}".format( - datetime.now(), - (os.getpid(), get_ident()), - self.settings.download_timeout, - self.result_link.startRowOffset, - self.result_link.startRowOffset + self.result_link.rowCount, - ) + f"cloud fetch download timed out after {self.settings.download_timeout} seconds for link representing rows {self.result_link.startRowOffset} to {self.result_link.startRowOffset + self.result_link.rowCount}" ) return False logger.debug( - f"{datetime.now()} {(os.getpid(), get_ident())} success wait for {timeout} for download file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount} " + f"finish waiting for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" ) except Exception as e: logger.error(e) @@ -95,37 +88,33 @@ def run(self): file, and signals to waiting threads that the download is finished and whether it was successful. """ self._reset() - - # Check if link is already expired or is expiring - if ResultSetDownloadHandler.check_link_expired( - self.result_link, self.settings.link_expiry_buffer_secs - ): - self.is_link_expired = True - return - - session = requests.Session() - session.timeout = self.settings.download_timeout + try: + # Check if link is already expired or is expiring + if ResultSetDownloadHandler.check_link_expired( + self.result_link, self.settings.link_expiry_buffer_secs + ): + self.is_link_expired = True + return + logger.debug( - f"{datetime.now()} {(os.getpid(), get_ident())} start download file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount}" + f"started to download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" ) # Get the file via HTTP request - response = session.get(self.result_link.fileLink) + response = http_get_with_retry(url=self.result_link.fileLink, download_timeout=self.settings.download_timeout) - logger.debug( - f"{datetime.now()} {(os.getpid(), get_ident())} finish download file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount}" - ) - - if not response.ok: + if not response: logger.error( - f"{datetime.now()} {(os.getpid(), get_ident())} failed downloading file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount}" + f"failed downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" ) - logger.error(response) - self.is_file_downloaded_successfully = False return + logger.debug( + f"success downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" + ) + # Save (and decompress if needed) the downloaded file compressed_data = response.content decompressed_data = ( @@ -138,20 +127,19 @@ def run(self): # The size of the downloaded file should match the size specified from TSparkArrowResultLink success = len(self.result_file) == self.result_link.bytesNum logger.debug( - f"{datetime.now()} {(os.getpid(), get_ident())} download successful file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount}" + f"download successful file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" ) self.is_file_downloaded_successfully = success except Exception as e: logger.debug( - f"{datetime.now()} {(os.getpid(), get_ident())} exception download file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount}" + f"exception downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" ) logger.error(e) self.is_file_downloaded_successfully = False finally: - session and session.close() logger.debug( - f"{datetime.now()} {(os.getpid(), get_ident())} signal finished file: startRow {self.result_link.startRowOffset}, rowCount{self.result_link.rowCount}" + f"signal finished file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" ) # Awaken threads waiting for this to be true which signals the run is complete self.is_download_finished.set() @@ -181,6 +169,9 @@ def check_link_expired( link.expiryTime < current_time or link.expiryTime - current_time < expiry_buffer_secs ): + logger.debug( + f"{(os.getpid(), get_ident())} - link expired" + ) return True return False @@ -207,3 +198,32 @@ def decompress_data(compressed_data: bytes) -> bytes: uncompressed_data += data start += num_bytes return uncompressed_data + + +def http_get_with_retry(url, max_retries=5, backoff_factor=2, download_timeout=60): + attempts = 0 + + while attempts < max_retries: + try: + session = requests.Session() + session.timeout = download_timeout + response = session.get(url) + + # Check if the response status code is in the 2xx range for success + if response.status_code == 200: + return response + else: + logger.error(response) + except requests.RequestException as e: + print(f"request failed with exception: {e}") + finally: + session.close() + # Exponential backoff before the next attempt + wait_time = backoff_factor ** attempts + logger.info(f"retrying in {wait_time} seconds...") + time.sleep(wait_time) + + attempts += 1 + + logger.error(f"exceeded maximum number of retries ({max_retries}) while downloading result.") + return None diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 3b27283a..a63f0128 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -115,3 +115,7 @@ class SessionAlreadyClosedError(RequestError): class CursorAlreadyClosedError(RequestError): """Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected.""" + + +class ResultSetDownloadError(RequestError): + """Thrown if there was an error during the download of a result set""" \ No newline at end of file diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 4a9e80c4..be97b352 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -393,7 +393,7 @@ def attempt_request(attempt): try: this_method_name = getattr(method, "__name__") - logger.debug("Sending request: {}()".format(this_method_name)) + logger.debug("sending thrift request: {}()".format(this_method_name)) unsafe_logger.debug("Sending request: {}".format(request)) # These three lines are no-ops if the v3 retry policy is not in use @@ -406,7 +406,7 @@ def attempt_request(attempt): # We need to call type(response) here because thrift doesn't implement __name__ attributes for thrift responses logger.debug( - "Received response: {}()".format(type(response).__name__) + "received thrift response: {}()".format(type(response).__name__) ) unsafe_logger.debug("Received response: {}".format(response)) return response @@ -764,6 +764,9 @@ def _results_message_to_execute_response(self, resp, operation_state): lz4_compressed = t_result_set_metadata_resp.lz4Compressed is_staging_operation = t_result_set_metadata_resp.isStagingOperation if direct_results and direct_results.resultSet: + logger.debug( + f"received direct results" + ) assert direct_results.resultSet.results.startRowOffset == 0 assert direct_results.resultSetMetadata @@ -776,6 +779,9 @@ def _results_message_to_execute_response(self, resp, operation_state): description=description, ) else: + logger.debug( + f"must fetch results" + ) arrow_queue_opt = None return ExecuteResponse( arrow_queue=arrow_queue_opt, @@ -835,11 +841,15 @@ def execute_command( max_bytes, lz4_compression, cursor, - use_cloud_fetch=True, + use_cloud_fetch=True, # change here parameters=[], ): assert session_handle is not None + logger.debug( + f"executing: cloud fetch: {use_cloud_fetch}, max rows: {max_rows}, max bytes: {max_bytes}" + ) + spark_arrow_types = ttypes.TSparkArrowTypes( timestampAsArrow=self._use_arrow_native_timestamps, decimalAsArrow=self._use_arrow_native_decimals, @@ -955,6 +965,9 @@ def get_columns( return self._handle_execute_response(resp, cursor) def _handle_execute_response(self, resp, cursor): + logger.debug( + f"got execute response" + ) cursor.active_op_handle = resp.operationHandle self._check_direct_results_for_error(resp.directResults) @@ -975,7 +988,7 @@ def fetch_results( arrow_schema_bytes, description, ): - logger.debug("ThriftBackend fetch_results") + logger.debug("started to fetch results") assert op_handle is not None req = ttypes.TFetchResultsReq( diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 14fa06af..e8b2ecfa 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -81,6 +81,9 @@ def build_queue( ) return ArrowQueue(converted_arrow_table, n_valid_rows) elif row_set_type == TSparkRowSetType.URL_BASED_SET: + logger.debug( + f"built cloud fetch queue for {len(t_row_set.resultLinks)} links." + ) return CloudFetchQueue( arrow_schema_bytes, start_row_offset=t_row_set.startRowOffset, @@ -156,11 +159,11 @@ def __init__( self.lz4_compressed = lz4_compressed self.description = description - # self.download_manager = ResultFileDownloadManager( - # self.max_download_threads, self.lz4_compressed - # ) + logger.debug( + f"creating cloud fetch queue for {len(result_links)} links and max_download_threads {self.max_download_threads}." + ) self.download_manager = ResultFileDownloadManager( - 1, self.lz4_compressed + self.max_download_threads, self.lz4_compressed ) self.download_manager.add_file_links(result_links) From 6548396c31b3e9e625ccf88d6ed71535b8d16c99 Mon Sep 17 00:00:00 2001 From: Andre Furlan Date: Thu, 15 Feb 2024 15:33:55 -0800 Subject: [PATCH 3/3] fix tests --- examples/query_execute.py | 6 +- .../sql/cloudfetch/download_manager.py | 2 - src/databricks/sql/cloudfetch/downloader.py | 18 ++- tests/unit/test_download_manager.py | 126 +++++++++--------- tests/unit/test_downloader.py | 32 ++++- 5 files changed, 102 insertions(+), 82 deletions(-) diff --git a/examples/query_execute.py b/examples/query_execute.py index b6919f71..de0cd658 100644 --- a/examples/query_execute.py +++ b/examples/query_execute.py @@ -5,7 +5,7 @@ logger = logging.getLogger("databricks.sql") -logger.setLevel(logging.INFO) +logger.setLevel(logging.DEBUG) fh = logging.FileHandler('pysqllogs.log') fh.setFormatter(logging.Formatter("%(asctime)s %(process)d %(thread)d %(message)s")) fh.setLevel(logging.DEBUG) @@ -20,8 +20,8 @@ with connection.cursor( # arraysize=100 ) as cursor: - cursor.execute("SELECT * FROM range(0, 10000000) AS t1 LEFT JOIN (SELECT 1) AS t2") - # cursor.execute("SELECT * FROM andre.plotly_iot_dashboard.bronze_sensors limit 1000001") + # cursor.execute("SELECT * FROM range(0, 10000000) AS t1 LEFT JOIN (SELECT 1) AS t2") + cursor.execute("SELECT * FROM andre.plotly_iot_dashboard.bronze_sensors limit 1000001") try: result = cursor.fetchall() print(f"result length: {len(result)}") diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index fdbbe55b..97c0dd8a 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -35,8 +35,6 @@ def __init__(self, max_download_threads: int, lz4_compressed: bool): self.download_handlers: List[ResultSetDownloadHandler] = [] self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1) self.downloadable_result_settings = DownloadableResultSettings(lz4_compressed) - self.fetch_need_retry = False - self.num_consecutive_result_file_download_retries = 0 def add_file_links( self, t_spark_arrow_result_links: List[TSparkArrowResultLink] diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index c880ce1b..518c4298 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -24,13 +24,17 @@ class DownloadableResultSettings: is_lz4_compressed (bool): Whether file is expected to be lz4 compressed. link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs. download_timeout (int): Timeout for download requests. Default 60 secs. - max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down. + download_max_retries (int): Number of consecutive download retries before shutting down. + max_retries (int): Number of consecutive download retries before shutting down. + backoff_factor (int): Factor to increase wait time between retries. + """ is_lz4_compressed: bool link_expiry_buffer_secs: int = 0 download_timeout: int = DEFAULT_CLOUD_FILE_TIMEOUT - max_consecutive_file_download_retries: int = 0 + max_retries: int = 5 + backoff_factor: int = 2 class ResultSetDownloadHandler(threading.Thread): @@ -70,7 +74,8 @@ def is_file_download_successful(self) -> bool: logger.debug( f"cloud fetch download timed out after {self.settings.download_timeout} seconds for link representing rows {self.result_link.startRowOffset} to {self.result_link.startRowOffset + self.result_link.rowCount}" ) - return False + # there are some weird cases when the is_download_finished is not set, but the file is downloaded successfully + return self.is_file_downloaded_successfully logger.debug( f"finish waiting for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}" @@ -103,7 +108,12 @@ def run(self): ) # Get the file via HTTP request - response = http_get_with_retry(url=self.result_link.fileLink, download_timeout=self.settings.download_timeout) + response = http_get_with_retry( + url=self.result_link.fileLink, + max_retries=self.settings.max_retries, + backoff_factor=self.settings.backoff_factor, + download_timeout=self.settings.download_timeout, + ) if not response: logger.error( diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index 97bf407a..0329f03c 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -139,69 +139,63 @@ def test_find_next_file_index_one_scheduled_next_row_8000(self, mock_submit): assert manager._find_next_file_index(8000) is None - @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", - return_value=True) - @patch("concurrent.futures.ThreadPoolExecutor.submit") - def test_check_if_download_successful_happy(self, mock_submit, mock_is_file_download_successful): - links = self.create_result_links(num_files=10) - manager = self.create_download_manager() - manager.add_file_links(links) - manager._schedule_downloads() - - status = manager._check_if_download_successful(manager.download_handlers[0]) - assert status - assert manager.num_consecutive_result_file_download_retries == 0 - - @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", - return_value=False) - def test_check_if_download_successful_link_expired(self, mock_is_file_download_successful): - manager = self.create_download_manager() - handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) - handler.is_link_expired = True - - status = manager._check_if_download_successful(handler) - mock_is_file_download_successful.assert_called() - assert not status - assert manager.fetch_need_retry - - @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", - return_value=False) - def test_check_if_download_successful_download_timed_out_no_retries(self, mock_is_file_download_successful): - manager = self.create_download_manager() - handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) - handler.is_download_timedout = True - - status = manager._check_if_download_successful(handler) - mock_is_file_download_successful.assert_called() - assert not status - assert manager.fetch_need_retry - - @patch("concurrent.futures.ThreadPoolExecutor.submit") - @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", - return_value=False) - def test_check_if_download_successful_download_timed_out_1_retry(self, mock_is_file_download_successful, mock_submit): - manager = self.create_download_manager() - manager.downloadable_result_settings = download_manager.DownloadableResultSettings( - is_lz4_compressed=True, - download_timeout=0, - max_consecutive_file_download_retries=1, - ) - handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) - handler.is_download_timedout = True - - status = manager._check_if_download_successful(handler) - assert mock_is_file_download_successful.call_count == 2 - assert mock_submit.call_count == 1 - assert not status - assert manager.fetch_need_retry - - @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", - return_value=False) - def test_check_if_download_successful_other_reason(self, mock_is_file_download_successful): - manager = self.create_download_manager() - handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) - - status = manager._check_if_download_successful(handler) - mock_is_file_download_successful.assert_called() - assert not status - assert manager.fetch_need_retry + # @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", + # return_value=True) + # @patch("concurrent.futures.ThreadPoolExecutor.submit") + # def test_check_if_download_successful_happy(self, mock_submit, mock_is_file_download_successful): + # links = self.create_result_links(num_files=10) + # manager = self.create_download_manager() + # manager.add_file_links(links) + # manager._schedule_downloads() + + # assert status + + # @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", + # return_value=False) + # def test_check_if_download_successful_link_expired(self, mock_is_file_download_successful): + # manager = self.create_download_manager() + # handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) + # handler.is_link_expired = True + + # status = manager._check_if_download_successful(handler) + # mock_is_file_download_successful.assert_called() + # assert not status + + # @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", + # return_value=False) + # def test_check_if_download_successful_download_timed_out_no_retries(self, mock_is_file_download_successful): + # manager = self.create_download_manager() + # handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) + # handler.is_download_timedout = True + + # status = manager._check_if_download_successful(handler) + # mock_is_file_download_successful.assert_called() + # assert not status + + # @patch("concurrent.futures.ThreadPoolExecutor.submit") + # @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", + # return_value=False) + # def test_check_if_download_successful_download_timed_out_1_retry(self, mock_is_file_download_successful, mock_submit): + # manager = self.create_download_manager() + # manager.downloadable_result_settings = download_manager.DownloadableResultSettings( + # is_lz4_compressed=True, + # download_timeout=0, + # max_consecutive_file_download_retries=1, + # ) + # handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) + # handler.is_download_timedout = True + + # status = manager._check_if_download_successful(handler) + # assert mock_is_file_download_successful.call_count == 2 + # assert mock_submit.call_count == 1 + # assert not status + + # @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", + # return_value=False) + # def test_check_if_download_successful_other_reason(self, mock_is_file_download_successful): + # manager = self.create_download_manager() + # handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) + + # status = manager._check_if_download_successful(handler) + # mock_is_file_download_successful.assert_called() + # assert not status diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 6e13c949..9dbc263e 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -13,18 +13,21 @@ class DownloaderTests(unittest.TestCase): def test_run_link_expired(self, mock_time): settings = Mock() result_link = Mock() + result_link.startRowOffset = 0 + result_link.rowCount = 100 # Already expired result_link.expiryTime = 999 d = downloader.ResultSetDownloadHandler(settings, result_link) assert not d.is_link_expired d.run() assert d.is_link_expired - mock_time.assert_called_once() @patch('time.time', return_value=1000) def test_run_link_past_expiry_buffer(self, mock_time): settings = Mock(link_expiry_buffer_secs=5) result_link = Mock() + result_link.startRowOffset = 0 + result_link.rowCount = 100 # Within the expiry buffer time result_link.expiryTime = 1004 d = downloader.ResultSetDownloadHandler(settings, result_link) @@ -33,13 +36,15 @@ def test_run_link_past_expiry_buffer(self, mock_time): assert d.is_link_expired mock_time.assert_called_once() - @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=False)))) + @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=False, status_code=500)))) @patch('time.time', return_value=1000) def test_run_get_response_not_ok(self, mock_time, mock_session): settings = Mock(link_expiry_buffer_secs=0, download_timeout=0) settings.download_timeout = 0 settings.use_proxy = False result_link = Mock(expiryTime=1001) + result_link.startRowOffset = 0 + result_link.rowCount = 100 d = downloader.ResultSetDownloadHandler(settings, result_link) d.run() @@ -48,11 +53,13 @@ def test_run_get_response_not_ok(self, mock_time, mock_session): assert d.is_download_finished.is_set() @patch('requests.Session', - return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, content=b"1234567890" * 9)))) + return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, status_code=200, content=b"1234567890" * 9)))) @patch('time.time', return_value=1000) def test_run_uncompressed_data_length_incorrect(self, mock_time, mock_session): settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False, is_lz4_compressed=False) result_link = Mock(bytesNum=100, expiryTime=1001) + result_link.startRowOffset = 0 + result_link.rowCount = 100 d = downloader.ResultSetDownloadHandler(settings, result_link) d.run() @@ -60,12 +67,14 @@ def test_run_uncompressed_data_length_incorrect(self, mock_time, mock_session): assert not d.is_file_downloaded_successfully assert d.is_download_finished.is_set() - @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True)))) + @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, status_code=200)))) @patch('time.time', return_value=1000) def test_run_compressed_data_length_incorrect(self, mock_time, mock_session): settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = True result_link = Mock(bytesNum=100, expiryTime=1001) + result_link.startRowOffset = 0 + result_link.rowCount = 100 mock_session.return_value.get.return_value.content = \ b'\x04"M\x18h@Z\x00\x00\x00\x00\x00\x00\x00\xec\x14\x00\x00\x00\xaf1234567890\n\x008P67890\x00\x00\x00\x00' @@ -76,13 +85,14 @@ def test_run_compressed_data_length_incorrect(self, mock_time, mock_session): assert d.is_download_finished.is_set() @patch('requests.Session', - return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, content=b"1234567890" * 10)))) + return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, status_code=200, content=b"1234567890" * 10)))) @patch('time.time', return_value=1000) def test_run_uncompressed_successful(self, mock_time, mock_session): settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = False result_link = Mock(bytesNum=100, expiryTime=1001) - + result_link.startRowOffset = 0 + result_link.rowCount = 100 d = downloader.ResultSetDownloadHandler(settings, result_link) d.run() @@ -90,12 +100,14 @@ def test_run_uncompressed_successful(self, mock_time, mock_session): assert d.is_file_downloaded_successfully assert d.is_download_finished.is_set() - @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True)))) + @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, status_code=200)))) @patch('time.time', return_value=1000) def test_run_compressed_successful(self, mock_time, mock_session): settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = True result_link = Mock(bytesNum=100, expiryTime=1001) + result_link.startRowOffset = 0 + result_link.rowCount = 100 mock_session.return_value.get.return_value.content = \ b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' @@ -111,6 +123,8 @@ def test_run_compressed_successful(self, mock_time, mock_session): def test_download_connection_error(self, mock_time, mock_session): settings = Mock(link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True) result_link = Mock(bytesNum=100, expiryTime=1001) + result_link.startRowOffset = 0 + result_link.rowCount = 100 mock_session.return_value.get.return_value.content = \ b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' @@ -125,6 +139,8 @@ def test_download_connection_error(self, mock_time, mock_session): def test_download_timeout(self, mock_time, mock_session): settings = Mock(link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True) result_link = Mock(bytesNum=100, expiryTime=1001) + result_link.startRowOffset = 0 + result_link.rowCount = 100 mock_session.return_value.get.return_value.content = \ b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' @@ -148,6 +164,8 @@ def test_is_file_download_successful_has_finished(self, mock_wait): def test_is_file_download_successful_times_outs(self): settings = Mock(download_timeout=1) result_link = Mock() + result_link.startRowOffset = 0 + result_link.rowCount = 100 handler = downloader.ResultSetDownloadHandler(settings, result_link) status = handler.is_file_download_successful()