Skip to content

Commit

Permalink
Updated supervisor example after merging new Tools definition
Browse files Browse the repository at this point in the history
  • Loading branch information
brnaba-aws committed Dec 26, 2024
1 parent 000a492 commit 7b46e0f
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 176 deletions.
26 changes: 21 additions & 5 deletions examples/supervisor-mode/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from typing import Any
import sys, asyncio, uuid
import os
from datetime import datetime, timezone
from multi_agent_orchestrator.utils import Logger
from multi_agent_orchestrator.orchestrator import MultiAgentOrchestrator, OrchestratorConfig
from multi_agent_orchestrator.agents import (
BedrockLLMAgent, BedrockLLMAgentOptions,
Expand All @@ -9,9 +14,8 @@
from multi_agent_orchestrator.classifiers import ClassifierResult
from multi_agent_orchestrator.types import ConversationMessage
from multi_agent_orchestrator.storage import DynamoDbChatStorage
from typing import Any
import sys, asyncio, uuid
import os
from multi_agent_orchestrator.utils import Tool

from weather_tool import weather_tool_description, weather_tool_handler, weather_tool_prompt
from supervisor_agent import SupervisorAgent, SupervisorAgentOptions
from dotenv import load_dotenv
Expand Down Expand Up @@ -52,6 +56,7 @@
weather_agent.set_system_prompt(weather_tool_prompt)



health_agent = BedrockLLMAgent(BedrockLLMAgentOptions(
name="HealthAgent",
description="You are a health agent. You are responsible for answering questions about health. You are only allowed to answer questions about health. You are not allowed to answer questions about anything else.",
Expand All @@ -74,7 +79,7 @@
api_key=os.getenv('ANTHROPIC_API_KEY', None),
name="SupervisorAgent",
description="You are a supervisor agent. You are responsible for managing the flow of the conversation. You are only allowed to manage the flow of the conversation. You are not allowed to answer questions about anything else.",
model_id="claude-3-5-sonnet-latest"
model_id="claude-3-5-sonnet-latest",
))

# supervisor_agent = BedrockLLMAgent(BedrockLLMAgentOptions(
Expand All @@ -83,6 +88,13 @@
# description="You are a supervisor agent. You are responsible for managing the flow of the conversation. You are only allowed to manage the flow of the conversation. You are not allowed to answer questions about anything else.",
# ))

async def get_current_date():
"""
Get the current date in US format.
"""
Logger.info('Using Tool : get_current_date')
return datetime.now(timezone.utc).strftime('%m/%d/%Y') # from datetime import datetime, timezone


supervisor = SupervisorAgent(
SupervisorAgentOptions(
Expand All @@ -92,7 +104,11 @@
table_name=os.getenv('DYNAMODB_CHAT_HISTORY_TABLE_NAME', None),
region='us-east-1'
),
trace=True
trace=True,
extra_tools=[Tool(
name="get_current_date",
func=get_current_date,
)]
))

async def handle_request(_orchestrator: MultiAgentOrchestrator, _user_input:str, _user_id:str, _session_id:str):
Expand Down
216 changes: 47 additions & 169 deletions examples/supervisor-mode/supervisor_agent.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,19 @@

from typing import Optional, Any, AsyncIterable, Union
from dataclasses import dataclass, field
from enum import Enum
from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio
from multi_agent_orchestrator.agents import Agent, AgentOptions, BedrockLLMAgent, AnthropicAgent
from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole
from multi_agent_orchestrator.utils import Logger
from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole, AgentProviderType
from multi_agent_orchestrator.utils import Logger, Tools, Tool
from multi_agent_orchestrator.storage import ChatStorage, InMemoryChatStorage
from tool import Tool, ToolResult
from datetime import datetime, timezone


class SupervisorType(Enum):
BEDROCK = "BEDROCK"
ANTHROPIC = "ANTHROPIC"

@dataclass
class SupervisorAgentOptions(AgentOptions):
supervisor:Agent = None
team: list[Agent] = field(default_factory=list)
storage: Optional[ChatStorage] = None
trace: Optional[bool] = None
extra_tools: Optional[Union[Tools, list[Tool]]] = None # allow for extra tools

# Hide inherited fields
name: str = field(init=False)
Expand All @@ -46,70 +38,67 @@ class SupervisorAgent(Agent):
__init__(self, options: SupervisorAgentOptions): Initializes a SupervisorAgent instance.
send_message(self, agent: Agent, content: str, user_id: str, session_id: str, additionalParameters: dict) -> str: Sends a message to an agent.
send_messages(self, messages: list[dict[str, str]]) -> str: Sends messages to multiple agents in parallel.
get_current_date(self) -> str: Gets the current date.
supervisor_tool_handler(self, response: Any, conversation: list[dict[str, Any]]) -> Any: Handles the response from a tool.
_process_tool(self, tool_name: str, input_data: dict) -> Any: Processes a tool based on its name.
process_request(self, input_text: str, user_id: str, session_id: str, chat_history: list[ConversationMessage], additional_params: Optional[dict[str, str]] = None) -> Union[ConversationMessage, AsyncIterable[Any]]: Processes a user request.
"""

supervisor_tools:list[Tool] = [Tool(
name='send_messages',
description='Send a message to a one or multiple agents in parallel.',
properties={
"messages": {
"type": "array",
"items": {
"type": "object",
"properties": {
"recipient": {
"type": "string",
"description": "The name of the agent to send the message to."
},
"content": {
"type": "string",
"description": "The content of the message to send."
}
},
"required": ["recipient", "content"]
},
"description": "Array of messages to send to different agents.",
"minItems": 1
}
},
required=["messages"]
),
Tool(
name="get_current_date",
description="Get the date of today in US format.",
properties={},
required=[]
)]


def __init__(self, options: SupervisorAgentOptions):
options.name = options.supervisor.name
options.description = options.supervisor.description
super().__init__(options)
self.supervisor:Union[AnthropicAgent,BedrockLLMAgent] = options.supervisor

self.team = options.team
self.supervisor_type = SupervisorType.BEDROCK.value if isinstance(self.supervisor, BedrockLLMAgent) else SupervisorType.ANTHROPIC.value
self.supervisor_type = AgentProviderType.BEDROCK.value if isinstance(self.supervisor, BedrockLLMAgent) else AgentProviderType.ANTHROPIC.value
self.supervisor_tools:Tools = Tools([Tool(
name='send_messages',
description='Send a message to a one or multiple agents in parallel.',
properties={
"messages": {
"type": "array",
"items": {
"type": "object",
"properties": {
"recipient": {
"type": "string",
"description": "The name of the agent to send the message to."
},
"content": {
"type": "string",
"description": "The content of the message to send."
}
},
"required": ["recipient", "content"]
},
"description": "Array of messages to send to different agents.",
"minItems": 1
}
},
required=["messages"],
func=self.send_messages
)])

if not self.supervisor.tool_config:
if options.extra_tools:
if isinstance(options.extra_tools, Tools):
self.supervisor_tools.tools.extend(options.extra_tools.tools)
elif isinstance(options.extra_tools, list):
self.supervisor_tools.tools.extend(options.extra_tools)
else:
raise RuntimeError('extra_tools must be a Tools object or a list of Tool objects.')
self.supervisor.tool_config = {
'tool': [tool.to_bedrock_format() if self.supervisor_type == SupervisorType.BEDROCK.value else tool.to_claude_format() for tool in SupervisorAgent.supervisor_tools],
'tool': self.supervisor_tools,
'toolMaxRecursions': 40,
'useToolHandler': self.supervisor_tool_handler
}
else:
raise RuntimeError('Supervisor tool config already set. Please do not set it manually.')
raise RuntimeError('supervisor tools are set by SupervisorAgent class. To add more tools, use the extra_tools parameter in the SupervisorAgentOptions class.')

self.user_id = ''
self.session_id = ''
self.storage = options.storage or InMemoryChatStorage()
self.trace = options.trace


tools_str = ",".join(f"{tool.name}:{tool.func_description}" for tool in SupervisorAgent.supervisor_tools)
tools_str = ",".join(f"{tool.name}:{tool.func_description}" for tool in self.supervisor_tools.tools)
agent_list_str = "\n".join(
f"{agent.name}: {agent.description}"
for agent in self.team
Expand Down Expand Up @@ -179,32 +168,6 @@ def send_message(self, agent:Agent, content: str, user_id: str, session_id: str,
if self.trace else None
return f"{agent.name}: {response.content[0].get('text')}"

# async def send_messages(self, messages: list[dict[str, str]]):
# """Process all messages for all agents in parallel."""
# with ThreadPoolExecutor(max_workers=5) as executor:
# futures = []
# for agent in self.team:
# for message in messages:
# if agent.name == message.get('recipient'):
# future = executor.submit(
# self.send_message,
# agent,
# message.get('content'),
# self.user_id,
# self.session_id,
# {}
# )
# futures.append(future)
# responses = []

# for future in as_completed(futures):
# response = future.result()
# responses.append(response)

# # Wait for all tasks to complete
# return ''.join(response for response in responses)


async def send_messages(self, messages: list[dict[str, str]]):
"""Process all messages for all agents in parallel."""
tasks = []
Expand Down Expand Up @@ -232,86 +195,6 @@ async def send_messages(self, messages: list[dict[str, str]]):
return ''.join(responses)
return ''


async def get_current_date(self):
print('Using Tool : get_current_date')
return datetime.now(timezone.utc).strftime('%m/%d/%Y') # from datetime import datetime, timezone



async def supervisor_tool_handler(self, response: Any, conversation: list[dict[str, Any]],) -> Any:
if not response.content:
raise ValueError("No content blocks in response")

tool_results = []
content_blocks = response.content

for block in content_blocks:
# Determine if it's a tool use block based on platform
tool_use_block = self._get_tool_use_block(block)
if not tool_use_block:
continue

tool_name = (
tool_use_block.get("name")
if self.supervisor_type == SupervisorType.BEDROCK.value
else tool_use_block.name
)

tool_id = (
tool_use_block.get("toolUseId")
if self.supervisor_type == SupervisorType.BEDROCK.value
else tool_use_block.id
)

# Get input based on platform
input_data = (
tool_use_block.get("input", {})
if self.supervisor_type == SupervisorType.BEDROCK.value
else tool_use_block.input
)

# Process the tool use
result = await self._process_tool(tool_name, input_data)

# Create tool result
tool_result = ToolResult(tool_id, result)

# Format according to platform
formatted_result = (
tool_result.to_bedrock_format()
if self.supervisor_type == SupervisorType.BEDROCK.value
else tool_result.to_anthropic_format()
)

tool_results.append(formatted_result)

# Create and return appropriate message format
if self.supervisor_type == SupervisorType.BEDROCK.value:
return ConversationMessage(
role=ParticipantRole.USER.value,
content=tool_results
)
else:
return {
'role': ParticipantRole.USER.value,
'content': tool_results
}


async def _process_tool(self, tool_name: str, input_data: dict) -> Any:
"""Process tool use based on tool name."""
if tool_name == "send_messages":
return await self.send_messages(
input_data.get('messages')
)
elif tool_name == "get_current_date":
return await self.get_current_date()
else:
error_msg = f"Unknown tool use name: {tool_name}"
Logger.error(error_msg)
return error_msg

async def process_request(
self,
input_text: str,
Expand All @@ -336,13 +219,8 @@ async def process_request(
# update prompt with agents memory
self.supervisor.set_system_prompt(self.prompt_template.replace('{AGENTS_MEMORY}', agents_memory))
# call the supervisor
response = await self.supervisor.process_request(input_text, user_id, session_id, chat_history, additional_params)
return response

def _get_tool_use_block(self, block: dict) -> Union[dict, None]:
"""Extract tool use block based on platform format."""
if self.supervisor_type == SupervisorType.BEDROCK.value and "toolUse" in block:
return block["toolUse"]
elif self.supervisor_type == SupervisorType.ANTHROPIC.value and block.type == "tool_use":
return block
return None
try:
response = await self.supervisor.process_request(input_text, user_id, session_id, chat_history, additional_params)
return response
except Exception as e:
Logger.error(f"Error in supervisor: {e}")
4 changes: 2 additions & 2 deletions examples/supervisor-mode/weather_tool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import requests
from requests.exceptions import RequestException
from typing import List, Dict, Any
from typing import Any
from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole


Expand Down Expand Up @@ -45,7 +45,7 @@
"""


async def weather_tool_handler(response: ConversationMessage, conversation: List[Dict[str, Any]]) -> ConversationMessage:
async def weather_tool_handler(response: ConversationMessage, conversation: list[dict[str, Any]]) -> ConversationMessage:
response_content_blocks = response.content

# Initialize an empty list of tool results
Expand Down

0 comments on commit 7b46e0f

Please sign in to comment.