Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Add] Added function calling support to GeminiClient #2792

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 114 additions & 23 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,22 @@
from __future__ import annotations

import base64
import copy
import json
import os
import random
import re
import time
import warnings
from io import BytesIO
from typing import Any, Dict, List, Mapping, Union
from typing import Any, Dict, List

import google.generativeai as genai
import requests
from google.ai.generativelanguage import Content, Part
from google.ai.generativelanguage import Content, Part, Tool, FunctionDeclaration, FunctionCall, FunctionResponse
from google.api_core.exceptions import InternalServerError
from openai.types.chat import ChatCompletion
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_message_tool_call import Function
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from PIL import Image
Expand Down Expand Up @@ -112,6 +115,7 @@ def create(self, params: Dict) -> ChatCompletion:

params.get("api_type", "google") # not used
messages = params.get("messages", [])
tools = params.get("tools", [])
stream = params.get("stream", False)
n_response = params.get("n", 1)

Expand All @@ -134,18 +138,19 @@ def create(self, params: Dict) -> ChatCompletion:
if "vision" not in model_name:
# A. create and call the chat model.
gemini_messages = oai_messages_to_gemini_messages(messages)
gemini_tools = oai_tools_to_gemini_tools(tools)

# we use chat model by default
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name, generation_config=generation_config, safety_settings=safety_settings, tools=gemini_tools
)
genai.configure(api_key=self.api_key)
chat = model.start_chat(history=gemini_messages[:-1])
chat: genai.ChatSession = model.start_chat(history=gemini_messages[:-1])
max_retries = 5
for attempt in range(max_retries):
ans = None
ans: Content = None
try:
response = chat.send_message(gemini_messages[-1].parts[0].text, stream=stream)
response = chat.send_message(gemini_messages[-1].parts, stream=stream)
except InternalServerError:
delay = 5 * (2**attempt)
warnings.warn(
Expand All @@ -157,14 +162,14 @@ def create(self, params: Dict) -> ChatCompletion:
raise RuntimeError(f"Google GenAI exception occurred while calling Gemini API: {e}")
else:
# `ans = response.text` is unstable. Use the following code instead.
ans: str = chat.history[-1].parts[0].text
ans: Content = chat.history[-1]
break

if ans is None:
raise RuntimeError(f"Fail to get response from Google AI after retrying {attempt + 1} times.")

prompt_tokens = model.count_tokens(chat.history[:-1]).total_tokens
completion_tokens = model.count_tokens(ans).total_tokens
completion_tokens = model.count_tokens(contents=Content(parts=ans.parts)).total_tokens
elif model_name == "gemini-pro-vision":
# B. handle the vision model
# Gemini's vision model does not support chat history yet
Expand All @@ -174,7 +179,7 @@ def create(self, params: Dict) -> ChatCompletion:
genai.configure(api_key=self.api_key)
# chat = model.start_chat(history=gemini_messages[:-1])
# response = chat.send_message(gemini_messages[-1])
user_message = oai_content_to_gemini_content(messages[-1]["content"])
user_message = oai_content_to_gemini_content(messages[-1])
if len(messages) > 2:
warnings.warn(
"Warning: Gemini's vision model does not support chat history yet.",
Expand All @@ -184,14 +189,13 @@ def create(self, params: Dict) -> ChatCompletion:

response = model.generate_content(user_message, stream=stream)
# ans = response.text
ans: str = response._result.candidates[0].content.parts[0].text
ans: Content = response._result.candidates[0].content

prompt_tokens = model.count_tokens(user_message).total_tokens
completion_tokens = model.count_tokens(ans).total_tokens
completion_tokens = model.count_tokens(ans.parts[0].text).total_tokens

# 3. convert output
message = ChatCompletionMessage(role="assistant", content=ans, function_call=None, tool_calls=None)
choices = [Choice(finish_reason="stop", index=0, message=message)]
choices = gemini_content_to_oai_choices(ans)

response_oai = ChatCompletion(
id=str(random.randint(0, 1000)),
Expand Down Expand Up @@ -223,16 +227,34 @@ def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str
return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6


def oai_content_to_gemini_content(content: Union[str, List]) -> List:
def oai_content_to_gemini_content(message: Dict[str, Any]) -> List:
"""Convert content from OAI format to Gemini format"""
rst = []
if isinstance(content, str):
rst.append(Part(text=content))
if isinstance(message, str):
rst.append(Part(text=message))
return rst

if "tool_calls" in message:
rst.append(Part(function_call=FunctionCall(
name=message["tool_calls"][0]["function"]["name"],
args=json.loads(message["tool_calls"][0]["function"]["arguments"])
)))
return rst

if message["role"] == "tool":
rst.append(Part(function_response=FunctionResponse(
name=message["name"],
response=json.loads(message["content"])
)))
return rst

if isinstance(message["content"], str):
rst.append(Part(text=message["content"]))
return rst

assert isinstance(content, list)
assert isinstance(message["content"], list)

for msg in content:
for msg in message["content"]:
if isinstance(msg, dict):
assert "type" in msg, f"Missing 'type' field in message: {msg}"
if msg["type"] == "text":
Expand All @@ -254,6 +276,9 @@ def concat_parts(parts: List[Part]) -> List:
"""
if not parts:
return []

if len(parts) == 1:
return parts

concatenated_parts = []
previous_part = parts[0]
Expand Down Expand Up @@ -281,18 +306,34 @@ def oai_messages_to_gemini_messages(messages: list[Dict[str, Any]]) -> list[dict
rst = []
curr_parts = []
for i, message in enumerate(messages):
parts = oai_content_to_gemini_content(message["content"])

# Since the tool call message does not have the "name" field, we need to find the corresponding tool message.
if message["role"] == "tool":
message["name"] = [m for m in messages if "tool_calls" in m and m["tool_calls"][0]["id"] == message["tool_call_id"]][0]["tool_calls"][0]["function"]["name"]

parts = oai_content_to_gemini_content(message)
role = "user" if message["role"] in ["user", "system"] else "model"

if prev_role is None or role == prev_role:
curr_parts += parts
# If the message is a function call or a function response, we need to separate it from the previous message.
if "function_call" in parts[0] or "function_response" in parts[0]:
if len(curr_parts) > 1:
rst.append(Content(parts=concat_parts(curr_parts), role=prev_role))
elif len(curr_parts) == 1:
rst.append(Content(parts=curr_parts, role=None if curr_parts[0].function_response else role))
rst.append(Content(parts=parts, role="user" if parts[0].function_response else role))
curr_parts = []
else:
curr_parts += parts
elif role != prev_role:
rst.append(Content(parts=concat_parts(curr_parts), role=prev_role))
if len(curr_parts) > 0:
rst.append(Content(parts=concat_parts(curr_parts), role=prev_role))
curr_parts = parts
prev_role = role

# handle the last message
rst.append(Content(parts=concat_parts(curr_parts), role=role))
if len(curr_parts) > 0:
rst.append(Content(parts=concat_parts(curr_parts), role=role))

# The Gemini is restrict on order of roles, such that
# 1. The messages should be interleaved between user and model.
Expand All @@ -304,6 +345,56 @@ def oai_messages_to_gemini_messages(messages: list[Dict[str, Any]]) -> list[dict
return rst


def oai_tools_to_gemini_tools(tools: List[Dict[str, Any]]) -> List[Tool]:
"""Convert tools from OAI format to Gemini format."""
function_declarations = []
for tool in tools:
function_declaration = FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"]["description"],
parameters=oai_function_parameters_to_gemini_function_parameters(copy.deepcopy(tool["function"]["parameters"]))
)
function_declarations.append(function_declaration)
return [Tool(function_declarations=function_declarations)]

def oai_function_parameters_to_gemini_function_parameters(function_definition: dict[str, any]) -> dict[str, any]:
"""
Convert OpenAPI function definition parameters to Gemini function parameters definition.
The type key is renamed to type_ and the value is capitalized.
"""
function_definition["type_"] = function_definition["type"].upper()
del function_definition["type"]
if "properties" in function_definition:
for key in function_definition["properties"]:
function_definition["properties"][key] = oai_function_parameters_to_gemini_function_parameters(function_definition["properties"][key])
if "items" in function_definition:
function_definition["items"] = oai_function_parameters_to_gemini_function_parameters(function_definition["items"])
return function_definition


def gemini_content_to_oai_choices(response: Content) -> List[Choice]:
"""Convert response from Gemini format to OAI format."""
text = None
tool_calls = None
for part in response.parts:
if part.text:
text = part.text
elif part.function_call:
arguments = Part.to_dict(part)["function_call"]["args"]
tool_calls = [
ChatCompletionMessageToolCall(
id=str(random.randint(0, 1000)),
type="function",
function=Function(
name=part.function_call.name,
arguments=json.dumps(arguments)
)
)
]
message = ChatCompletionMessage(role="assistant", content=text, function_call=None, tool_calls=tool_calls)
return [Choice(finish_reason="tool_calls" if tool_calls else "stop", index=0, message=message)]


def _to_pil(data: str) -> Image.Image:
"""
Converts a base64 encoded image data string to a PIL Image object.
Expand Down