diff --git a/fluster/test_suite.py b/fluster/test_suite.py index 018e58f..e306243 100644 --- a/fluster/test_suite.py +++ b/fluster/test_suite.py @@ -26,6 +26,7 @@ from shutil import rmtree from typing import cast, List, Dict, Optional, Type, Any import urllib.error +import zipfile from fluster.test_vector import TestVector @@ -45,7 +46,6 @@ def __init__( extract_all: bool, keep_file: bool, test_suite_name: str, - test_vector: TestVector, retries: int, ): self.out_dir = out_dir @@ -53,9 +53,34 @@ def __init__( self.extract_all = extract_all self.keep_file = keep_file self.test_suite_name = test_suite_name - self.test_vector = test_vector self.retries = retries + # This is added to avoid having to create an extra ancestor class + def set_test_vector(self, test_vector: TestVector) -> None: + """Setter function for member variable test vector""" + # pylint: disable=attribute-defined-outside-init + + self.test_vector = test_vector + + +class DownloadWorkAV1Argon(DownloadWork): + """Context to pass to AV1 Argon download worker""" + + def __init__( + self, + out_dir: str, + verify: bool, + extract_all: bool, + keep_file: bool, + test_suite_name: str, + test_vectors: Dict[str, TestVector], + retries: int, + ): + super().__init__( + out_dir, verify, extract_all, keep_file, test_suite_name, retries + ) + self.test_vectors = test_vectors + class Context: """Context for TestSuite""" @@ -225,50 +250,82 @@ def _download_worker(ctx: DownloadWork) -> None: os.remove(dest_path) @staticmethod - def _download_worker_av1_argon(ctx: DownloadWork) -> None: - """Download and extract a av1 argon test vector""" - test_vector = ctx.test_vector + def _download_worker_av1_argon(ctx: DownloadWorkAV1Argon) -> None: + """Download AV1 Argon test suite and extract all test vectors""" + + test_vectors = ctx.test_vectors + # Extract 1st test vector from the Dict to use as reference for the download process of .zip source file that + # contains all test vectors + test_vector_0 = test_vectors[next(iter(test_vectors))] dest_dir = os.path.join(ctx.out_dir, ctx.test_suite_name) - dest_path = os.path.join(dest_dir, os.path.basename(test_vector.source)) - if not os.path.exists(dest_dir): - os.makedirs(dest_dir) - # Catch the exception that download may throw to make sure pickle can serialize it properly - # This avoids: - # Error sending result: ''. - # Reason: 'TypeError("cannot pickle '_io.BufferedReader' object")' - if not os.path.exists(dest_path): - print(f"\tDownloading test vector {test_vector.name} from {dest_dir}") - for i in range(ctx.retries): - try: - exception_str = "" - utils.download(test_vector.source, dest_dir) - except urllib.error.URLError as ex: - exception_str = str(ex) - print( - f"\tUnable to download {test_vector.source} to {dest_dir}, {exception_str}, retry count={i+1}" - ) - continue - except Exception as ex: - raise Exception(str(ex)) from ex - break - - if exception_str: - raise Exception(exception_str) - if test_vector.source_checksum != "__skip__": - checksum = utils.file_checksum(dest_path) - if test_vector.source_checksum != checksum: - raise Exception( - f"Checksum error for test vector '{test_vector.name}': '{checksum}' instead of " - f"'{test_vector.source_checksum}'" - ) - if utils.is_extractable(dest_path): - print(f"\tExtracting test vector {test_vector.name} to {dest_dir}") - utils.extract( - dest_path, - dest_dir, - file=test_vector.input_file if not ctx.extract_all else None, + # Local path to source file + dest_path = os.path.join(dest_dir, os.path.basename(test_vector_0.source)) + + # Clean up existing corrupt source file + if ( + ctx.verify + and os.path.exists(dest_path) + and utils.is_extractable(dest_path) + and test_vector_0.source_checksum != utils.file_checksum(dest_path) + ): + os.remove(dest_path) + print( + f"\tRemoved source file {dest_path} from path, checksum doesn't match with expected" ) + os.makedirs(dest_dir, exist_ok=True) + + print(f"\tDownloading source file from {test_vector_0.source}") + for i in range(ctx.retries): + try: + exception_str = "" + utils.download(test_vector_0.source, dest_dir) + except urllib.error.URLError as ex: + exception_str = str(ex) + print( + f"\tUnable to download {test_vector_0.source} to {dest_dir}, " + f"{exception_str}, retry count={i+1}" + ) + continue + except Exception as ex: + raise Exception(str(ex)) from ex + break + + if exception_str: + raise Exception(exception_str) + + # Check that source file was downloaded correctly + if test_vector_0.source_checksum != "__skip__": + checksum = utils.file_checksum(dest_path) + if test_vector_0.source_checksum != checksum: + raise Exception( + f"Checksum error for source file '{os.path.basename(test_vector_0.source)}': " + f"'{checksum}' instead of '{test_vector_0.source_checksum}'" + ) + + # Extract all test vectors from compressed source file + try: + with zipfile.ZipFile(dest_path, "r") as zip_file: + print( + f"\tExtracting test vectors from {os.path.basename(test_vector_0.source)}" + ) + for test_vector_iter in test_vectors.values(): + if test_vector_iter.input_file in zip_file.namelist(): + zip_file.extract(test_vector_iter.input_file, dest_dir) + else: + print( + f"WARNING: test vector {test_vector_iter.input_file} was not found inside source file " + f"{os.path.basename(test_vector_iter.source)}" + ) + except zipfile.BadZipFile as bad_zip_exception: + raise Exception( + f"{dest_path} could not be opened as zip file. Delete the file manually and re-try." + ) from bad_zip_exception + + # Remove source file, if applicable + if not ctx.keep_file: + os.remove(dest_path) + def download( self, jobs: int, @@ -279,27 +336,8 @@ def download( retries: int = 1, ) -> None: """Download the test suite""" - if not os.path.exists(out_dir): - os.makedirs(out_dir) - if self.name == "AV1_ARGON_VECTORS": - # Only one job to download the zip file for Argon. - jobs = 1 - dest_dir = os.path.join(out_dir, self.name) - test_vector_key = self.test_vectors[list(self.test_vectors)[0]].source - dest_folder = os.path.splitext(os.path.basename(test_vector_key))[0] - dest_path = os.path.join(dest_dir, dest_folder) - if ( - verify - and os.path.exists(dest_path) - and self.test_vectors[test_vector_key].source_checksum - == utils.file_checksum(dest_path) - ): - # Remove file only in case the input file was extractable. - # Otherwise, we'd be removing the original file we want to work - # with every even time we execute the download subcommand. - if utils.is_extractable(dest_path) and not keep_file: - os.remove(dest_path) - print(f"Downloading test suite {self.name} using {jobs} parallel jobs") + os.makedirs(out_dir, exist_ok=True) + with Pool(jobs) as pool: def _callback_error(err: Any) -> None: @@ -307,24 +345,44 @@ def _callback_error(err: Any) -> None: pool.terminate() downloads = [] - for test_vector in self.test_vectors.values(): - dwork = DownloadWork( + + if self.name != "AV1_ARGON_VECTORS": + print(f"Downloading test suite {self.name} using {jobs} parallel jobs") + for test_vector in self.test_vectors.values(): + dwork = DownloadWork( + out_dir, + verify, + extract_all, + keep_file, + self.name, + retries, + ) + dwork.set_test_vector(test_vector) + downloads.append( + pool.apply_async( + self._download_worker, + args=(dwork,), + error_callback=_callback_error, + ) + ) + else: + print( + f"Downloading test suite {self.name} using 1 job (no parallel execution possible)" + ) + dwork_av1 = DownloadWorkAV1Argon( out_dir, verify, extract_all, keep_file, self.name, - test_vector, + self.test_vectors, retries, ) - if self.name == "AV1_ARGON_VECTORS": - download_worker = self._download_worker_av1_argon - else: - download_worker = self._download_worker + # We can only use 1 parallel job because all test vectors are inside the same .zip source file downloads.append( pool.apply_async( - download_worker, - args=(dwork,), + self._download_worker_av1_argon, + args=(dwork_av1,), error_callback=_callback_error, ) ) @@ -335,16 +393,6 @@ def _callback_error(err: Any) -> None: if not job.successful(): sys.exit("Some download failed") - if self.name == "AV1_ARGON_VECTORS": - if not dwork.keep_file: - os.remove( - os.path.join( - dwork.out_dir, - dwork.test_suite_name, - os.path.basename(dwork.test_vector.source), - ) - ) - print("All downloads finished") @staticmethod