From eb7768ff5d59f16cd877696bd8dd14983752b96e Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Fri, 29 Nov 2024 18:59:40 +0100 Subject: [PATCH 01/10] Run prediction tasks in parallel --- cog_safe_push/ai.py | 45 ++- cog_safe_push/config.py | 3 +- cog_safe_push/main.py | 281 +++++++----------- cog_safe_push/match_outputs.py | 32 +- cog_safe_push/predict.py | 201 +------------ cog_safe_push/retry.py | 24 -- cog_safe_push/schema.py | 2 +- cog_safe_push/task_context.py | 101 +++++++ cog_safe_push/tasks.py | 183 ++++++++++++ end-to-end-test/test_end_to_end.py | 143 +++++---- .../test_output_matches_prompt.py | 18 +- pyrightconfig.json | 2 +- script/lint | 9 +- test/test_match_outputs.py | 26 +- test/test_predict.py | 28 +- 15 files changed, 595 insertions(+), 503 deletions(-) delete mode 100644 cog_safe_push/retry.py create mode 100644 cog_safe_push/task_context.py create mode 100644 cog_safe_push/tasks.py diff --git a/cog_safe_push/ai.py b/cog_safe_push/ai.py index 16043d6..71a333c 100644 --- a/cog_safe_push/ai.py +++ b/cog_safe_push/ai.py @@ -1,4 +1,5 @@ import base64 +import functools import json import mimetypes import os @@ -10,16 +11,36 @@ from . import log from .exceptions import AIError, ArgumentError -from .retry import retry -@retry(3) -def boolean( +def async_retry(attempts=3): + def decorator_retry(func): + @functools.wraps(func) + async def wrapper_retry(*args, **kwargs): + for attempt in range(1, attempts + 1): + try: + return await func(*args, **kwargs) + except Exception as e: + log.warning(f"Exception occurred: {e}") + if attempt < attempts: + log.warning(f"Retrying attempt {attempt}/{attempts}") + else: + log.warning(f"Giving up after {attempts} attempts") + raise + return None + + return wrapper_retry + + return decorator_retry + + +@async_retry(3) +async def boolean( prompt: str, files: list[Path] | None = None, include_file_metadata: bool = False ) -> bool: system_prompt = "You only answer YES or NO, and absolutely nothing else. Your response will be used in a programmatic context so it's important that you only ever answer with either the string YES or the string NO." - #system_prompt = "You are a helpful assistant" - output = call( + # system_prompt = "You are a helpful assistant" + output = await call( system_prompt=system_prompt, prompt=prompt.strip(), files=files, @@ -32,17 +53,17 @@ def boolean( raise AIError(f"Failed to parse output as YES/NO: {output}") -@retry(3) -def json_object(prompt: str, files: list[Path] | None = None) -> dict: +@async_retry(3) +async def json_object(prompt: str, files: list[Path] | None = None) -> dict: system_prompt = "You always respond with valid JSON, and nothing else (no backticks, etc.). Your outputs will be used in a programmatic context." - output = call(system_prompt=system_prompt, prompt=prompt.strip(), files=files) + output = await call(system_prompt=system_prompt, prompt=prompt.strip(), files=files) try: return json.loads(output) except json.JSONDecodeError: raise AIError(f"Failed to parse output as JSON: {output}") -def call( +async def call( system_prompt: str, prompt: str, files: list[Path] | None = None, @@ -53,7 +74,7 @@ def call( raise ArgumentError("ANTHROPIC_API_KEY is not defined") model = "claude-3-5-sonnet-20241022" - client = anthropic.Anthropic(api_key=api_key) + client = anthropic.AsyncAnthropic(api_key=api_key) if files: content = create_content_list(files) @@ -61,7 +82,7 @@ def call( if include_file_metadata: prompt += "\n\nMetadata for the attached file(s):\n" for path in files: - prompt += f"* " + file_info(path) + "\n" + prompt += "* " + file_info(path) + "\n" content.append({"type": "text", "text": prompt}) @@ -74,7 +95,7 @@ def call( {"role": "user", "content": content} ] - response = client.messages.create( + response = await client.messages.create( model=model, messages=messages, system=system_prompt, diff --git a/cog_safe_push/config.py b/cog_safe_push/config.py index 129f484..8eae13e 100644 --- a/cog_safe_push/config.py +++ b/cog_safe_push/config.py @@ -37,7 +37,7 @@ class FuzzConfig(BaseModel): fixed_inputs: dict[str, InputScalar] = {} disabled_inputs: list[str] = [] duration: int = DEFAULT_FUZZ_DURATION - iterations: int | None = None + iterations: int = 10 class PredictConfig(BaseModel): @@ -68,6 +68,7 @@ class Config(BaseModel): predict: PredictConfig | None = None train: TrainConfig | None = None dockerfile: str | None = None + parallel: int = 4 def override(self, field: str, args: argparse.Namespace, arg: str): if hasattr(args, arg) and getattr(args, arg) is not None: diff --git a/cog_safe_push/main.py b/cog_safe_push/main.py index 3822457..68b00c4 100644 --- a/cog_safe_push/main.py +++ b/cog_safe_push/main.py @@ -1,16 +1,15 @@ import argparse +import asyncio import re import sys from pathlib import Path from typing import Any import pydantic -import replicate import yaml from replicate.exceptions import ReplicateError -from replicate.model import Model -from . import cog, lint, log, predict, schema +from . import cog, lint, log, schema from .config import ( DEFAULT_FUZZ_DURATION, DEFAULT_PREDICT_TIMEOUT, @@ -21,12 +20,16 @@ ) from .config import TestCase as ConfigTestCase from .exceptions import ArgumentError, CogSafePushError -from .predict import ( +from .task_context import TaskContext, make_task_context +from .tasks import ( AIOutput, + CheckOutputsMatch, ExactStringOutput, ExactURLOutput, - TestCase, - make_predict_inputs, + ExpectedOutput, + FuzzModel, + RunTestCase, + Task, ) DEFAULT_CONFIG_PATH = Path("cog-safe-push.yaml") @@ -179,7 +182,7 @@ def run_config(config: Config, no_push: bool): test_model_owner, test_model_name = parse_model(config.test_model) # small optimization - reuse_test_model = None + task_context = None if config.train: # Don't push twice in case train and predict are both defined @@ -195,23 +198,29 @@ def run_config(config: Config, no_push: bool): fuzz = FuzzConfig( fixed_inputs={}, disabled_inputs=[], duration=0, iterations=0 ) - reuse_test_model = cog_safe_push( + task_context = make_task_context( model_owner=model_owner, model_name=model_name, test_model_owner=test_model_owner, test_model_name=test_model_name, - no_push=train_no_push, test_hardware=config.test_hardware, train=True, train_destination_owner=destination_owner, train_destination_name=destination_name, + dockerfile=config.dockerfile, + ) + + cog_safe_push( + task_context=task_context, + no_push=train_no_push, + train=True, do_compare_outputs=False, predict_timeout=config.train.train_timeout, - test_cases=config_test_cases_to_test_cases(config.train.test_cases), + test_cases=parse_config_test_cases(config.train.test_cases), fuzz_fixed_inputs=fuzz.fixed_inputs, fuzz_disabled_inputs=fuzz.disabled_inputs, - fuzz_seconds=fuzz.duration, fuzz_iterations=fuzz.iterations, + parallel=config.parallel, ) if config.predict: @@ -221,53 +230,45 @@ def run_config(config: Config, no_push: bool): fuzz = FuzzConfig( fixed_inputs={}, disabled_inputs=[], duration=0, iterations=0 ) + if task_context is None: # has not been created in the training block above + task_context = make_task_context( + model_owner=model_owner, + model_name=model_name, + test_model_owner=test_model_owner, + test_model_name=test_model_name, + test_hardware=config.test_hardware, + dockerfile=config.dockerfile, + ) + cog_safe_push( - model_owner=model_owner, - model_name=model_name, - test_model_owner=test_model_owner, - test_model_name=test_model_name, + task_context=task_context, no_push=no_push, - test_hardware=config.test_hardware, train=False, do_compare_outputs=config.predict.compare_outputs, predict_timeout=config.predict.predict_timeout, - test_cases=config_test_cases_to_test_cases(config.predict.test_cases), + test_cases=parse_config_test_cases(config.predict.test_cases), fuzz_fixed_inputs=fuzz.fixed_inputs, fuzz_disabled_inputs=fuzz.disabled_inputs, - fuzz_seconds=fuzz.duration, fuzz_iterations=fuzz.iterations, - reuse_test_model=reuse_test_model, - dockerfile=config.dockerfile, + parallel=config.parallel, ) def cog_safe_push( - model_owner: str, - model_name: str, - test_model_owner: str, - test_model_name: str, - test_hardware: str, + task_context: TaskContext, no_push: bool = False, train: bool = False, - train_destination_owner: str | None = None, - train_destination_name: str | None = None, - train_destination_hardware: str = "cpu", do_compare_outputs: bool = True, predict_timeout: int = 300, - test_cases: list[TestCase] = [], + test_cases: list[tuple[dict[str, Any], ExpectedOutput]] = [], fuzz_fixed_inputs: dict = {}, fuzz_disabled_inputs: list = [], - fuzz_seconds: int = 30, - fuzz_iterations: int | None = None, - reuse_test_model: Model | None = None, - dockerfile: str | None = None, + fuzz_iterations: int = 10, + parallel=4, ): - if model_owner == test_model_owner and model_name == test_model_name: - raise ArgumentError("Can't use the same model as test model") - if no_push: log.info( - f"Running in test-only mode, no model will be pushed to {model_owner}/{model_name}" + f"Running in test-only mode, no model will be pushed to {task_context.model.owner}/{task_context.model.name}" ) if train: @@ -280,118 +281,94 @@ def cog_safe_push( "--fuzz-fixed-inputs keys must not be present in --fuzz-disabled-inputs" ) - model = get_model(model_owner, model_name) - if not model: - raise ArgumentError( - f"You need to create the model {model_owner}/{model_name} before running this script" - ) - - if reuse_test_model: - test_model = reuse_test_model - else: - test_model = get_or_create_model( - test_model_owner, test_model_name, test_hardware - ) - - if train: - train_destination = get_or_create_model( - train_destination_owner, train_destination_name, train_destination_hardware - ) - else: - train_destination = None - - if not reuse_test_model: - log.info("Pushing test model") - pushed_version_id = cog.push(test_model, dockerfile) - test_model.reload() - try: - assert ( - test_model.versions.list()[0].id == pushed_version_id - ), f"Pushed version ID {pushed_version_id} doesn't match latest version on {test_model_owner}/{test_model_name}: {test_model.versions.list()[0].id}" - except ReplicateError as e: - if e.status == 404: - # Assume it's an official model - # If it's an official model, can't check that the version matches - pass - else: - raise - log.info("Linting test model schema") - schema.lint(test_model, train=train) + schema.lint(task_context.test_model, train=train) model_has_versions = False try: - model_has_versions = bool(model.versions.list()) + model_has_versions = bool(task_context.model.versions.list()) except ReplicateError as e: if e.status == 404: # Assume it's an official model - model_has_versions = bool(model.latest_version) + model_has_versions = bool(task_context.model.latest_version) else: raise + tasks = [] + if model_has_versions: log.info("Checking schema backwards compatibility") - test_model_schemas = schema.get_schemas(test_model, train=train) - model_schemas = schema.get_schemas(model, train=train) + test_model_schemas = schema.get_schemas(task_context.test_model, train=train) + model_schemas = schema.get_schemas(task_context.model, train=train) schema.check_backwards_compatible( test_model_schemas, model_schemas, train=train ) if do_compare_outputs: - log.info( - "Checking that outputs match between existing version and test version" - ) - if test_cases: - compare_inputs = test_cases[0].inputs - is_deterministic = "seed" in compare_inputs - else: - schemas = schema.get_schemas(model, train=train) - compare_inputs, is_deterministic = make_predict_inputs( - schemas, - train=train, - only_required=True, - seed=1, - fixed_inputs=fuzz_fixed_inputs, - disabled_inputs=fuzz_disabled_inputs, + tasks.append( + CheckOutputsMatch( + context=task_context, + timeout_seconds=predict_timeout, + first_test_case_inputs=test_cases[0][0] if test_cases else None, + fuzz_fixed_inputs=fuzz_fixed_inputs, + fuzz_disabled_inputs=fuzz_disabled_inputs, ) - predict.check_outputs_match( - test_model=test_model, - model=model, - train=train, - train_destination=train_destination, - timeout_seconds=predict_timeout, - inputs=compare_inputs, - is_deterministic=is_deterministic, ) if test_cases: - log.info("Running test cases") - predict.run_test_cases( - model=test_model, - train=train, - train_destination=train_destination, - predict_timeout=predict_timeout, - test_cases=test_cases, - ) + for inputs, output in test_cases: + tasks.append( + RunTestCase( + context=task_context, + inputs=inputs, + output=output, + predict_timeout=predict_timeout, + ) + ) - if fuzz_seconds > 0: - log.info("Fuzzing test model") - predict.fuzz_model( - model=test_model, - train=train, - train_destination=train_destination, - timeout_seconds=fuzz_seconds, - max_iterations=fuzz_iterations, - fixed_inputs=fuzz_fixed_inputs, - disabled_inputs=fuzz_disabled_inputs, + for _ in range(fuzz_iterations): + tasks.append( + FuzzModel( + context=task_context, + fixed_inputs=fuzz_fixed_inputs, + disabled_inputs=fuzz_disabled_inputs, + predict_timeout=predict_timeout, + ) ) + asyncio.run(run_tasks(tasks, parallel=parallel)) + log.info("Tests were successful ✨") if not no_push: log.info("Pushing model...") - cog.push(model, dockerfile) + cog.push(task_context.model, task_context.dockerfile) + + +async def run_tasks(tasks: list[Task], parallel: int) -> None: + log.info(f"Running tasks with parallelism {parallel}") + + semaphore = asyncio.Semaphore(parallel) + errors: list[Exception] = [] + + async def run_with_semaphore(task: Task) -> None: + async with semaphore: + try: + await task.run() + except Exception as e: + errors.append(e) + + # Create task coroutines and run them concurrently + task_coroutines = [run_with_semaphore(task) for task in tasks] + + # Use gather to run tasks concurrently + await asyncio.gather(*task_coroutines, return_exceptions=True) - return test_model # for reuse + if errors: + # If there are multiple errors, we'll raise the first one + # but log all of them + for error in errors[1:]: + log.error(f"Additional error occurred: {error}") + raise errors[0] def parse_inputs(inputs_list: list[str]) -> dict[str, Any]: @@ -425,35 +402,6 @@ def parse_input_value(value: str) -> Any: return value -def get_or_create_model(model_owner, model_name, hardware) -> Model: - model = get_model(model_owner, model_name) - - if not model: - if not hardware: - raise ArgumentError( - f"Model {model_owner}/{model_name} doesn't exist, and you didn't specify hardware" - ) - - log.info(f"Creating model {model_owner}/{model_name} with hardware {hardware}") - model = replicate.models.create( - owner=model_owner, - name=model_name, - visibility="private", - hardware=hardware, - ) - return model - - -def get_model(owner, name) -> Model | None: - try: - model = replicate.models.get(f"{owner}/{name}") - except ReplicateError as e: - if e.status == 404: - return None - raise - return model - - def parse_model(model_owner_name: str) -> tuple[str, str]: pattern = r"^([a-z0-9_-]+)/([a-z0-9-.]+)$" match = re.match(pattern, model_owner_name) @@ -502,23 +450,24 @@ def parse_test_case(test_case_str: str) -> ConfigTestCase: return test_case -def config_test_cases_to_test_cases( +def parse_config_test_case( + config_test_case: ConfigTestCase, +) -> tuple[dict[str, Any], ExpectedOutput]: + output = None + if config_test_case.exact_string: + output = ExactStringOutput(string=config_test_case.exact_string) + elif config_test_case.match_url: + output = ExactURLOutput(url=config_test_case.match_url) + elif config_test_case.match_prompt: + output = AIOutput(prompt=config_test_case.match_prompt) + + return (config_test_case.inputs, output) + + +def parse_config_test_cases( config_test_cases: list[ConfigTestCase], -) -> list[TestCase]: - test_cases = [] - for tc in config_test_cases: - output = None - if tc.exact_string: - output = ExactStringOutput(string=tc.exact_string) - elif tc.match_url: - output = ExactURLOutput(url=tc.match_url) - elif tc.match_prompt: - output = AIOutput(prompt=tc.match_prompt) - - test_case = TestCase(inputs=tc.inputs, output=output) - test_cases.append(test_case) - - return test_cases +) -> list[tuple[dict[str, Any], ExpectedOutput]]: + return [parse_config_test_case(tc) for tc in config_test_cases] def print_help_config(): diff --git a/cog_safe_push/match_outputs.py b/cog_safe_push/match_outputs.py index 451e170..392fffc 100644 --- a/cog_safe_push/match_outputs.py +++ b/cog_safe_push/match_outputs.py @@ -11,7 +11,7 @@ from . import ai, log -def output_matches_prompt(output: Any, prompt: str) -> tuple[bool, str]: +async def output_matches_prompt(output: Any, prompt: str) -> tuple[bool, str]: urls = [] if isinstance(output, str) and is_url(output): urls = [output] @@ -38,7 +38,7 @@ def output_matches_prompt(output: Any, prompt: str) -> tuple[bool, str]: Description to evaluate: {prompt}""" - matches = ai.boolean( + matches = await ai.boolean( claude_prompt, files=tmp_files, include_file_metadata=True, @@ -50,7 +50,7 @@ def output_matches_prompt(output: Any, prompt: str) -> tuple[bool, str]: # If it's not a match, do best of three to avoid flaky tests multiple_matches = [matches] for _ in range(2): - matches = ai.boolean( + matches = await ai.boolean( claude_prompt, files=tmp_files, include_file_metadata=True, @@ -63,18 +63,20 @@ def output_matches_prompt(output: Any, prompt: str) -> tuple[bool, str]: return False, "AI determined that the output does not match the description" -def outputs_match(test_output, output, is_deterministic: bool) -> tuple[bool, str]: +async def outputs_match( + test_output, output, is_deterministic: bool +) -> tuple[bool, str]: if type(test_output) is not type(output): return False, "The types of the outputs don't match" if isinstance(output, str): if is_url(test_output) and is_url(output): - return urls_match(test_output, output, is_deterministic) + return await urls_match(test_output, output, is_deterministic) if is_url(test_output) or is_url(output): return False, "Only one output is a URL" - return strings_match(test_output, output, is_deterministic) + return await strings_match(test_output, output, is_deterministic) if isinstance(output, bool): if test_output == output: @@ -95,7 +97,7 @@ def outputs_match(test_output, output, is_deterministic: bool) -> tuple[bool, st if test_output.keys() != output.keys(): return False, "Dict keys don't match" for key in output: - matches, message = outputs_match( + matches, message = await outputs_match( test_output[key], output[key], is_deterministic ) if not matches: @@ -106,7 +108,7 @@ def outputs_match(test_output, output, is_deterministic: bool) -> tuple[bool, st if len(test_output) != len(output): return False, "List lengths don't match" for i in range(len(output)): - matches, message = outputs_match( + matches, message = await outputs_match( test_output[i], output[i], is_deterministic ) if not matches: @@ -118,12 +120,12 @@ def outputs_match(test_output, output, is_deterministic: bool) -> tuple[bool, st return True, "" -def strings_match(s1: str, s2: str, is_deterministic: bool) -> tuple[bool, str]: +async def strings_match(s1: str, s2: str, is_deterministic: bool) -> tuple[bool, str]: if is_deterministic: if s1 == s2: return True, "" return False, "Strings aren't the same" - fuzzy_match = ai.boolean( + fuzzy_match = await ai.boolean( f""" Have these two strings been generated by the same generative AI model inputs/prompt? @@ -136,13 +138,13 @@ def strings_match(s1: str, s2: str, is_deterministic: bool) -> tuple[bool, str]: return False, "Strings aren't similar" -def urls_match(url1: str, url2: str, is_deterministic: bool) -> tuple[bool, str]: +async def urls_match(url1: str, url2: str, is_deterministic: bool) -> tuple[bool, str]: # New model must return same extension as previous model if not extensions_match(url1, url2): return False, "URL extensions don't match" if is_image(url1): - return images_match(url1, url2, is_deterministic) + return await images_match(url1, url2, is_deterministic) if is_audio(url1): return audios_match(url1, url2, is_deterministic) @@ -179,7 +181,9 @@ def is_url(s: str) -> bool: return s.startswith(("http://", "https://")) -def images_match(url1: str, url2: str, is_deterministic: bool) -> tuple[bool, str]: +async def images_match( + url1: str, url2: str, is_deterministic: bool +) -> tuple[bool, str]: with download(url1) as tmp1, download(url2) as tmp2: img1 = Image.open(tmp1) img2 = Image.open(tmp2) @@ -196,7 +200,7 @@ def images_match(url1: str, url2: str, is_deterministic: bool) -> tuple[bool, st return False, "Images are not identical" return True, "" - fuzzy_match = ai.boolean( + fuzzy_match = await ai.boolean( "These two images have been generated by or modified by an AI model. Is it highly likely that those two predictions of the model had the same inputs?", files=[tmp1, tmp2], ) diff --git a/cog_safe_push/predict.py b/cog_safe_push/predict.py index d343731..b8ffb10 100644 --- a/cog_safe_push/predict.py +++ b/cog_safe_push/predict.py @@ -1,124 +1,21 @@ +import asyncio import json import time -from dataclasses import dataclass -from typing import Any, List +from typing import Any import replicate from replicate.exceptions import ReplicateError from replicate.model import Model -from . import ai, log, schema +from . import ai, log from .exceptions import ( AIError, - FuzzError, - OutputsDontMatchError, PredictionFailedError, PredictionTimeoutError, - TestCaseFailedError, ) -from .match_outputs import is_url, output_matches_prompt, outputs_match, urls_match -@dataclass -class ExactStringOutput: - string: str - - -@dataclass -class ExactURLOutput: - url: str - - -@dataclass -class AIOutput: - prompt: str - - -@dataclass -class TestCase: - inputs: dict[str, Any] - output: ExactStringOutput | ExactURLOutput | AIOutput | None - - -def check_outputs_match( - test_model: Model, - model: Model, - train: bool, - train_destination: Model | None, - timeout_seconds: float, - inputs: dict[str, Any], - is_deterministic: bool, -): - test_output = predict( - model=test_model, - train=train, - train_destination=train_destination, - inputs=inputs, - timeout_seconds=timeout_seconds, - ) - output = predict( - model=model, - train=train, - train_destination=train_destination, - inputs=inputs, - timeout_seconds=timeout_seconds, - ) - matches, error = outputs_match(test_output, output, is_deterministic) - if not matches: - raise OutputsDontMatchError( - f"Outputs don't match:\n\ntest output:\n{test_output}\n\nmodel output:\n{output}\n\n{error}" - ) - - -def fuzz_model( - model: Model, - train: bool, - train_destination: Model | None, - timeout_seconds: float, - max_iterations: int | None, - fixed_inputs: dict[str, Any], - disabled_inputs: list[str], -): - start_time = time.time() - inputs_history = [] - successful_predictions = 0 - while True: - schemas = schema.get_schemas(model, train=train) - predict_inputs, _ = make_predict_inputs( - schemas, - train=train, - only_required=False, - seed=None, - fixed_inputs=fixed_inputs, - disabled_inputs=disabled_inputs, - inputs_history=inputs_history, - ) - inputs_history.append(predict_inputs) - predict_timeout = start_time + timeout_seconds - time.time() - try: - output = predict( - model=model, - train=train, - train_destination=train_destination, - inputs=predict_inputs, - timeout_seconds=predict_timeout, - ) - except PredictionTimeoutError: - if not successful_predictions: - log.warning( - f"No predictions succeeded in {timeout_seconds}, try increasing --fuzz-seconds" - ) - return - except PredictionFailedError as e: - raise FuzzError(e) - if not output: - raise FuzzError("No output") - successful_predictions += 1 - if max_iterations is not None and successful_predictions == max_iterations: - return - - -def make_predict_inputs( +async def make_predict_inputs( schemas: dict, train: bool, only_required: bool, @@ -284,14 +181,14 @@ def make_predict_inputs( Return a new combination of inputs that you haven't used before. You have previously used these inputs: {inputs_history_str}""" - inputs = ai.json_object(prompt) + inputs = await ai.json_object(prompt) if set(required) - set(inputs.keys()): max_attempts = 5 if attempt == max_attempts: raise AIError( f"Failed to generate a json payload with the correct keys after {max_attempts} attempts, giving up" ) - return make_predict_inputs( + return await make_predict_inputs( schemas=schemas, train=train, only_required=only_required, @@ -316,7 +213,7 @@ def make_predict_inputs( return inputs, is_deterministic -def predict( +async def predict( model: Model, train: bool, train_destination: Model | None, @@ -327,6 +224,8 @@ def predict( f"Running {'training' if train else 'prediction'} with inputs:\n{json.dumps(inputs, indent=2)}" ) + start_time = time.time() + if train: assert train_destination version_ref = f"{model.owner}/{model.name}:{model.versions.list()[0].id}" @@ -337,21 +236,22 @@ def predict( ) else: try: - prediction = replicate.predictions.create( + prediction = await replicate.predictions.async_create( version=model.versions.list()[0].id, input=inputs ) except ReplicateError as e: if e.status == 404: # Assume it's an official model - prediction = replicate.predictions.create(model=model, input=inputs) + prediction = await replicate.predictions.async_create( + model=model, input=inputs + ) else: raise log.vv(f"Prediction URL: https://replicate.com/p/{prediction.id}") - start_time = time.time() while prediction.status not in ["succeeded", "failed", "canceled"]: - time.sleep(0.5) + await asyncio.sleep(0.5) if time.time() - start_time > timeout_seconds: raise PredictionTimeoutError() prediction.reload() @@ -359,81 +259,12 @@ def predict( if prediction.status == "failed": raise PredictionFailedError(prediction.error) - log.vv(f"Got output: {truncate(prediction.output)}") + duration = time.time() - start_time + log.vv(f"Got output: {truncate(prediction.output)} ({duration:.2f} sec)") return prediction.output -def run_test_cases( - model: Model, - train: bool, - train_destination: Model | None, - predict_timeout: int, - test_cases: List[TestCase], -): - for i, test_case in enumerate(test_cases): - log.info(f"Running test case {i + 1}/{len(test_cases)}") - - try: - output = predict( - model=model, - train=train, - train_destination=train_destination, - inputs=test_case.inputs, - timeout_seconds=predict_timeout, - ) - except PredictionFailedError as e: - raise TestCaseFailedError(f"Test case {i + 1} failed: {str(e)}") - - if test_case.output is None: - log.info(f"Test case {i + 1} passed (no output checker)") - continue - - if isinstance(test_case.output, ExactStringOutput): - if output != test_case.output.string: - raise TestCaseFailedError( - f"Test case {i + 1} failed: Expected '{test_case.output.string}', got '{truncate(output, 200)}'" - ) - elif isinstance(test_case.output, ExactURLOutput): - output_url = None - if isinstance(output, str) and is_url(output): - output_url = output - if ( - isinstance(output, list) - and len(output) == 1 - and isinstance(output[0], str) - and is_url(output[0]) - ): - output_url = output[0] - if output_url is not None: - matches, error = urls_match( - test_case.output.url, output_url, is_deterministic=True - ) - if not matches: - raise TestCaseFailedError( - f"Test case {i + 1} failed: URL mismatch. {error}" - ) - else: - raise TestCaseFailedError( - f"Test case {i + 1} failed: Expected URL, got '{truncate(output, 200)}'" - ) - elif isinstance(test_case.output, AIOutput): - try: - matches, error = output_matches_prompt(output, test_case.output.prompt) - if not matches: - raise TestCaseFailedError(f"Test case {i + 1} failed: {error}") - except AIError as e: - raise TestCaseFailedError( - f"Test case {i + 1} failed: AI error: {str(e)}" - ) - else: - raise ValueError(f"Unknown output type: {type(test_case.output)}") - - log.info(f"Test case {i + 1} passed") - - log.info(f"All {len(test_cases)} test cases passed") - - def truncate(s, max_length=500) -> str: s = str(s) if len(s) <= max_length: diff --git a/cog_safe_push/retry.py b/cog_safe_push/retry.py deleted file mode 100644 index 2f56369..0000000 --- a/cog_safe_push/retry.py +++ /dev/null @@ -1,24 +0,0 @@ -import functools - -from . import log - - -def retry(attempts=3): - def decorator_retry(func): - @functools.wraps(func) - def wrapper_retry(*args, **kwargs): - for attempt in range(1, attempts + 1): - try: - return func(*args, **kwargs) - except Exception as e: - log.warning(f"Exception occurred: {e}") - if attempt < attempts: - log.warning(f"Retrying attempt {attempt}/{attempts}") - else: - log.warning(f"Giving up after {attempts} attempts") - raise - return None - - return wrapper_retry - - return decorator_retry diff --git a/cog_safe_push/schema.py b/cog_safe_push/schema.py index 00fce26..70bcc00 100644 --- a/cog_safe_push/schema.py +++ b/cog_safe_push/schema.py @@ -116,7 +116,7 @@ def get_openapi_schema(model: Model) -> dict: raise -def get_schemas(model, train: bool): +def get_schemas(model, train: bool) -> dict: schemas = get_openapi_schema(model)["components"]["schemas"] unnecessary_keys = [ "HTTPValidationError", diff --git a/cog_safe_push/task_context.py b/cog_safe_push/task_context.py new file mode 100644 index 0000000..7b22945 --- /dev/null +++ b/cog_safe_push/task_context.py @@ -0,0 +1,101 @@ +from dataclasses import dataclass + +import replicate +from replicate.exceptions import ReplicateError +from replicate.model import Model + +from . import cog, log +from .exceptions import ArgumentError + + +@dataclass(frozen=True) +class TaskContext: + model: Model + test_model: Model + train_destination: Model | None + dockerfile: str | None + + def is_train(self): + return self.train_destination is not None + + +def make_task_context( + model_owner: str, + model_name: str, + test_model_owner: str, + test_model_name: str, + test_hardware: str, + train: bool = False, + dockerfile: str | None = None, + train_destination_owner: str | None = None, + train_destination_name: str | None = None, + train_destination_hardware: str = "cpu", +) -> TaskContext: + if model_owner == test_model_owner and model_name == test_model_name: + raise ArgumentError("Can't use the same model as test model") + + model = get_model(model_owner, model_name) + if not model: + raise ArgumentError( + f"You need to create the model {model_owner}/{model_name} before running this script" + ) + + test_model = get_or_create_model(test_model_owner, test_model_name, test_hardware) + + if train: + train_destination = get_or_create_model( + train_destination_owner, train_destination_name, train_destination_hardware + ) + else: + train_destination = None + + log.info("Pushing test model") + pushed_version_id = cog.push(test_model, dockerfile) + test_model.reload() + try: + assert ( + test_model.versions.list()[0].id == pushed_version_id + ), f"Pushed version ID {pushed_version_id} doesn't match latest version on {test_model_owner}/{test_model_name}: {test_model.versions.list()[0].id}" + except ReplicateError as e: + if e.status == 404: + # Assume it's an official model + # If it's an official model, can't check that the version matches + pass + else: + raise + + return TaskContext( + model=model, + test_model=test_model, + train_destination=train_destination, + dockerfile=dockerfile, + ) + + +def get_or_create_model(model_owner, model_name, hardware) -> Model: + model = get_model(model_owner, model_name) + + if not model: + if not hardware: + raise ArgumentError( + f"Model {model_owner}/{model_name} doesn't exist, and you didn't specify hardware" + ) + + log.info(f"Creating model {model_owner}/{model_name} with hardware {hardware}") + model = replicate.models.create( + owner=model_owner, + name=model_name, + visibility="private", + hardware=hardware, + ) + return model + + +def get_model(owner, name) -> Model | None: + try: + model = replicate.models.get(f"{owner}/{name}") + except ReplicateError as e: + if e.status == 404: + return None + raise + return model diff --git a/cog_safe_push/tasks.py b/cog_safe_push/tasks.py new file mode 100644 index 0000000..d34ecb6 --- /dev/null +++ b/cog_safe_push/tasks.py @@ -0,0 +1,183 @@ +from dataclasses import dataclass +from typing import Any, Protocol + +from . import log, schema +from .exceptions import ( + AIError, + FuzzError, + OutputsDontMatchError, + PredictionFailedError, + PredictionTimeoutError, + TestCaseFailedError, +) +from .match_outputs import is_url, output_matches_prompt, outputs_match, urls_match +from .predict import make_predict_inputs, predict, truncate +from .task_context import TaskContext + + +@dataclass +class ExactStringOutput: + string: str + + +@dataclass +class ExactURLOutput: + url: str + + +@dataclass +class AIOutput: + prompt: str + + +ExpectedOutput = ExactStringOutput | ExactURLOutput | AIOutput | None + + +class Task(Protocol): + async def run(self) -> None: ... + + +@dataclass +class CheckOutputsMatch(Task): + context: TaskContext + timeout_seconds: int + first_test_case_inputs: dict[str, Any] | None + fuzz_fixed_inputs: dict[str, Any] + fuzz_disabled_inputs: list[str] + + async def run(self) -> None: + if self.first_test_case_inputs is not None: + inputs = self.first_test_case_inputs + is_deterministic = "seed" in inputs + else: + schemas = schema.get_schemas( + self.context.model, train=self.context.is_train() + ) + inputs, is_deterministic = await make_predict_inputs( + schemas, + train=self.context.is_train(), + only_required=True, + seed=1, + fixed_inputs=self.fuzz_fixed_inputs, + disabled_inputs=self.fuzz_disabled_inputs, + ) + + log.v( + f"Checking outputs match between existing version and test version, with inputs: {inputs}" + ) + test_output = await predict( + model=self.context.test_model, + train=self.context.is_train(), + train_destination=self.context.train_destination, + inputs=inputs, + timeout_seconds=self.timeout_seconds, + ) + output = await predict( + model=self.context.model, + train=self.context.is_train(), + train_destination=self.context.train_destination, + inputs=inputs, + timeout_seconds=self.timeout_seconds, + ) + matches, error = await outputs_match(test_output, output, is_deterministic) + if not matches: + raise OutputsDontMatchError( + f"Outputs don't match:\n\ntest output:\n{test_output}\n\nmodel output:\n{output}\n\n{error}" + ) + + +@dataclass +class RunTestCase(Task): + context: TaskContext + inputs: dict[str, Any] + output: ExactStringOutput | ExactURLOutput | AIOutput | None + predict_timeout: int + + async def run(self) -> None: + log.v(f"Running test case with inputs: {self.inputs}") + try: + output = await predict( + model=self.context.test_model, + train=self.context.is_train(), + train_destination=self.context.train_destination, + inputs=self.inputs, + timeout_seconds=self.predict_timeout, + ) + except PredictionFailedError as e: + raise TestCaseFailedError(f"Test case failed: {str(e)}") + + if self.output is None: + return + + if isinstance(self.output, ExactStringOutput): + if output != self.output.string: + raise TestCaseFailedError( + f"Test case failed: Expected '{self.output.string}', got '{truncate(output, 200)}'" + ) + elif isinstance(self.output, ExactURLOutput): + output_url = None + if isinstance(output, str) and is_url(output): + output_url = output + if ( + isinstance(output, list) + and len(output) == 1 + and isinstance(output[0], str) + and is_url(output[0]) + ): + output_url = output[0] + if output_url is not None: + matches, error = await urls_match( + self.output.url, output_url, is_deterministic=True + ) + if not matches: + raise TestCaseFailedError( + f"Test case failed: URL mismatch. {error}" + ) + else: + raise TestCaseFailedError( + f"Test case failed: Expected URL, got '{truncate(output, 200)}'" + ) + elif isinstance(self.output, AIOutput): + try: + matches, error = await output_matches_prompt(output, self.output.prompt) + if not matches: + raise TestCaseFailedError(f"Test case failed: {error}") + except AIError as e: + raise TestCaseFailedError(f"Test case failed: AI error: {str(e)}") + + +# TODO(andreas): need to make inputs history work +@dataclass +class FuzzModel(Task): + context: TaskContext + fixed_inputs: dict + disabled_inputs: list[str] + predict_timeout: int + + async def run(self) -> None: + schemas = schema.get_schemas( + self.context.test_model, train=self.context.is_train() + ) + predict_inputs, _ = await make_predict_inputs( + schemas, + train=self.context.is_train(), + only_required=False, + seed=None, + fixed_inputs=self.fixed_inputs, + disabled_inputs=self.disabled_inputs, + ) + log.v(f"Fuzzing with inputs: {predict_inputs}") + try: + output = await predict( + model=self.context.test_model, + train=self.context.is_train(), + train_destination=self.context.train_destination, + inputs=predict_inputs, + timeout_seconds=self.predict_timeout, + ) + except PredictionTimeoutError: + raise FuzzError("Prediction timed out") + except PredictionFailedError as e: + raise FuzzError(f"Prediction failed: {e}") + if not output: + raise FuzzError("No output") diff --git a/end-to-end-test/test_end_to_end.py b/end-to-end-test/test_end_to_end.py index 695826e..07b1cda 100644 --- a/end-to-end-test/test_end_to_end.py +++ b/end-to-end-test/test_end_to_end.py @@ -11,7 +11,8 @@ from cog_safe_push import log from cog_safe_push.exceptions import * from cog_safe_push.main import cog_safe_push -from cog_safe_push.predict import AIOutput, ExactStringOutput, ExactURLOutput, TestCase +from cog_safe_push.task_context import make_task_context +from cog_safe_push.tasks import AIOutput, ExactStringOutput, ExactURLOutput log.set_verbosity(3) @@ -27,53 +28,58 @@ def test_cog_safe_push(): try: with fixture_dir("base"): cog_safe_push( - model_owner, - model_name, - model_owner, - test_model_name, - "cpu", + make_task_context( + model_owner, model_name, model_owner, test_model_name, "cpu" + ), test_cases=[ - TestCase( - inputs={"text": "world"}, - output=ExactStringOutput("hello world"), + ( + {"text": "world"}, + ExactStringOutput("hello world"), ), - TestCase( - inputs={"text": "world"}, - output=AIOutput("the text hello world"), + ( + {"text": "world"}, + AIOutput("the text hello world"), ), ], ) with fixture_dir("same-schema"): - cog_safe_push(model_owner, model_name, model_owner, test_model_name, "cpu") + cog_safe_push( + make_task_context( + model_owner, model_name, model_owner, test_model_name, "cpu" + ) + ) with fixture_dir("schema-lint-error"): with pytest.raises(SchemaLintError): cog_safe_push( - model_owner, model_name, model_owner, test_model_name, "cpu" + make_task_context( + model_owner, model_name, model_owner, test_model_name, "cpu" + ) ) with fixture_dir("incompatible-schema"): with pytest.raises(IncompatibleSchemaError): cog_safe_push( - model_owner, model_name, model_owner, test_model_name, "cpu" + make_task_context( + model_owner, model_name, model_owner, test_model_name, "cpu" + ) ) with fixture_dir("outputs-dont-match"): with pytest.raises(OutputsDontMatchError): cog_safe_push( - model_owner, model_name, model_owner, test_model_name, "cpu" + make_task_context( + model_owner, model_name, model_owner, test_model_name, "cpu" + ) ) with fixture_dir("additive-schema-fuzz-error"): with pytest.raises(FuzzError): cog_safe_push( - model_owner, - model_name, - model_owner, - test_model_name, - "cpu", - fuzz_seconds=120, + make_task_context( + model_owner, model_name, model_owner, test_model_name, "cpu" + ), ) finally: @@ -92,60 +98,61 @@ def test_cog_safe_push_images(): try: with fixture_dir("image-base"): cog_safe_push( - model_owner, - model_name, - model_owner, - test_model_name, - "cpu", - fuzz_seconds=60, + make_task_context( + model_owner, model_name, model_owner, test_model_name, "cpu" + ), test_cases=[ - TestCase( - inputs={ + ( + { "image": "https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg", "width": 1024, "height": 639, }, - output=ExactURLOutput( + ExactURLOutput( "https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg" ), ), - TestCase( - inputs={ + ( + { "image": "https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg", "width": 200, "height": 100, }, - output=AIOutput("An image of a car"), + AIOutput("An image of a car"), ), - TestCase( - inputs={ + ( + { "image": "https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg", "width": 200, "height": 100, }, - output=AIOutput("A jpg image"), + AIOutput("A jpg image"), ), - TestCase( - inputs={ + ( + { "image": "https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg", "width": 200, "height": 100, }, - output=AIOutput("A image with width 200px and height 100px"), + AIOutput("A image with width 200px and height 100px"), ), - TestCase( - inputs={ + ( + { "image": "https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg", "width": 200, "height": 100, }, - output=None, + None, ), ], ) with fixture_dir("image-base"): - cog_safe_push(model_owner, model_name, model_owner, test_model_name, "cpu") + cog_safe_push( + make_task_context( + model_owner, model_name, model_owner, test_model_name, "cpu" + ) + ) finally: delete_model(model_owner, model_name) @@ -162,10 +169,18 @@ def test_cog_safe_push_images_with_seed(): try: with fixture_dir("image-base-seed"): - cog_safe_push(model_owner, model_name, model_owner, test_model_name, "cpu") + cog_safe_push( + make_task_context( + model_owner, model_name, model_owner, test_model_name, "cpu" + ) + ) with fixture_dir("image-base-seed"): - cog_safe_push(model_owner, model_name, model_owner, test_model_name, "cpu") + cog_safe_push( + make_task_context( + model_owner, model_name, model_owner, test_model_name, "cpu" + ) + ) finally: delete_model(model_owner, model_name) @@ -183,27 +198,31 @@ def test_cog_safe_push_train(): try: with fixture_dir("train"): cog_safe_push( - model_owner, - model_name, - model_owner, - test_model_name, - "cpu", - train=True, - train_destination_owner=model_owner, - train_destination_name=test_model_name + "-dest", + make_task_context( + model_owner, + model_name, + model_owner, + test_model_name, + "cpu", + train=True, + train_destination_owner=model_owner, + train_destination_name=test_model_name + "-dest", + ), fuzz_iterations=1, ) with fixture_dir("train"): cog_safe_push( - model_owner, - model_name, - model_owner, - test_model_name, - "cpu", - train=True, - train_destination_owner=model_owner, - train_destination_name=test_model_name + "-dest", + make_task_context( + model_owner, + model_name, + model_owner, + test_model_name, + "cpu", + train=True, + train_destination_owner=model_owner, + train_destination_name=test_model_name + "-dest", + ), fuzz_iterations=1, ) diff --git a/integration-test/test_output_matches_prompt.py b/integration-test/test_output_matches_prompt.py index 74a18e9..daa0318 100644 --- a/integration-test/test_output_matches_prompt.py +++ b/integration-test/test_output_matches_prompt.py @@ -1,8 +1,8 @@ -import pytest from pathlib import Path -from cog_safe_push.match_outputs import output_matches_prompt -from cog_safe_push import log +import pytest + +from cog_safe_push.match_outputs import output_matches_prompt # log.set_verbosity(3) @@ -68,22 +68,22 @@ def get_captioned_images( @pytest.mark.parametrize( - "file_url,prompt", + ("file_url", "prompt"), get_captioned_images(positive_images), ids=lambda x: Path(x[0]).name if isinstance(x, tuple) else x, ) -def test_image_output_matches_prompt_positive(file_url: str, prompt: str): +async def test_image_output_matches_prompt_positive(file_url: str, prompt: str): """Test that images in the positive directory match their prompts.""" - matches, message = output_matches_prompt(file_url, prompt) + matches, message = await output_matches_prompt(file_url, prompt) assert matches, f"Image should match prompt '{prompt}'. Error: {message}" @pytest.mark.parametrize( - "file_url,prompt", + ("file_url", "prompt"), get_captioned_images(negative_images), ids=lambda x: Path(x[0]).name if isinstance(x, tuple) else x, ) -def test_image_output_matches_prompt_negative(file_url: str, prompt: str): +async def test_image_output_matches_prompt_negative(file_url: str, prompt: str): """Test that images in the negative directory don't match their prompts.""" - matches, _ = output_matches_prompt(file_url, prompt) + matches, _ = await output_matches_prompt(file_url, prompt) assert not matches, f"Image should not match prompt '{prompt}'" diff --git a/pyrightconfig.json b/pyrightconfig.json index 1657c7a..cd78deb 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -6,7 +6,7 @@ "exclude": [ "**/node_modules", "**/__pycache__", - "integration_test/fixtures", + "end-to-end-test/fixtures", ], "ignore": [ diff --git a/script/lint b/script/lint index 37258e1..0e651cb 100755 --- a/script/lint +++ b/script/lint @@ -1,5 +1,10 @@ #!/bin/bash -eu -ruff check -ruff format --check +if [[ ${1:-} == "--fix" ]]; then + ruff check --fix + ruff format +else + ruff check + ruff format --check +fi pyright diff --git a/test/test_match_outputs.py b/test/test_match_outputs.py index 4064d99..959738d 100644 --- a/test/test_match_outputs.py +++ b/test/test_match_outputs.py @@ -151,23 +151,25 @@ def test_extensions_match(): assert not extensions_match("file1.jpg", "file2.png") -def test_urls_with_different_extensions(): - result, message = outputs_match( +async def test_urls_with_different_extensions(): + result, message = await outputs_match( "http://example.com/file1.jpg", "http://example.com/file2.png", False ) assert not result assert message == "URL extensions don't match" -def test_one_url_one_non_url(): - result, message = outputs_match("http://example.com/image.jpg", "not_a_url", False) +async def test_one_url_one_non_url(): + result, message = await outputs_match( + "http://example.com/image.jpg", "not_a_url", False + ) assert not result assert message == "Only one output is a URL" @patch("cog_safe_push.match_outputs.download") @patch("PIL.Image.open") -def test_images_with_different_sizes(mock_image_open, mock_download): +async def test_images_with_different_sizes(mock_image_open, mock_download): assert mock_download mock_image1 = MagicMock() mock_image2 = MagicMock() @@ -175,7 +177,7 @@ def test_images_with_different_sizes(mock_image_open, mock_download): mock_image2.size = (200, 200) mock_image_open.side_effect = [mock_image1, mock_image2] - result, message = outputs_match( + result, message = await outputs_match( "http://example.com/image1.jpg", "http://example.com/image2.jpg", False ) assert not result @@ -183,8 +185,8 @@ def test_images_with_different_sizes(mock_image_open, mock_download): @patch("cog_safe_push.log.warning") -def test_unknown_url_format(mock_warning): - result, _ = outputs_match( +async def test_unknown_url_format(mock_warning): + result, _ = await outputs_match( "http://example.com/unknown.xyz", "http://example.com/unknown.xyz", False ) assert result @@ -194,23 +196,23 @@ def test_unknown_url_format(mock_warning): @patch("cog_safe_push.log.warning") -def test_unknown_output_type(mock_warning): +async def test_unknown_output_type(mock_warning): class UnknownType: pass - result, _ = outputs_match(UnknownType(), UnknownType(), False) + result, _ = await outputs_match(UnknownType(), UnknownType(), False) assert result mock_warning.assert_called_with(f"Unknown type: {type(UnknownType())}") -def test_large_structure_performance(): +async def test_large_structure_performance(): import time large_struct1 = {"key" + str(i): i for i in range(10000)} large_struct2 = {"key" + str(i): i for i in range(10000)} start_time = time.time() - result, _ = outputs_match(large_struct1, large_struct2, False) + result, _ = await outputs_match(large_struct1, large_struct2, False) end_time = time.time() assert result diff --git a/test/test_predict.py b/test/test_predict.py index 2741dd1..b7bedc9 100644 --- a/test/test_predict.py +++ b/test/test_predict.py @@ -31,10 +31,10 @@ def sample_schemas(): @patch("cog_safe_push.predict.ai.json_object") -def test_make_predict_inputs_basic(mock_json_object, sample_schemas): +async def test_make_predict_inputs_basic(mock_json_object, sample_schemas): mock_json_object.return_value = {"text": "hello", "number": 42, "choice": "A"} - inputs, is_deterministic = make_predict_inputs( + inputs, is_deterministic = await make_predict_inputs( sample_schemas, train=False, only_required=True, @@ -47,11 +47,11 @@ def test_make_predict_inputs_basic(mock_json_object, sample_schemas): assert not is_deterministic -def test_make_predict_inputs_with_seed(sample_schemas): +async def test_make_predict_inputs_with_seed(sample_schemas): with patch("cog_safe_push.predict.ai.json_object") as mock_json_object: mock_json_object.return_value = {"text": "hello", "number": 42, "choice": "A"} - inputs, is_deterministic = make_predict_inputs( + inputs, is_deterministic = await make_predict_inputs( sample_schemas, train=False, only_required=True, @@ -64,11 +64,11 @@ def test_make_predict_inputs_with_seed(sample_schemas): assert is_deterministic -def test_make_predict_inputs_with_fixed_inputs(sample_schemas): +async def test_make_predict_inputs_with_fixed_inputs(sample_schemas): with patch("cog_safe_push.predict.ai.json_object") as mock_json_object: mock_json_object.return_value = {"text": "hello", "number": 42, "choice": "A"} - inputs, _ = make_predict_inputs( + inputs, _ = await make_predict_inputs( sample_schemas, train=False, only_required=True, @@ -80,7 +80,7 @@ def test_make_predict_inputs_with_fixed_inputs(sample_schemas): assert inputs["text"] == "fixed" -def test_make_predict_inputs_with_disabled_inputs(sample_schemas): +async def test_make_predict_inputs_with_disabled_inputs(sample_schemas): with patch("cog_safe_push.predict.ai.json_object") as mock_json_object: mock_json_object.return_value = { "text": "hello", @@ -89,7 +89,7 @@ def test_make_predict_inputs_with_disabled_inputs(sample_schemas): "optional": True, } - inputs, _ = make_predict_inputs( + inputs, _ = await make_predict_inputs( sample_schemas, train=False, only_required=False, @@ -101,7 +101,7 @@ def test_make_predict_inputs_with_disabled_inputs(sample_schemas): assert "optional" not in inputs -def test_make_predict_inputs_with_inputs_history(sample_schemas): +async def test_make_predict_inputs_with_inputs_history(sample_schemas): with patch("cog_safe_push.predict.ai.json_object") as mock_json_object: mock_json_object.return_value = {"text": "new", "number": 100, "choice": "C"} @@ -110,7 +110,7 @@ def test_make_predict_inputs_with_inputs_history(sample_schemas): {"text": "older", "number": 21, "choice": "B"}, ] - inputs, _ = make_predict_inputs( + inputs, _ = await make_predict_inputs( sample_schemas, train=False, only_required=True, @@ -124,14 +124,14 @@ def test_make_predict_inputs_with_inputs_history(sample_schemas): assert inputs != inputs_history[1] -def test_make_predict_inputs_ai_error(sample_schemas): +async def test_make_predict_inputs_ai_error(sample_schemas): with patch("cog_safe_push.predict.ai.json_object") as mock_json_object: mock_json_object.side_effect = [ {"text": "hello"}, # Missing required fields {"text": "hello", "number": 42, "choice": "A"}, # Correct input ] - inputs, _ = make_predict_inputs( + inputs, _ = await make_predict_inputs( sample_schemas, train=False, only_required=True, @@ -144,14 +144,14 @@ def test_make_predict_inputs_ai_error(sample_schemas): assert mock_json_object.call_count == 2 -def test_make_predict_inputs_max_attempts_reached(sample_schemas): +async def test_make_predict_inputs_max_attempts_reached(sample_schemas): with patch("cog_safe_push.predict.ai.json_object") as mock_json_object: mock_json_object.return_value = { "text": "hello" } # Always missing required fields with pytest.raises(AIError): - make_predict_inputs( + await make_predict_inputs( sample_schemas, train=False, only_required=True, From b385806f2b5c4fb835fea3f70cba401d32c0a832 Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Sun, 1 Dec 2024 22:01:08 +0100 Subject: [PATCH 02/10] Separate task to generate fuzz inputs + run all tests in ci --- .github/workflows/ci.yaml | 40 ++++++++++ cog_safe_push/ai.py | 61 ++++++++------- cog_safe_push/config.py | 1 - cog_safe_push/main.py | 30 +++++--- cog_safe_push/predict.py | 12 ++- cog_safe_push/tasks.py | 44 +++++++---- end-to-end-test/test_end_to_end.py | 24 +++--- integration-test/pytest.ini | 3 + .../test_output_matches_prompt.py | 5 ++ requirements-test.txt | 2 + script/end-to-end-test | 3 + script/integration-test | 2 +- script/unit-test | 2 +- setup.py | 2 +- test/pytest.ini | 3 + test/test_main.py | 18 ++--- test/test_match_outputs.py | 74 +++++++++---------- 17 files changed, 205 insertions(+), 121 deletions(-) create mode 100644 integration-test/pytest.ini create mode 100644 script/end-to-end-test create mode 100644 test/pytest.ini diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index eb51912..cfccacb 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -46,3 +46,43 @@ jobs: - name: Run pytest run: | ./script/unit-test + + integration-test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + pip install -r requirements-test.txt + pip install . + + - name: Run pytest + run: | + ./script/integration-test + + end-to-end-test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + pip install -r requirements-test.txt + pip install . + + - name: Run pytest + run: | + ./script/end-to-end-test diff --git a/cog_safe_push/ai.py b/cog_safe_push/ai.py index 71a333c..851ceb7 100644 --- a/cog_safe_push/ai.py +++ b/cog_safe_push/ai.py @@ -76,34 +76,39 @@ async def call( model = "claude-3-5-sonnet-20241022" client = anthropic.AsyncAnthropic(api_key=api_key) - if files: - content = create_content_list(files) - - if include_file_metadata: - prompt += "\n\nMetadata for the attached file(s):\n" - for path in files: - prompt += "* " + file_info(path) + "\n" - - content.append({"type": "text", "text": prompt}) - - log.vvv(f"Claude prompt with {len(files)} files: {prompt}") - else: - content = prompt - log.vvv(f"Claude prompt: {prompt}") - - messages: list[anthropic.types.MessageParam] = [ - {"role": "user", "content": content} - ] - - response = await client.messages.create( - model=model, - messages=messages, - system=system_prompt, - max_tokens=4096, - stream=False, - temperature=1.0, - ) - content = cast(anthropic.types.TextBlock, response.content[0]) + try: + if files: + content = create_content_list(files) + + if include_file_metadata: + prompt += "\n\nMetadata for the attached file(s):\n" + for path in files: + prompt += "* " + file_info(path) + "\n" + + content.append({"type": "text", "text": prompt}) + + log.vvv(f"Claude prompt with {len(files)} files: {prompt}") + else: + content = prompt + log.vvv(f"Claude prompt: {prompt}") + + messages: list[anthropic.types.MessageParam] = [ + {"role": "user", "content": content} + ] + + response = await client.messages.create( + model=model, + messages=messages, + system=system_prompt, + max_tokens=4096, + stream=False, + temperature=1.0, + ) + content = cast(anthropic.types.TextBlock, response.content[0]) + + finally: + await client.close() + output = content.text log.vvv(f"Claude response: {output}") return output diff --git a/cog_safe_push/config.py b/cog_safe_push/config.py index 8eae13e..3e1ca05 100644 --- a/cog_safe_push/config.py +++ b/cog_safe_push/config.py @@ -36,7 +36,6 @@ class FuzzConfig(BaseModel): fixed_inputs: dict[str, InputScalar] = {} disabled_inputs: list[str] = [] - duration: int = DEFAULT_FUZZ_DURATION iterations: int = 10 diff --git a/cog_safe_push/main.py b/cog_safe_push/main.py index 68b00c4..4dfc282 100644 --- a/cog_safe_push/main.py +++ b/cog_safe_push/main.py @@ -2,6 +2,7 @@ import asyncio import re import sys +from asyncio import Queue from pathlib import Path from typing import Any @@ -11,7 +12,6 @@ from . import cog, lint, log, schema from .config import ( - DEFAULT_FUZZ_DURATION, DEFAULT_PREDICT_TIMEOUT, Config, FuzzConfig, @@ -28,6 +28,7 @@ ExactURLOutput, ExpectedOutput, FuzzModel, + MakeFuzzInputs, RunTestCase, Task, ) @@ -106,15 +107,9 @@ def parse_args_and_config() -> tuple[Config, bool]: type=parse_fuzz_disabled_inputs, default=argparse.SUPPRESS, ) - parser.add_argument( - "--fuzz-duration", - help=f"Number of seconds to run fuzzing. Set to 0 for no fuzzing. Default: {DEFAULT_FUZZ_DURATION}", - type=int, - default=argparse.SUPPRESS, - ) parser.add_argument( "--fuzz-iterations", - help="Maximum number of iterations to run fuzzing. Leave blank to run for the full --fuzz-seconds", + help="Maximum number of iterations to run fuzzing.", type=int, default=argparse.SUPPRESS, ) @@ -166,7 +161,6 @@ def parse_args_and_config() -> tuple[Config, bool]: config.predict_override("predict_timeout", args, "predict_timeout") config.predict_fuzz_override("fixed_inputs", args, "fuzz_fixed_inputs") config.predict_fuzz_override("disabled_inputs", args, "fuzz_disabled_inputs") - config.predict_fuzz_override("duration", args, "fuzz_duration") config.predict_fuzz_override("iterations", args, "fuzz_iterations") if not config.test_model: @@ -325,15 +319,25 @@ def cog_safe_push( ) ) - for _ in range(fuzz_iterations): + if fuzz_iterations > 0: + fuzz_inputs_queue = Queue(maxsize=fuzz_iterations) tasks.append( - FuzzModel( + MakeFuzzInputs( context=task_context, + inputs_queue=fuzz_inputs_queue, + num_inputs=fuzz_iterations, fixed_inputs=fuzz_fixed_inputs, disabled_inputs=fuzz_disabled_inputs, - predict_timeout=predict_timeout, ) ) + for _ in range(fuzz_iterations): + tasks.append( + FuzzModel( + context=task_context, + inputs_queue=fuzz_inputs_queue, + predict_timeout=predict_timeout, + ) + ) asyncio.run(run_tasks(tasks, parallel=parallel)) @@ -353,7 +357,9 @@ async def run_tasks(tasks: list[Task], parallel: int) -> None: async def run_with_semaphore(task: Task) -> None: async with semaphore: try: + print(f"starting task {type(task)}") await task.run() + print(f"finished task {type(task)}") except Exception as e: errors.append(e) diff --git a/cog_safe_push/predict.py b/cog_safe_push/predict.py index b8ffb10..a0453bf 100644 --- a/cog_safe_push/predict.py +++ b/cog_safe_push/predict.py @@ -157,6 +157,8 @@ async def make_predict_inputs( * https://storage.googleapis.com/cog-safe-push-public/de-experiment-german-word.ogg * https://storage.googleapis.com/cog-safe-push-public/de-ionendosis-german-word.ogg +If the schema has default values for some of the inputs, feel free to either use the defaults or come up with new values. + """ ) @@ -178,7 +180,7 @@ async def make_predict_inputs( inputs_history_str = "\n".join(["* " + json.dumps(i) for i in inputs_history]) prompt += f""" -Return a new combination of inputs that you haven't used before. You have previously used these inputs: +Return a new combination of inputs that you haven't used before, ideally that's quite diverse from inputs you've used before. You have previously used these inputs: {inputs_history_str}""" inputs = await ai.json_object(prompt) @@ -236,13 +238,17 @@ async def predict( ) else: try: - prediction = await replicate.predictions.async_create( + # await async_create doesn't seem to work here, throws + # RuntimeError: Event loop is closed + # But since we're async sleeping this should only block + # a very short time + prediction = replicate.predictions.create( version=model.versions.list()[0].id, input=inputs ) except ReplicateError as e: if e.status == 404: # Assume it's an official model - prediction = await replicate.predictions.async_create( + prediction = replicate.predictions.create( model=model, input=inputs ) else: diff --git a/cog_safe_push/tasks.py b/cog_safe_push/tasks.py index d34ecb6..9dbb84e 100644 --- a/cog_safe_push/tasks.py +++ b/cog_safe_push/tasks.py @@ -1,3 +1,5 @@ +import asyncio +from asyncio import Queue from dataclasses import dataclass from typing import Any, Protocol @@ -146,33 +148,49 @@ async def run(self) -> None: raise TestCaseFailedError(f"Test case failed: AI error: {str(e)}") -# TODO(andreas): need to make inputs history work @dataclass -class FuzzModel(Task): +class MakeFuzzInputs(Task): context: TaskContext + num_inputs: int + inputs_queue: Queue[dict[str, Any]] fixed_inputs: dict disabled_inputs: list[str] - predict_timeout: int async def run(self) -> None: schemas = schema.get_schemas( self.context.test_model, train=self.context.is_train() ) - predict_inputs, _ = await make_predict_inputs( - schemas, - train=self.context.is_train(), - only_required=False, - seed=None, - fixed_inputs=self.fixed_inputs, - disabled_inputs=self.disabled_inputs, - ) - log.v(f"Fuzzing with inputs: {predict_inputs}") + inputs_history = [] + for _ in range(self.num_inputs): + inputs, _ = await make_predict_inputs( + schemas, + train=self.context.is_train(), + only_required=False, + seed=None, + fixed_inputs=self.fixed_inputs, + disabled_inputs=self.disabled_inputs, + inputs_history=inputs_history, + ) + await self.inputs_queue.put(inputs) + inputs_history.append(inputs) + + +@dataclass +class FuzzModel(Task): + context: TaskContext + inputs_queue: Queue[dict[str, Any]] + predict_timeout: int + + async def run(self) -> None: + inputs = await asyncio.wait_for(self.inputs_queue.get(), timeout=60) + + log.v(f"Fuzzing with inputs: {inputs}") try: output = await predict( model=self.context.test_model, train=self.context.is_train(), train_destination=self.context.train_destination, - inputs=predict_inputs, + inputs=inputs, timeout_seconds=self.predict_timeout, ) except PredictionTimeoutError: diff --git a/end-to-end-test/test_end_to_end.py b/end-to-end-test/test_end_to_end.py index 07b1cda..ebb2008 100644 --- a/end-to-end-test/test_end_to_end.py +++ b/end-to-end-test/test_end_to_end.py @@ -1,4 +1,4 @@ -import datetime +import uuid import json import os from contextlib import contextmanager, suppress @@ -14,14 +14,12 @@ from cog_safe_push.task_context import make_task_context from cog_safe_push.tasks import AIOutput, ExactStringOutput, ExactURLOutput -log.set_verbosity(3) +log.set_verbosity(2) def test_cog_safe_push(): model_owner = "replicate-internal" - model_name = "test-cog-safe-push-" + datetime.datetime.now().strftime( - "%y%m%d-%H%M%S" - ) + model_name = generate_model_name() test_model_name = model_name + "-test" create_model(model_owner, model_name) @@ -89,9 +87,7 @@ def test_cog_safe_push(): def test_cog_safe_push_images(): model_owner = "replicate-internal" - model_name = "test-cog-safe-push-" + datetime.datetime.now().strftime( - "%y%m%d-%H%M%S" - ) + model_name = generate_model_name() test_model_name = model_name + "-test" create_model(model_owner, model_name) @@ -161,9 +157,7 @@ def test_cog_safe_push_images(): def test_cog_safe_push_images_with_seed(): model_owner = "replicate-internal" - model_name = "test-cog-safe-push-" + datetime.datetime.now().strftime( - "%y%m%d-%H%M%S" - ) + model_name = generate_model_name() test_model_name = model_name + "-test" create_model(model_owner, model_name) @@ -189,9 +183,7 @@ def test_cog_safe_push_images_with_seed(): def test_cog_safe_push_train(): model_owner = "replicate-internal" - model_name = "test-cog-safe-push-" + datetime.datetime.now().strftime( - "%y%m%d-%H%M%S" - ) + model_name = generate_model_name() test_model_name = model_name + "-test" create_model(model_owner, model_name) @@ -232,6 +224,10 @@ def test_cog_safe_push_train(): delete_model(model_owner, test_model_name + "-dest") +def generate_model_name(): + return "test-cog-safe-push-" + uuid.uuid4().hex + + def create_model(model_owner, model_name): replicate.models.create( owner=model_owner, diff --git a/integration-test/pytest.ini b/integration-test/pytest.ini new file mode 100644 index 0000000..d4edf3a --- /dev/null +++ b/integration-test/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +asyncio_mode=auto +asyncio_default_fixture_loop_scope="function" \ No newline at end of file diff --git a/integration-test/test_output_matches_prompt.py b/integration-test/test_output_matches_prompt.py index daa0318..749124f 100644 --- a/integration-test/test_output_matches_prompt.py +++ b/integration-test/test_output_matches_prompt.py @@ -29,6 +29,11 @@ "an anime illustration", "a lake", ], + "https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg": [ + "An image of a car", + "A jpg image", + "A image with width 1024px and height 639px", + ], } negative_images = { diff --git a/requirements-test.txt b/requirements-test.txt index 0cbf6bd..da399e6 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,3 +1,5 @@ pytest +pytest-asyncio +pytest-xdist ruff pyright diff --git a/script/end-to-end-test b/script/end-to-end-test new file mode 100644 index 0000000..ef86228 --- /dev/null +++ b/script/end-to-end-test @@ -0,0 +1,3 @@ +#!/bin/bash -eu + +pytest -s end-to-end-test/ diff --git a/script/integration-test b/script/integration-test index e88d525..c90b239 100755 --- a/script/integration-test +++ b/script/integration-test @@ -1,3 +1,3 @@ #!/bin/bash -eu -pytest -s integration_test/ +pytest -n4 -s integration-test/ diff --git a/script/unit-test b/script/unit-test index 06f347a..0d5ea2f 100755 --- a/script/unit-test +++ b/script/unit-test @@ -1,3 +1,3 @@ #!/bin/bash -eu -pytest test/ \ No newline at end of file +pytest test/ diff --git a/setup.py b/setup.py index dcc7912..842ba78 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ version="0.0.1", packages=find_packages(), install_requires=[ - "replicate>=0.31.0,<1", + "replicate>=1.0.3,<2", "anthropic>=0.21.3,<1", "pillow>=10.0.0", "ruff>=0.6.1,<1", diff --git a/test/pytest.ini b/test/pytest.ini new file mode 100644 index 0000000..d4edf3a --- /dev/null +++ b/test/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +asyncio_mode=auto +asyncio_default_fixture_loop_scope="function" \ No newline at end of file diff --git a/test/test_main.py b/test/test_main.py index 6c6d25a..231511c 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -66,8 +66,8 @@ def test_parse_args_fuzz_options(monkeypatch): "key1=value1;key2=42", "--fuzz-disabled-inputs", "key3;key4", - "--fuzz-duration", - "120", + "--fuzz-iterations", + "5", ], ) config, _ = parse_args_and_config() @@ -75,7 +75,7 @@ def test_parse_args_fuzz_options(monkeypatch): assert config.predict.fuzz is not None assert config.predict.fuzz.fixed_inputs == {"key1": "value1", "key2": 42} assert config.predict.fuzz.disabled_inputs == ["key3", "key4"] - assert config.predict.fuzz.duration == 120 + assert config.predict.fuzz.iterations == 5 def test_parse_args_test_case(monkeypatch): @@ -139,7 +139,7 @@ def test_parse_config_file(tmp_path, monkeypatch): key1: value1 disabled_inputs: - key2 - duration: 150 + iterations: 15 """ config_file = tmp_path / "cog-safe-push.yaml" config_file.write_text(config_yaml) @@ -159,7 +159,7 @@ def test_parse_config_file(tmp_path, monkeypatch): assert config.predict.test_cases[0].exact_string == "expected output" assert config.predict.fuzz.fixed_inputs == {"key1": "value1"} assert config.predict.fuzz.disabled_inputs == ["key2"] - assert config.predict.fuzz.duration == 150 + assert config.predict.fuzz.iterations == 15 def test_config_override_with_args(tmp_path, monkeypatch): @@ -261,8 +261,7 @@ def test_parse_config_with_train(tmp_path, monkeypatch): input1: value1 match_prompt: An AI generated output fuzz: - duration: 300 - iterations: 10 + iterations: 8 """ config_file = tmp_path / "cog-safe-push.yaml" config_file.write_text(config_yaml) @@ -280,8 +279,7 @@ def test_parse_config_with_train(tmp_path, monkeypatch): assert len(config.train.test_cases) == 1 assert config.train.test_cases[0].inputs == {"input1": "value1"} assert config.train.test_cases[0].match_prompt == "An AI generated output" - assert config.train.fuzz.duration == 300 - assert config.train.fuzz.iterations == 10 + assert config.train.fuzz.iterations == 8 def test_parse_args_with_default_config(tmp_path, monkeypatch): @@ -376,7 +374,7 @@ def test_parse_config_missing_fuzz_section(tmp_path, monkeypatch): config_file.write_text(config_yaml) monkeypatch.setattr( "sys.argv", - ["cog-safe-push", "--config", str(config_file), "--fuzz-duration", "600"], + ["cog-safe-push", "--config", str(config_file), "--fuzz-iterations", "20"], ) with pytest.raises(ArgumentError, match="missing a predict.fuzz section"): diff --git a/test/test_match_outputs.py b/test/test_match_outputs.py index 959738d..7c41143 100644 --- a/test/test_match_outputs.py +++ b/test/test_match_outputs.py @@ -12,18 +12,18 @@ ) -def test_identical_strings(): - assert outputs_match("hello", "hello", True) == (True, "") +async def test_identical_strings(): + assert await outputs_match("hello", "hello", True) == (True, "") -def test_different_strings_deterministic(): - assert outputs_match("hello", "world", True) == (False, "Strings aren't the same") +async def test_different_strings_deterministic(): + assert await outputs_match("hello", "world", True) == (False, "Strings aren't the same") @patch("cog_safe_push.predict.ai.boolean") -def test_different_strings_non_deterministic(mock_ai_boolean): +async def test_different_strings_non_deterministic(mock_ai_boolean): mock_ai_boolean.return_value = True - assert outputs_match("The quick brown fox", "A fast auburn canine", False) == ( + assert await outputs_match("The quick brown fox", "A fast auburn canine", False) == ( True, "", ) @@ -31,92 +31,92 @@ def test_different_strings_non_deterministic(mock_ai_boolean): mock_ai_boolean.reset_mock() mock_ai_boolean.return_value = False - assert outputs_match( + assert await outputs_match( "The quick brown fox", "Something completely different", False ) == (False, "Strings aren't similar") mock_ai_boolean.assert_called_once() -def test_identical_booleans(): - assert outputs_match(True, True, True) == (True, "") +async def test_identical_booleans(): + assert await outputs_match(True, True, True) == (True, "") -def test_different_booleans(): - assert outputs_match(True, False, True) == (False, "Booleans aren't identical") +async def test_different_booleans(): + assert await outputs_match(True, False, True) == (False, "Booleans aren't identical") -def test_identical_integers(): - assert outputs_match(42, 42, True) == (True, "") +async def test_identical_integers(): + assert await outputs_match(42, 42, True) == (True, "") -def test_different_integers(): - assert outputs_match(42, 43, True) == (False, "Integers aren't identical") +async def test_different_integers(): + assert await outputs_match(42, 43, True) == (False, "Integers aren't identical") -def test_close_floats(): - assert outputs_match(3.14, 3.14001, True) == (True, "") +async def test_close_floats(): + assert await outputs_match(3.14, 3.14001, True) == (True, "") -def test_different_floats(): - assert outputs_match(3.14, 3.25, True) == (False, "Floats aren't identical") +async def test_different_floats(): + assert await outputs_match(3.14, 3.25, True) == (False, "Floats aren't identical") -def test_identical_dicts(): +async def test_identical_dicts(): dict1 = {"a": 1, "b": "hello"} dict2 = {"a": 1, "b": "hello"} - assert outputs_match(dict1, dict2, True) == (True, "") + assert await outputs_match(dict1, dict2, True) == (True, "") -def test_different_dict_values(): +async def test_different_dict_values(): dict1 = {"a": 1, "b": "hello"} dict2 = {"a": 1, "b": "world"} - assert outputs_match(dict1, dict2, True) == (False, "In b: Strings aren't the same") + assert await outputs_match(dict1, dict2, True) == (False, "In b: Strings aren't the same") -def test_different_dict_keys(): +async def test_different_dict_keys(): dict1 = {"a": 1, "b": "hello"} dict2 = {"a": 1, "c": "hello"} - assert outputs_match(dict1, dict2, True) == (False, "Dict keys don't match") + assert await outputs_match(dict1, dict2, True) == (False, "Dict keys don't match") -def test_identical_lists(): +async def test_identical_lists(): list1 = [1, "hello", True] list2 = [1, "hello", True] - assert outputs_match(list1, list2, True) == (True, "") + assert await outputs_match(list1, list2, True) == (True, "") -def test_different_list_values(): +async def test_different_list_values(): list1 = [1, "hello", True] list2 = [1, "world", True] - assert outputs_match(list1, list2, True) == ( + assert await outputs_match(list1, list2, True) == ( False, "At index 1: Strings aren't the same", ) -def test_different_list_lengths(): +async def test_different_list_lengths(): list1 = [1, 2, 3] list2 = [1, 2] - assert outputs_match(list1, list2, True) == (False, "List lengths don't match") + assert await outputs_match(list1, list2, True) == (False, "List lengths don't match") -def test_nested_structures(): +async def test_nested_structures(): struct1 = {"a": [1, {"b": "hello"}], "c": True} struct2 = {"a": [1, {"b": "hello"}], "c": True} - assert outputs_match(struct1, struct2, True) == (True, "") + assert await outputs_match(struct1, struct2, True) == (True, "") -def test_different_nested_structures(): +async def test_different_nested_structures(): struct1 = {"a": [1, {"b": "hello"}], "c": True} struct2 = {"a": [1, {"b": "world"}], "c": True} - assert outputs_match(struct1, struct2, True) == ( + assert await outputs_match(struct1, struct2, True) == ( False, "In a: At index 1: In b: Strings aren't the same", ) -def test_different_types(): - assert outputs_match("42", 42, True) == ( +async def test_different_types(): + assert await outputs_match("42", 42, True) == ( False, "The types of the outputs don't match", ) From 4b1c7f7fa254385fff4e17048f8f711069b35c32 Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Sun, 1 Dec 2024 22:04:57 +0100 Subject: [PATCH 03/10] lint --- cog_safe_push/main.py | 4 ++-- cog_safe_push/predict.py | 4 +--- end-to-end-test/test_end_to_end.py | 2 +- test/test_match_outputs.py | 24 +++++++++++++++++++----- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/cog_safe_push/main.py b/cog_safe_push/main.py index 4dfc282..bd8fa39 100644 --- a/cog_safe_push/main.py +++ b/cog_safe_push/main.py @@ -190,7 +190,7 @@ def run_config(config: Config, no_push: bool): fuzz = config.train.fuzz else: fuzz = FuzzConfig( - fixed_inputs={}, disabled_inputs=[], duration=0, iterations=0 + fixed_inputs={}, disabled_inputs=[], iterations=0 ) task_context = make_task_context( model_owner=model_owner, @@ -222,7 +222,7 @@ def run_config(config: Config, no_push: bool): fuzz = config.predict.fuzz else: fuzz = FuzzConfig( - fixed_inputs={}, disabled_inputs=[], duration=0, iterations=0 + fixed_inputs={}, disabled_inputs=[], iterations=0 ) if task_context is None: # has not been created in the training block above task_context = make_task_context( diff --git a/cog_safe_push/predict.py b/cog_safe_push/predict.py index a0453bf..9c2fd6c 100644 --- a/cog_safe_push/predict.py +++ b/cog_safe_push/predict.py @@ -248,9 +248,7 @@ async def predict( except ReplicateError as e: if e.status == 404: # Assume it's an official model - prediction = replicate.predictions.create( - model=model, input=inputs - ) + prediction = replicate.predictions.create(model=model, input=inputs) else: raise diff --git a/end-to-end-test/test_end_to_end.py b/end-to-end-test/test_end_to_end.py index ebb2008..a8037f2 100644 --- a/end-to-end-test/test_end_to_end.py +++ b/end-to-end-test/test_end_to_end.py @@ -1,6 +1,6 @@ -import uuid import json import os +import uuid from contextlib import contextmanager, suppress from pathlib import Path diff --git a/test/test_match_outputs.py b/test/test_match_outputs.py index 7c41143..4d667e3 100644 --- a/test/test_match_outputs.py +++ b/test/test_match_outputs.py @@ -17,13 +17,18 @@ async def test_identical_strings(): async def test_different_strings_deterministic(): - assert await outputs_match("hello", "world", True) == (False, "Strings aren't the same") + assert await outputs_match("hello", "world", True) == ( + False, + "Strings aren't the same", + ) @patch("cog_safe_push.predict.ai.boolean") async def test_different_strings_non_deterministic(mock_ai_boolean): mock_ai_boolean.return_value = True - assert await outputs_match("The quick brown fox", "A fast auburn canine", False) == ( + assert await outputs_match( + "The quick brown fox", "A fast auburn canine", False + ) == ( True, "", ) @@ -42,7 +47,10 @@ async def test_identical_booleans(): async def test_different_booleans(): - assert await outputs_match(True, False, True) == (False, "Booleans aren't identical") + assert await outputs_match(True, False, True) == ( + False, + "Booleans aren't identical", + ) async def test_identical_integers(): @@ -70,7 +78,10 @@ async def test_identical_dicts(): async def test_different_dict_values(): dict1 = {"a": 1, "b": "hello"} dict2 = {"a": 1, "b": "world"} - assert await outputs_match(dict1, dict2, True) == (False, "In b: Strings aren't the same") + assert await outputs_match(dict1, dict2, True) == ( + False, + "In b: Strings aren't the same", + ) async def test_different_dict_keys(): @@ -97,7 +108,10 @@ async def test_different_list_values(): async def test_different_list_lengths(): list1 = [1, 2, 3] list2 = [1, 2] - assert await outputs_match(list1, list2, True) == (False, "List lengths don't match") + assert await outputs_match(list1, list2, True) == ( + False, + "List lengths don't match", + ) async def test_nested_structures(): From d8b2ba4f7bcdcc3210fab59e7b09bf0d9476973c Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Sun, 1 Dec 2024 22:06:51 +0100 Subject: [PATCH 04/10] Add secrets to test --- .github/workflows/ci.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index cfccacb..ac30fb3 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -64,6 +64,8 @@ jobs: pip install . - name: Run pytest + env: + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} run: | ./script/integration-test @@ -84,5 +86,8 @@ jobs: pip install . - name: Run pytest + env: + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }} run: | ./script/end-to-end-test From 2a374a4a6f8009d4c179e2bce9a9e68449498908 Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Sun, 1 Dec 2024 22:07:51 +0100 Subject: [PATCH 05/10] lint --- cog_safe_push/main.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/cog_safe_push/main.py b/cog_safe_push/main.py index bd8fa39..fd0b837 100644 --- a/cog_safe_push/main.py +++ b/cog_safe_push/main.py @@ -189,9 +189,7 @@ def run_config(config: Config, no_push: bool): if config.train.fuzz: fuzz = config.train.fuzz else: - fuzz = FuzzConfig( - fixed_inputs={}, disabled_inputs=[], iterations=0 - ) + fuzz = FuzzConfig(fixed_inputs={}, disabled_inputs=[], iterations=0) task_context = make_task_context( model_owner=model_owner, model_name=model_name, @@ -221,9 +219,7 @@ def run_config(config: Config, no_push: bool): if config.predict.fuzz: fuzz = config.predict.fuzz else: - fuzz = FuzzConfig( - fixed_inputs={}, disabled_inputs=[], iterations=0 - ) + fuzz = FuzzConfig(fixed_inputs={}, disabled_inputs=[], iterations=0) if task_context is None: # has not been created in the training block above task_context = make_task_context( model_owner=model_owner, From 49ebcf9963e15b42f5e8449c5fcf531ddf81555d Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Sun, 1 Dec 2024 22:08:20 +0100 Subject: [PATCH 06/10] make end-to-end-test executable --- script/end-to-end-test | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 script/end-to-end-test diff --git a/script/end-to-end-test b/script/end-to-end-test old mode 100644 new mode 100755 From c86d2cb4ac5c3b025e67d191deb7aef2a32779fe Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Sun, 1 Dec 2024 22:08:58 +0100 Subject: [PATCH 07/10] run end-to-end test in parallel --- script/end-to-end-test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/script/end-to-end-test b/script/end-to-end-test index ef86228..bdcca0c 100755 --- a/script/end-to-end-test +++ b/script/end-to-end-test @@ -1,3 +1,3 @@ #!/bin/bash -eu -pytest -s end-to-end-test/ +pytest -n4 -s -v end-to-end-test/ From 8c3685bd9ef269427bae49fbd5b73e8a3d5a60ae Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Sun, 1 Dec 2024 22:12:10 +0100 Subject: [PATCH 08/10] Install cog in CI --- .github/workflows/ci.yaml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ac30fb3..e8bebda 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -85,6 +85,15 @@ jobs: pip install -r requirements-test.txt pip install . + - name: Install Cog + run: | + sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m)" + sudo chmod +x /usr/local/bin/cog + + - name: cog login + run: | + echo ${{ secrets.COG_TOKEN }} | cog login --token-stdin + - name: Run pytest env: ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} From 5af3bcea41d16bcaf0bf028e579a1036f3e24193 Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Sun, 1 Dec 2024 22:22:55 +0100 Subject: [PATCH 09/10] don't fail e2e test if delete causes httpx.RemoteProtocolError --- end-to-end-test/test_end_to_end.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/end-to-end-test/test_end_to_end.py b/end-to-end-test/test_end_to_end.py index a8037f2..a0ae233 100644 --- a/end-to-end-test/test_end_to_end.py +++ b/end-to-end-test/test_end_to_end.py @@ -4,6 +4,7 @@ from contextlib import contextmanager, suppress from pathlib import Path +import httpx import pytest import replicate from replicate.exceptions import ReplicateException @@ -244,16 +245,17 @@ def delete_model(model_owner, model_name): # model likely doesn't exist return - for version in model.versions.list(): - print(f"Deleting version {version.id}") + with suppress(httpx.RemoteProtocolError): + for version in model.versions.list(): + print(f"Deleting version {version.id}") + with suppress(json.JSONDecodeError): + # bug in replicate-python causes delete to throw JSONDecodeError + model.versions.delete(version.id) + + print(f"Deleting model {model_owner}/{model_name}") with suppress(json.JSONDecodeError): # bug in replicate-python causes delete to throw JSONDecodeError - model.versions.delete(version.id) - - print(f"Deleting model {model_owner}/{model_name}") - with suppress(json.JSONDecodeError): - # bug in replicate-python causes delete to throw JSONDecodeError - replicate.models.delete(model_owner, model_name) + replicate.models.delete(model_owner, model_name) @contextmanager From 453c2838cbb35555ae4fef9f1866778c7bb04ba7 Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Sun, 1 Dec 2024 22:24:46 +0100 Subject: [PATCH 10/10] run e2e test on beefier ci instance --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e8bebda..6903a82 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -70,7 +70,7 @@ jobs: ./script/integration-test end-to-end-test: - runs-on: ubuntu-latest + runs-on: ubuntu-latest-4-cores steps: - uses: actions/checkout@v3