Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tasks #12

Merged
merged 10 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,57 @@ jobs:
- name: Run pytest
run: |
./script/unit-test

integration-test:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'

- name: Install dependencies
run: |
pip install -r requirements-test.txt
pip install .

- name: Run pytest
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
run: |
./script/integration-test

end-to-end-test:
runs-on: ubuntu-latest-4-cores
andreasjansson marked this conversation as resolved.
Show resolved Hide resolved

steps:
- uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'

- name: Install dependencies
run: |
pip install -r requirements-test.txt
pip install .

- name: Install Cog
run: |
sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m)"
sudo chmod +x /usr/local/bin/cog

- name: cog login
run: |
echo ${{ secrets.COG_TOKEN }} | cog login --token-stdin

- name: Run pytest
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }}
run: |
./script/end-to-end-test
104 changes: 65 additions & 39 deletions cog_safe_push/ai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import functools
import json
import mimetypes
import os
Expand All @@ -10,16 +11,36 @@

from . import log
from .exceptions import AIError, ArgumentError
from .retry import retry


@retry(3)
def boolean(
def async_retry(attempts=3):
def decorator_retry(func):
@functools.wraps(func)
async def wrapper_retry(*args, **kwargs):
for attempt in range(1, attempts + 1):
try:
return await func(*args, **kwargs)
except Exception as e:
log.warning(f"Exception occurred: {e}")
if attempt < attempts:
log.warning(f"Retrying attempt {attempt}/{attempts}")
else:
log.warning(f"Giving up after {attempts} attempts")
raise
return None

return wrapper_retry

return decorator_retry


@async_retry(3)
async def boolean(
prompt: str, files: list[Path] | None = None, include_file_metadata: bool = False
) -> bool:
system_prompt = "You only answer YES or NO, and absolutely nothing else. Your response will be used in a programmatic context so it's important that you only ever answer with either the string YES or the string NO."
#system_prompt = "You are a helpful assistant"
output = call(
# system_prompt = "You are a helpful assistant"
output = await call(
system_prompt=system_prompt,
prompt=prompt.strip(),
files=files,
Expand All @@ -32,17 +53,17 @@ def boolean(
raise AIError(f"Failed to parse output as YES/NO: {output}")


@retry(3)
def json_object(prompt: str, files: list[Path] | None = None) -> dict:
@async_retry(3)
async def json_object(prompt: str, files: list[Path] | None = None) -> dict:
system_prompt = "You always respond with valid JSON, and nothing else (no backticks, etc.). Your outputs will be used in a programmatic context."
output = call(system_prompt=system_prompt, prompt=prompt.strip(), files=files)
output = await call(system_prompt=system_prompt, prompt=prompt.strip(), files=files)
try:
return json.loads(output)
except json.JSONDecodeError:
raise AIError(f"Failed to parse output as JSON: {output}")


def call(
async def call(
system_prompt: str,
prompt: str,
files: list[Path] | None = None,
Expand All @@ -53,36 +74,41 @@ def call(
raise ArgumentError("ANTHROPIC_API_KEY is not defined")

model = "claude-3-5-sonnet-20241022"
client = anthropic.Anthropic(api_key=api_key)

if files:
content = create_content_list(files)

if include_file_metadata:
prompt += "\n\nMetadata for the attached file(s):\n"
for path in files:
prompt += f"* " + file_info(path) + "\n"

content.append({"type": "text", "text": prompt})

log.vvv(f"Claude prompt with {len(files)} files: {prompt}")
else:
content = prompt
log.vvv(f"Claude prompt: {prompt}")

messages: list[anthropic.types.MessageParam] = [
{"role": "user", "content": content}
]

response = client.messages.create(
model=model,
messages=messages,
system=system_prompt,
max_tokens=4096,
stream=False,
temperature=1.0,
)
content = cast(anthropic.types.TextBlock, response.content[0])
client = anthropic.AsyncAnthropic(api_key=api_key)

try:
if files:
content = create_content_list(files)

if include_file_metadata:
prompt += "\n\nMetadata for the attached file(s):\n"
for path in files:
prompt += "* " + file_info(path) + "\n"

content.append({"type": "text", "text": prompt})

log.vvv(f"Claude prompt with {len(files)} files: {prompt}")
else:
content = prompt
log.vvv(f"Claude prompt: {prompt}")

messages: list[anthropic.types.MessageParam] = [
{"role": "user", "content": content}
]

response = await client.messages.create(
model=model,
messages=messages,
system=system_prompt,
max_tokens=4096,
stream=False,
temperature=1.0,
)
content = cast(anthropic.types.TextBlock, response.content[0])

finally:
await client.close()

output = content.text
log.vvv(f"Claude response: {output}")
return output
Expand Down
4 changes: 2 additions & 2 deletions cog_safe_push/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ class FuzzConfig(BaseModel):

fixed_inputs: dict[str, InputScalar] = {}
disabled_inputs: list[str] = []
duration: int = DEFAULT_FUZZ_DURATION
iterations: int | None = None
iterations: int = 10


class PredictConfig(BaseModel):
Expand Down Expand Up @@ -68,6 +67,7 @@ class Config(BaseModel):
predict: PredictConfig | None = None
train: TrainConfig | None = None
dockerfile: str | None = None
parallel: int = 4

def override(self, field: str, args: argparse.Namespace, arg: str):
if hasattr(args, arg) and getattr(args, arg) is not None:
Expand Down
Loading