From e38d59ed6cb9acb93b7863411838ec56e8c1981e Mon Sep 17 00:00:00 2001 From: Arjun G Date: Sat, 25 May 2024 23:16:00 +0530 Subject: [PATCH] Added function calling support to GeminiClient --- autogen/oai/gemini.py | 137 +++++++++++++++++++++++++++++++++++------- 1 file changed, 114 insertions(+), 23 deletions(-) diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index 5c06a4def0c9..dfa38cf9a257 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -32,19 +32,22 @@ from __future__ import annotations import base64 +import copy +import json import os import random import re import time import warnings from io import BytesIO -from typing import Any, Dict, List, Mapping, Union +from typing import Any, Dict, List import google.generativeai as genai import requests -from google.ai.generativelanguage import Content, Part +from google.ai.generativelanguage import Content, Part, Tool, FunctionDeclaration, FunctionCall, FunctionResponse from google.api_core.exceptions import InternalServerError -from openai.types.chat import ChatCompletion +from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion_message_tool_call import Function from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage from PIL import Image @@ -112,6 +115,7 @@ def create(self, params: Dict) -> ChatCompletion: params.get("api_type", "google") # not used messages = params.get("messages", []) + tools = params.get("tools", []) stream = params.get("stream", False) n_response = params.get("n", 1) @@ -134,18 +138,19 @@ def create(self, params: Dict) -> ChatCompletion: if "vision" not in model_name: # A. create and call the chat model. gemini_messages = oai_messages_to_gemini_messages(messages) + gemini_tools = oai_tools_to_gemini_tools(tools) # we use chat model by default model = genai.GenerativeModel( - model_name, generation_config=generation_config, safety_settings=safety_settings + model_name, generation_config=generation_config, safety_settings=safety_settings, tools=gemini_tools ) genai.configure(api_key=self.api_key) - chat = model.start_chat(history=gemini_messages[:-1]) + chat: genai.ChatSession = model.start_chat(history=gemini_messages[:-1]) max_retries = 5 for attempt in range(max_retries): - ans = None + ans: Content = None try: - response = chat.send_message(gemini_messages[-1].parts[0].text, stream=stream) + response = chat.send_message(gemini_messages[-1].parts, stream=stream) except InternalServerError: delay = 5 * (2**attempt) warnings.warn( @@ -157,14 +162,14 @@ def create(self, params: Dict) -> ChatCompletion: raise RuntimeError(f"Google GenAI exception occurred while calling Gemini API: {e}") else: # `ans = response.text` is unstable. Use the following code instead. - ans: str = chat.history[-1].parts[0].text + ans: Content = chat.history[-1] break if ans is None: raise RuntimeError(f"Fail to get response from Google AI after retrying {attempt + 1} times.") prompt_tokens = model.count_tokens(chat.history[:-1]).total_tokens - completion_tokens = model.count_tokens(ans).total_tokens + completion_tokens = model.count_tokens(contents=Content(parts=ans.parts)).total_tokens elif model_name == "gemini-pro-vision": # B. handle the vision model # Gemini's vision model does not support chat history yet @@ -174,7 +179,7 @@ def create(self, params: Dict) -> ChatCompletion: genai.configure(api_key=self.api_key) # chat = model.start_chat(history=gemini_messages[:-1]) # response = chat.send_message(gemini_messages[-1]) - user_message = oai_content_to_gemini_content(messages[-1]["content"]) + user_message = oai_content_to_gemini_content(messages[-1]) if len(messages) > 2: warnings.warn( "Warning: Gemini's vision model does not support chat history yet.", @@ -184,14 +189,13 @@ def create(self, params: Dict) -> ChatCompletion: response = model.generate_content(user_message, stream=stream) # ans = response.text - ans: str = response._result.candidates[0].content.parts[0].text + ans: Content = response._result.candidates[0].content prompt_tokens = model.count_tokens(user_message).total_tokens - completion_tokens = model.count_tokens(ans).total_tokens + completion_tokens = model.count_tokens(ans.parts[0].text).total_tokens # 3. convert output - message = ChatCompletionMessage(role="assistant", content=ans, function_call=None, tool_calls=None) - choices = [Choice(finish_reason="stop", index=0, message=message)] + choices = gemini_content_to_oai_choices(ans) response_oai = ChatCompletion( id=str(random.randint(0, 1000)), @@ -223,16 +227,34 @@ def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6 -def oai_content_to_gemini_content(content: Union[str, List]) -> List: +def oai_content_to_gemini_content(message: Dict[str, Any]) -> List: """Convert content from OAI format to Gemini format""" rst = [] - if isinstance(content, str): - rst.append(Part(text=content)) + if isinstance(message, str): + rst.append(Part(text=message)) + return rst + + if "tool_calls" in message: + rst.append(Part(function_call=FunctionCall( + name=message["tool_calls"][0]["function"]["name"], + args=json.loads(message["tool_calls"][0]["function"]["arguments"]) + ))) + return rst + + if message["role"] == "tool": + rst.append(Part(function_response=FunctionResponse( + name=message["name"], + response=json.loads(message["content"]) + ))) + return rst + + if isinstance(message["content"], str): + rst.append(Part(text=message["content"])) return rst - assert isinstance(content, list) + assert isinstance(message["content"], list) - for msg in content: + for msg in message["content"]: if isinstance(msg, dict): assert "type" in msg, f"Missing 'type' field in message: {msg}" if msg["type"] == "text": @@ -254,6 +276,9 @@ def concat_parts(parts: List[Part]) -> List: """ if not parts: return [] + + if len(parts) == 1: + return parts concatenated_parts = [] previous_part = parts[0] @@ -281,18 +306,34 @@ def oai_messages_to_gemini_messages(messages: list[Dict[str, Any]]) -> list[dict rst = [] curr_parts = [] for i, message in enumerate(messages): - parts = oai_content_to_gemini_content(message["content"]) + + # Since the tool call message does not have the "name" field, we need to find the corresponding tool message. + if message["role"] == "tool": + message["name"] = [m for m in messages if "tool_calls" in m and m["tool_calls"][0]["id"] == message["tool_call_id"]][0]["tool_calls"][0]["function"]["name"] + + parts = oai_content_to_gemini_content(message) role = "user" if message["role"] in ["user", "system"] else "model" if prev_role is None or role == prev_role: - curr_parts += parts + # If the message is a function call or a function response, we need to separate it from the previous message. + if "function_call" in parts[0] or "function_response" in parts[0]: + if len(curr_parts) > 1: + rst.append(Content(parts=concat_parts(curr_parts), role=prev_role)) + elif len(curr_parts) == 1: + rst.append(Content(parts=curr_parts, role=None if curr_parts[0].function_response else role)) + rst.append(Content(parts=parts, role="user" if parts[0].function_response else role)) + curr_parts = [] + else: + curr_parts += parts elif role != prev_role: - rst.append(Content(parts=concat_parts(curr_parts), role=prev_role)) + if len(curr_parts) > 0: + rst.append(Content(parts=concat_parts(curr_parts), role=prev_role)) curr_parts = parts prev_role = role # handle the last message - rst.append(Content(parts=concat_parts(curr_parts), role=role)) + if len(curr_parts) > 0: + rst.append(Content(parts=concat_parts(curr_parts), role=role)) # The Gemini is restrict on order of roles, such that # 1. The messages should be interleaved between user and model. @@ -304,6 +345,56 @@ def oai_messages_to_gemini_messages(messages: list[Dict[str, Any]]) -> list[dict return rst +def oai_tools_to_gemini_tools(tools: List[Dict[str, Any]]) -> List[Tool]: + """Convert tools from OAI format to Gemini format.""" + function_declarations = [] + for tool in tools: + function_declaration = FunctionDeclaration( + name=tool["function"]["name"], + description=tool["function"]["description"], + parameters=oai_function_parameters_to_gemini_function_parameters(copy.deepcopy(tool["function"]["parameters"])) + ) + function_declarations.append(function_declaration) + return [Tool(function_declarations=function_declarations)] + +def oai_function_parameters_to_gemini_function_parameters(function_definition: dict[str, any]) -> dict[str, any]: + """ + Convert OpenAPI function definition parameters to Gemini function parameters definition. + The type key is renamed to type_ and the value is capitalized. + """ + function_definition["type_"] = function_definition["type"].upper() + del function_definition["type"] + if "properties" in function_definition: + for key in function_definition["properties"]: + function_definition["properties"][key] = oai_function_parameters_to_gemini_function_parameters(function_definition["properties"][key]) + if "items" in function_definition: + function_definition["items"] = oai_function_parameters_to_gemini_function_parameters(function_definition["items"]) + return function_definition + + +def gemini_content_to_oai_choices(response: Content) -> List[Choice]: + """Convert response from Gemini format to OAI format.""" + text = None + tool_calls = None + for part in response.parts: + if part.text: + text = part.text + elif part.function_call: + arguments = Part.to_dict(part)["function_call"]["args"] + tool_calls = [ + ChatCompletionMessageToolCall( + id=str(random.randint(0, 1000)), + type="function", + function=Function( + name=part.function_call.name, + arguments=json.dumps(arguments) + ) + ) + ] + message = ChatCompletionMessage(role="assistant", content=text, function_call=None, tool_calls=tool_calls) + return [Choice(finish_reason="tool_calls" if tool_calls else "stop", index=0, message=message)] + + def _to_pil(data: str) -> Image.Image: """ Converts a base64 encoded image data string to a PIL Image object.