Skip to content

Commit

Permalink
Merge pull request #1037 from parea-ai/PAI-1442-openai-structured-out…
Browse files Browse the repository at this point in the history
…puts

openai-structured-outputs
  • Loading branch information
jalexanderII authored Aug 7, 2024
2 parents 8f4604a + 33a5374 commit cfcda67
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 19 deletions.
109 changes: 109 additions & 0 deletions cookbook/openai/tracing_with_openai_with_structured_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import os

from dotenv import load_dotenv
from openai import OpenAI
from pydantic import BaseModel

from parea import Parea

load_dotenv()

client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
p.wrap_openai_client(client)


class CalendarEvent(BaseModel):
name: str
date: str
participants: list[str]


def with_pydantic():
completion = client.beta.chat.completions.parse(
model="gpt-4o-2024-08-06",
messages=[
{"role": "system", "content": "Extract the event information."},
{"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
],
response_format=CalendarEvent,
)
event = completion.choices[0].message.parsed
print(event)


def with_json_schema():
response = client.chat.completions.create(
model="gpt-4o-2024-08-06",
messages=[
{"role": "system", "content": "You are a helpful math tutor. Guide the user through the solution step by step."},
{"role": "user", "content": "how can I solve 8x + 7 = -23"},
],
response_format={
"type": "json_schema",
"json_schema": {
"name": "math_response",
"schema": {
"type": "object",
"properties": {
"steps": {
"type": "array",
"items": {
"type": "object",
"properties": {"explanation": {"type": "string"}, "output": {"type": "string"}},
"required": ["explanation", "output"],
"additionalProperties": False,
},
},
"final_answer": {"type": "string"},
},
"required": ["steps", "final_answer"],
"additionalProperties": False,
},
"strict": True,
},
},
)
print(response.choices[0].message.content)


def with_tools():
tools = [
{
"type": "function",
"function": {
"name": "get_delivery_date",
"description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'",
"parameters": {
"type": "object",
"properties": {
"order_id": {
"type": "string",
"description": "The customer's order ID.",
},
},
"required": ["order_id"],
"additionalProperties": False,
},
},
"strict": True,
}
]

messages = [
{"role": "system", "content": "You are a helpful customer support assistant. Use the supplied tools to assist the user."},
{"role": "user", "content": "Hi, can you tell me the delivery date for my order with id 5?"},
]

response = client.chat.completions.create(
model="gpt-4o-2024-08-06",
messages=messages,
tools=tools,
)
print(response.choices[0].message.tool_calls)


if __name__ == "__main__":
with_pydantic()
with_json_schema()
with_tools()
5 changes: 5 additions & 0 deletions parea/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ def str2bool(v):
"completion": 15.0,
"token_limit": {"max_completion_tokens": 4096, "max_prompt_tokens": 128000},
},
"gpt-4o-2024-08-06": {
"prompt": 5.0,
"completion": 15.0,
"token_limit": {"max_completion_tokens": 4096, "max_prompt_tokens": 128000},
},
"gpt-4o-mini": {
"prompt": 0.15,
"completion": 0.6,
Expand Down
6 changes: 4 additions & 2 deletions parea/utils/trace_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def log_in_thread(target_func: Callable, data: Dict[str, Any]):
logging_thread.start()


def merge(old, new):
def merge(old, new, key=None):
if key == "error" and old:
return json_dumps([old, new])
if isinstance(old, dict) and isinstance(new, dict):
return dict(ChainMap(new, old))
if isinstance(old, list) and isinstance(new, list):
Expand Down Expand Up @@ -112,7 +114,7 @@ def trace_insert(data: Dict[str, Any], trace_id: Optional[str] = None):
return
for key, new_value in data.items():
existing_value = current_trace_data.__getattribute__(key)
current_trace_data.__setattr__(key, merge(existing_value, new_value) if existing_value else new_value)
current_trace_data.__setattr__(key, merge(existing_value, new_value, key) if existing_value else new_value)
except Exception as e:
logger.debug(f"Error occurred inserting data into trace log, {e}", exc_info=e)

Expand Down
12 changes: 12 additions & 0 deletions parea/utils/universal_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ def handle_dspy_response(self, obj) -> Any:
else:
return None

def handle_openai_not_given(self, obj) -> Any:
try:
from openai import NotGiven
except ImportError:
return None

if isinstance(obj, NotGiven):
return {"not_given": None}
return None

def default(self, obj: Any):
if isinstance(obj, str):
return obj
Expand Down Expand Up @@ -116,6 +126,8 @@ def default(self, obj: Any):
return obj.to_dict(orient="records")
elif dspy_response := self.handle_dspy_response(obj):
return dspy_response
elif is_openai_not_given := self.handle_openai_not_given(obj):
return is_openai_not_given["not_given"]
elif callable(obj):
try:
return f"<callable {obj.__name__}>"
Expand Down
21 changes: 18 additions & 3 deletions parea/wrapper/openai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ def get_original_methods(self, module_client=openai):
original_methods = {"chat.completions.create": module_client.chat.completions.create}
except openai.OpenAIError:
original_methods = {}

try:
latest_methods = {"beta.chat.completions.parse": module_client.beta.chat.completions.parse}
original_methods.update(latest_methods)
except Exception:
pass

return list(original_methods.keys())

def init(self, log: Callable, cache: Cache = None, module_client=openai):
Expand Down Expand Up @@ -103,7 +110,7 @@ def resolver(self, trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any],
trace_data.get()[trace_id].output_tokens = output_tokens
trace_data.get()[trace_id].total_tokens = total_tokens
trace_data.get()[trace_id].cost = _compute_cost(input_tokens, output_tokens, model)
trace_data.get()[trace_id].output = output
trace_data.get()[trace_id].output = json_dumps(output) if not isinstance(output, str) else output
return response

def gen_resolver(self, trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], response, final_log):
Expand Down Expand Up @@ -269,7 +276,13 @@ def _kwargs_to_llm_configuration(kwargs):

@staticmethod
def _get_output(result: Any, model: Optional[str] = None) -> str:
if not isinstance(result, OpenAIObject) and isinstance(result, dict):
try:
from openai.types.chat import ParsedChatCompletion, ParsedChatCompletionMessage
except ImportError:
ParsedChatCompletion = None
ParsedChatCompletionMessage = None

if not isinstance(result, (OpenAIObject, ParsedChatCompletion)) and isinstance(result, dict):
result = convert_to_openai_object(
{
"choices": [
Expand All @@ -282,7 +295,9 @@ def _get_output(result: Any, model: Optional[str] = None) -> str:
}
)
response_message = result.choices[0].message
if not response_message.get("content", None) if is_old_openai else not response_message.content:
if isinstance(response_message, ParsedChatCompletionMessage):
completion = response_message.parsed.model_dump_json() if response_message.parsed else ""
elif not response_message.get("content", None) if is_old_openai else not response_message.content:
completion = OpenAIWrapper._format_function_call(response_message)
else:
completion = response_message.content
Expand Down
35 changes: 26 additions & 9 deletions parea/wrapper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from functools import lru_cache, wraps

import tiktoken
from openai import NotGiven
from openai import __version__ as openai_version
from pydantic._internal._model_construction import ModelMetaclass

from parea.constants import ALL_NON_AZURE_MODELS_INFO, AZURE_MODEL_INFO, TURN_OFF_PAREA_EVAL_LOGGING
from parea.parea_logger import parea_logger
Expand Down Expand Up @@ -220,9 +222,12 @@ def clean_json_string(s):

def _resolve_functions(kwargs):
if "functions" in kwargs:
return kwargs.get("functions", [])
f = kwargs.get("functions", [])
return None if isinstance(f, NotGiven) else f
elif "tools" in kwargs:
tools = kwargs["tools"]
if isinstance(tools, NotGiven):
return None
if isinstance(tools, list):
return [d.get("function", {}) for d in tools]

Expand All @@ -234,19 +239,27 @@ def _resolve_functions(kwargs):
def _kwargs_to_llm_configuration(kwargs, model=None) -> LLMInputs:
functions = _resolve_functions(kwargs)
function_call_default = "auto" if functions else None
function_call = kwargs.get("function_call", function_call_default) or kwargs.get("tool_choice", function_call_default)
response_format = kwargs.get("response_format", None)
response_format = {"type": "json_schema", "json_schema": str(response_format)} if isinstance(response_format, ModelMetaclass) else response_format
temp = kwargs.get("temperature", 1.0)
max_length = kwargs.get("max_tokens", None)
top_p = kwargs.get("top_p", 1.0)
frequency_penalty = kwargs.get("frequency_penalty", 0.0)
presence_penalty = kwargs.get("presence_penalty", 0.0)
return LLMInputs(
model=model or kwargs.get("model", None),
provider="openai",
messages=_convert_oai_messages(kwargs.get("messages", None)),
functions=functions,
function_call=kwargs.get("function_call", function_call_default) or kwargs.get("tool_choice", function_call_default),
function_call=None if isinstance(function_call, NotGiven) else function_call,
model_params=ModelParams(
temp=kwargs.get("temperature", 1.0),
max_length=kwargs.get("max_tokens", None),
top_p=kwargs.get("top_p", 1.0),
frequency_penalty=kwargs.get("frequency_penalty", 0.0),
presence_penalty=kwargs.get("presence_penalty", 0.0),
response_format=kwargs.get("response_format", None),
temp=None if isinstance(temp, NotGiven) else temp,
max_length=None if isinstance(max_length, NotGiven) else max_length,
top_p=None if isinstance(top_p, NotGiven) else top_p,
frequency_penalty=None if isinstance(frequency_penalty, NotGiven) else frequency_penalty,
presence_penalty=None if isinstance(presence_penalty, NotGiven) else presence_penalty,
response_format=response_format,
),
)

Expand Down Expand Up @@ -302,7 +315,11 @@ def _compute_cost(prompt_tokens: int, completion_tokens: int, model: str) -> flo

def _process_response(response, model_inputs, trace_id):
response_message = response.choices[0].message
if response_message.content:
if response_message.finish_reason == "content_filter":
trace_insert({"error": "Error: The content was filtered due to policy violations."}, trace_id)
if hasattr(response_message, "refusal"):
completion = response_message.refusal
elif response_message.content:
completion = response_message.content
else:
completion = _format_function_call(response_message)
Expand Down
Loading

0 comments on commit cfcda67

Please sign in to comment.