Skip to content

Commit

Permalink
struc outputs support
Browse files Browse the repository at this point in the history
  • Loading branch information
jalexanderII committed Aug 7, 2024
1 parent a8edb5c commit dfea695
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 29 deletions.
34 changes: 18 additions & 16 deletions cookbook/openai/tracing_with_openai_with_structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from dotenv import load_dotenv
from openai import OpenAI
from parea import Parea
from pydantic import BaseModel

from parea import Parea

load_dotenv()

client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
Expand All @@ -18,17 +19,20 @@ class CalendarEvent(BaseModel):
participants: list[str]


completion = client.beta.chat.completions.parse(
model="gpt-4o-2024-08-06",
messages=[
{"role": "system", "content": "Extract the event information."},
{"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
],
response_format=CalendarEvent,
)
def with_pydantic():
completion = client.beta.chat.completions.parse(
model="gpt-4o-2024-08-06",
messages=[
{"role": "system", "content": "Extract the event information."},
{"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
],
response_format=CalendarEvent,
)
event = completion.choices[0].message.parsed
print(event)


def main():
def with_json_schema():
response = client.chat.completions.create(
model="gpt-4o-2024-08-06",
messages=[
Expand Down Expand Up @@ -60,11 +64,10 @@ def main():
},
},
)

print(response.choices[0].message.content)


def main2():
def with_tools():
tools = [
{
"type": "function",
Expand Down Expand Up @@ -101,7 +104,6 @@ def main2():


if __name__ == "__main__":
# event = completion.choices[0].message.parsed
# print(event)
# main()
main2()
with_pydantic()
with_json_schema()
with_tools()
1 change: 1 addition & 0 deletions cookbook/use_dataset_for_finetuning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from dotenv import load_dotenv

from parea import Parea

load_dotenv()
Expand Down
3 changes: 2 additions & 1 deletion parea/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Dict, Union

import os

from dotenv import load_dotenv

load_dotenv()
Expand Down
4 changes: 3 additions & 1 deletion parea/parea_logger.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any, Dict, Optional

import json
import logging
import os
from typing import Any, Dict, Optional

from attrs import asdict, define, field
from cattrs import structure

from parea.api_client import HTTPClient
from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID
from parea.helpers import serialize_metadata_values
Expand Down
4 changes: 3 additions & 1 deletion parea/schemas/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import json
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from attrs import define, field, validators

from parea.schemas import EvaluationResult
from parea.schemas.log import EvaluatedLog, LLMInputs

Expand Down
3 changes: 2 additions & 1 deletion parea/utils/trace_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Generator, Iterator, List, Optional, Tuple

import contextvars
import inspect
import json
Expand All @@ -9,7 +11,6 @@
from datetime import datetime
from functools import wraps
from random import random
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Generator, Iterator, List, Optional, Tuple

from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID, TURN_OFF_PAREA_EVAL_LOGGING
from parea.helpers import gen_trace_id, is_logging_disabled, timezone_aware_now
Expand Down
3 changes: 2 additions & 1 deletion parea/utils/universal_encoder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any

import dataclasses
import datetime
import json
import logging
from decimal import Decimal
from enum import Enum
from typing import Any
from uuid import UUID

import attrs
Expand Down
11 changes: 6 additions & 5 deletions parea/wrapper/openai/openai.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Generator, Iterator, Optional, Sequence, TypeVar, Union

import json
import os
from collections import defaultdict
from datetime import datetime
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Generator, Iterator, Optional, Sequence, TypeVar, \
Union

import openai
from openai import __version__ as openai_version

from parea.helpers import timezone_aware_now
from parea.utils.universal_encoder import json_dumps
from parea.wrapper.utils import _calculate_input_tokens, _compute_cost, _format_function_call, \
_kwargs_to_llm_configuration, _num_tokens_from_string
from parea.wrapper.utils import _calculate_input_tokens, _compute_cost, _format_function_call, _kwargs_to_llm_configuration, _num_tokens_from_string

if openai_version.startswith("0."):
from openai.openai_object import OpenAIObject
from openai.util import convert_to_openai_object
else:
from openai.types.chat import ChatCompletion as OpenAIObject, ParsedChatCompletionMessage
from openai.types.chat import ChatCompletion as OpenAIObject
from openai.types.chat import ParsedChatCompletion as OpenAIObjectParsed
from openai.types.chat import ParsedChatCompletionMessage

def convert_to_openai_object(kwargs) -> OpenAIObject:
if "id" not in kwargs:
Expand Down
9 changes: 6 additions & 3 deletions parea/wrapper/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from typing import Callable, Dict, List, Optional, Union

import json
import re
import sys
from functools import lru_cache, wraps
from typing import Callable, Dict, List, Optional, Union

import tiktoken
from openai import __version__ as openai_version, NotGiven
from openai import NotGiven
from openai import __version__ as openai_version
from pydantic._internal._model_construction import ModelMetaclass

from parea.constants import ALL_NON_AZURE_MODELS_INFO, AZURE_MODEL_INFO, TURN_OFF_PAREA_EVAL_LOGGING
from parea.parea_logger import parea_logger
from parea.schemas.log import LLMInputs, Message, ModelParams, Role
from parea.schemas.models import UpdateTraceScenario
from parea.utils.trace_utils import fill_trace_data, get_current_trace_id, log_in_thread, trace_data, trace_insert
from parea.utils.universal_encoder import json_dumps
from pydantic._internal._model_construction import ModelMetaclass

is_openai_1 = openai_version.startswith("1.")
if is_openai_1:
Expand Down

0 comments on commit dfea695

Please sign in to comment.