Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

COM-12459: Optimise AV1 Argon test suite extraction process #199

Merged
merged 1 commit into from
Nov 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 131 additions & 83 deletions fluster/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from shutil import rmtree
from typing import cast, List, Dict, Optional, Type, Any
import urllib.error
import zipfile


from fluster.test_vector import TestVector
Expand All @@ -45,17 +46,41 @@ def __init__(
extract_all: bool,
keep_file: bool,
test_suite_name: str,
test_vector: TestVector,
retries: int,
):
self.out_dir = out_dir
self.verify = verify
self.extract_all = extract_all
self.keep_file = keep_file
self.test_suite_name = test_suite_name
self.test_vector = test_vector
self.retries = retries

# This is added to avoid having to create an extra ancestor class
def set_test_vector(self, test_vector: TestVector) -> None:
"""Setter function for member variable test vector"""
# pylint: disable=attribute-defined-outside-init

self.test_vector = test_vector


class DownloadWorkAV1Argon(DownloadWork):
"""Context to pass to AV1 Argon download worker"""

def __init__(
self,
out_dir: str,
verify: bool,
extract_all: bool,
keep_file: bool,
test_suite_name: str,
test_vectors: Dict[str, TestVector],
retries: int,
):
super().__init__(
out_dir, verify, extract_all, keep_file, test_suite_name, retries
)
self.test_vectors = test_vectors


class Context:
"""Context for TestSuite"""
Expand Down Expand Up @@ -225,50 +250,82 @@ def _download_worker(ctx: DownloadWork) -> None:
os.remove(dest_path)

@staticmethod
def _download_worker_av1_argon(ctx: DownloadWork) -> None:
"""Download and extract a av1 argon test vector"""
test_vector = ctx.test_vector
def _download_worker_av1_argon(ctx: DownloadWorkAV1Argon) -> None:
"""Download AV1 Argon test suite and extract all test vectors"""

test_vectors = ctx.test_vectors
# Extract 1st test vector from the Dict to use as reference for the download process of .zip source file that
# contains all test vectors
test_vector_0 = test_vectors[next(iter(test_vectors))]
dest_dir = os.path.join(ctx.out_dir, ctx.test_suite_name)
dest_path = os.path.join(dest_dir, os.path.basename(test_vector.source))
if not os.path.exists(dest_dir):
os.makedirs(dest_dir)
# Catch the exception that download may throw to make sure pickle can serialize it properly
# This avoids:
# Error sending result: '<multiprocessing.pool.ExceptionWithTraceback object at 0x7fd7811ecee0>'.
# Reason: 'TypeError("cannot pickle '_io.BufferedReader' object")'
if not os.path.exists(dest_path):
print(f"\tDownloading test vector {test_vector.name} from {dest_dir}")
for i in range(ctx.retries):
try:
exception_str = ""
utils.download(test_vector.source, dest_dir)
except urllib.error.URLError as ex:
exception_str = str(ex)
print(
f"\tUnable to download {test_vector.source} to {dest_dir}, {exception_str}, retry count={i+1}"
)
continue
except Exception as ex:
raise Exception(str(ex)) from ex
break

if exception_str:
raise Exception(exception_str)
if test_vector.source_checksum != "__skip__":
checksum = utils.file_checksum(dest_path)
if test_vector.source_checksum != checksum:
raise Exception(
f"Checksum error for test vector '{test_vector.name}': '{checksum}' instead of "
f"'{test_vector.source_checksum}'"
)
if utils.is_extractable(dest_path):
print(f"\tExtracting test vector {test_vector.name} to {dest_dir}")
utils.extract(
dest_path,
dest_dir,
file=test_vector.input_file if not ctx.extract_all else None,
# Local path to source file
dest_path = os.path.join(dest_dir, os.path.basename(test_vector_0.source))

# Clean up existing corrupt source file
if (
ctx.verify
and os.path.exists(dest_path)
and utils.is_extractable(dest_path)
and test_vector_0.source_checksum != utils.file_checksum(dest_path)
):
os.remove(dest_path)
print(
f"\tRemoved source file {dest_path} from path, checksum doesn't match with expected"
)

os.makedirs(dest_dir, exist_ok=True)

print(f"\tDownloading source file from {test_vector_0.source}")
for i in range(ctx.retries):
try:
exception_str = ""
utils.download(test_vector_0.source, dest_dir)
except urllib.error.URLError as ex:
exception_str = str(ex)
print(
f"\tUnable to download {test_vector_0.source} to {dest_dir}, "
f"{exception_str}, retry count={i+1}"
)
continue
except Exception as ex:
raise Exception(str(ex)) from ex
break

if exception_str:
raise Exception(exception_str)

# Check that source file was downloaded correctly
if test_vector_0.source_checksum != "__skip__":
checksum = utils.file_checksum(dest_path)
if test_vector_0.source_checksum != checksum:
raise Exception(
f"Checksum error for source file '{os.path.basename(test_vector_0.source)}': "
f"'{checksum}' instead of '{test_vector_0.source_checksum}'"
)

# Extract all test vectors from compressed source file
try:
with zipfile.ZipFile(dest_path, "r") as zip_file:
print(
f"\tExtracting test vectors from {os.path.basename(test_vector_0.source)}"
)
for test_vector_iter in test_vectors.values():
if test_vector_iter.input_file in zip_file.namelist():
zip_file.extract(test_vector_iter.input_file, dest_dir)
else:
print(
f"WARNING: test vector {test_vector_iter.input_file} was not found inside source file "
f"{os.path.basename(test_vector_iter.source)}"
)
except zipfile.BadZipFile as bad_zip_exception:
raise Exception(
f"{dest_path} could not be opened as zip file. Delete the file manually and re-try."
) from bad_zip_exception

# Remove source file, if applicable
if not ctx.keep_file:
os.remove(dest_path)

def download(
self,
jobs: int,
Expand All @@ -279,52 +336,53 @@ def download(
retries: int = 1,
) -> None:
"""Download the test suite"""
if not os.path.exists(out_dir):
os.makedirs(out_dir)
if self.name == "AV1_ARGON_VECTORS":
# Only one job to download the zip file for Argon.
jobs = 1
dest_dir = os.path.join(out_dir, self.name)
test_vector_key = self.test_vectors[list(self.test_vectors)[0]].source
dest_folder = os.path.splitext(os.path.basename(test_vector_key))[0]
dest_path = os.path.join(dest_dir, dest_folder)
if (
verify
and os.path.exists(dest_path)
and self.test_vectors[test_vector_key].source_checksum
== utils.file_checksum(dest_path)
):
# Remove file only in case the input file was extractable.
# Otherwise, we'd be removing the original file we want to work
# with every even time we execute the download subcommand.
if utils.is_extractable(dest_path) and not keep_file:
os.remove(dest_path)
print(f"Downloading test suite {self.name} using {jobs} parallel jobs")
os.makedirs(out_dir, exist_ok=True)

with Pool(jobs) as pool:

def _callback_error(err: Any) -> None:
print(f"\nError downloading -> {err}\n")
pool.terminate()

downloads = []
for test_vector in self.test_vectors.values():
dwork = DownloadWork(

if self.name != "AV1_ARGON_VECTORS":
print(f"Downloading test suite {self.name} using {jobs} parallel jobs")
for test_vector in self.test_vectors.values():
dwork = DownloadWork(
out_dir,
verify,
extract_all,
keep_file,
self.name,
retries,
)
dwork.set_test_vector(test_vector)
downloads.append(
pool.apply_async(
self._download_worker,
args=(dwork,),
error_callback=_callback_error,
)
)
else:
print(
f"Downloading test suite {self.name} using 1 job (no parallel execution possible)"
)
dwork_av1 = DownloadWorkAV1Argon(
out_dir,
verify,
extract_all,
keep_file,
self.name,
test_vector,
self.test_vectors,
retries,
)
if self.name == "AV1_ARGON_VECTORS":
download_worker = self._download_worker_av1_argon
else:
download_worker = self._download_worker
# We can only use 1 parallel job because all test vectors are inside the same .zip source file
downloads.append(
pool.apply_async(
download_worker,
args=(dwork,),
self._download_worker_av1_argon,
args=(dwork_av1,),
error_callback=_callback_error,
)
)
Expand All @@ -335,16 +393,6 @@ def _callback_error(err: Any) -> None:
if not job.successful():
sys.exit("Some download failed")

if self.name == "AV1_ARGON_VECTORS":
if not dwork.keep_file:
os.remove(
os.path.join(
dwork.out_dir,
dwork.test_suite_name,
os.path.basename(dwork.test_vector.source),
)
)

print("All downloads finished")

@staticmethod
Expand Down
Loading