diff --git a/examples/supervisor-mode/main.py b/examples/supervisor-mode/main.py index 8c8f4fc..eeeb156 100644 --- a/examples/supervisor-mode/main.py +++ b/examples/supervisor-mode/main.py @@ -1,3 +1,8 @@ +from typing import Any +import sys, asyncio, uuid +import os +from datetime import datetime, timezone +from multi_agent_orchestrator.utils import Logger from multi_agent_orchestrator.orchestrator import MultiAgentOrchestrator, OrchestratorConfig from multi_agent_orchestrator.agents import ( BedrockLLMAgent, BedrockLLMAgentOptions, @@ -9,9 +14,8 @@ from multi_agent_orchestrator.classifiers import ClassifierResult from multi_agent_orchestrator.types import ConversationMessage from multi_agent_orchestrator.storage import DynamoDbChatStorage -from typing import Any -import sys, asyncio, uuid -import os +from multi_agent_orchestrator.utils import Tool + from weather_tool import weather_tool_description, weather_tool_handler, weather_tool_prompt from supervisor_agent import SupervisorAgent, SupervisorAgentOptions from dotenv import load_dotenv @@ -52,6 +56,7 @@ weather_agent.set_system_prompt(weather_tool_prompt) + health_agent = BedrockLLMAgent(BedrockLLMAgentOptions( name="HealthAgent", description="You are a health agent. You are responsible for answering questions about health. You are only allowed to answer questions about health. You are not allowed to answer questions about anything else.", @@ -74,7 +79,7 @@ api_key=os.getenv('ANTHROPIC_API_KEY', None), name="SupervisorAgent", description="You are a supervisor agent. You are responsible for managing the flow of the conversation. You are only allowed to manage the flow of the conversation. You are not allowed to answer questions about anything else.", - model_id="claude-3-5-sonnet-latest" + model_id="claude-3-5-sonnet-latest", )) # supervisor_agent = BedrockLLMAgent(BedrockLLMAgentOptions( @@ -83,6 +88,13 @@ # description="You are a supervisor agent. You are responsible for managing the flow of the conversation. You are only allowed to manage the flow of the conversation. You are not allowed to answer questions about anything else.", # )) +async def get_current_date(): + """ + Get the current date in US format. + """ + Logger.info('Using Tool : get_current_date') + return datetime.now(timezone.utc).strftime('%m/%d/%Y') # from datetime import datetime, timezone + supervisor = SupervisorAgent( SupervisorAgentOptions( @@ -92,7 +104,11 @@ table_name=os.getenv('DYNAMODB_CHAT_HISTORY_TABLE_NAME', None), region='us-east-1' ), - trace=True + trace=True, + extra_tools=[Tool( + name="get_current_date", + func=get_current_date, + )] )) async def handle_request(_orchestrator: MultiAgentOrchestrator, _user_input:str, _user_id:str, _session_id:str): diff --git a/examples/supervisor-mode/supervisor_agent.py b/examples/supervisor-mode/supervisor_agent.py index a219b5d..ad8d476 100644 --- a/examples/supervisor-mode/supervisor_agent.py +++ b/examples/supervisor-mode/supervisor_agent.py @@ -1,20 +1,11 @@ from typing import Optional, Any, AsyncIterable, Union from dataclasses import dataclass, field -from enum import Enum -from concurrent.futures import ThreadPoolExecutor, as_completed import asyncio from multi_agent_orchestrator.agents import Agent, AgentOptions, BedrockLLMAgent, AnthropicAgent -from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole -from multi_agent_orchestrator.utils import Logger +from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole, AgentProviderType +from multi_agent_orchestrator.utils import Logger, Tools, Tool from multi_agent_orchestrator.storage import ChatStorage, InMemoryChatStorage -from tool import Tool, ToolResult -from datetime import datetime, timezone - - -class SupervisorType(Enum): - BEDROCK = "BEDROCK" - ANTHROPIC = "ANTHROPIC" @dataclass class SupervisorAgentOptions(AgentOptions): @@ -22,6 +13,7 @@ class SupervisorAgentOptions(AgentOptions): team: list[Agent] = field(default_factory=list) storage: Optional[ChatStorage] = None trace: Optional[bool] = None + extra_tools: Optional[Union[Tools, list[Tool]]] = None # allow for extra tools # Hide inherited fields name: str = field(init=False) @@ -46,46 +38,9 @@ class SupervisorAgent(Agent): __init__(self, options: SupervisorAgentOptions): Initializes a SupervisorAgent instance. send_message(self, agent: Agent, content: str, user_id: str, session_id: str, additionalParameters: dict) -> str: Sends a message to an agent. send_messages(self, messages: list[dict[str, str]]) -> str: Sends messages to multiple agents in parallel. - get_current_date(self) -> str: Gets the current date. - supervisor_tool_handler(self, response: Any, conversation: list[dict[str, Any]]) -> Any: Handles the response from a tool. - _process_tool(self, tool_name: str, input_data: dict) -> Any: Processes a tool based on its name. process_request(self, input_text: str, user_id: str, session_id: str, chat_history: list[ConversationMessage], additional_params: Optional[dict[str, str]] = None) -> Union[ConversationMessage, AsyncIterable[Any]]: Processes a user request. """ - supervisor_tools:list[Tool] = [Tool( - name='send_messages', - description='Send a message to a one or multiple agents in parallel.', - properties={ - "messages": { - "type": "array", - "items": { - "type": "object", - "properties": { - "recipient": { - "type": "string", - "description": "The name of the agent to send the message to." - }, - "content": { - "type": "string", - "description": "The content of the message to send." - } - }, - "required": ["recipient", "content"] - }, - "description": "Array of messages to send to different agents.", - "minItems": 1 - } - }, - required=["messages"] - ), - Tool( - name="get_current_date", - description="Get the date of today in US format.", - properties={}, - required=[] - )] - - def __init__(self, options: SupervisorAgentOptions): options.name = options.supervisor.name options.description = options.supervisor.description @@ -93,15 +48,49 @@ def __init__(self, options: SupervisorAgentOptions): self.supervisor:Union[AnthropicAgent,BedrockLLMAgent] = options.supervisor self.team = options.team - self.supervisor_type = SupervisorType.BEDROCK.value if isinstance(self.supervisor, BedrockLLMAgent) else SupervisorType.ANTHROPIC.value + self.supervisor_type = AgentProviderType.BEDROCK.value if isinstance(self.supervisor, BedrockLLMAgent) else AgentProviderType.ANTHROPIC.value + self.supervisor_tools:Tools = Tools([Tool( + name='send_messages', + description='Send a message to a one or multiple agents in parallel.', + properties={ + "messages": { + "type": "array", + "items": { + "type": "object", + "properties": { + "recipient": { + "type": "string", + "description": "The name of the agent to send the message to." + }, + "content": { + "type": "string", + "description": "The content of the message to send." + } + }, + "required": ["recipient", "content"] + }, + "description": "Array of messages to send to different agents.", + "minItems": 1 + } + }, + required=["messages"], + func=self.send_messages + )]) + if not self.supervisor.tool_config: + if options.extra_tools: + if isinstance(options.extra_tools, Tools): + self.supervisor_tools.tools.extend(options.extra_tools.tools) + elif isinstance(options.extra_tools, list): + self.supervisor_tools.tools.extend(options.extra_tools) + else: + raise RuntimeError('extra_tools must be a Tools object or a list of Tool objects.') self.supervisor.tool_config = { - 'tool': [tool.to_bedrock_format() if self.supervisor_type == SupervisorType.BEDROCK.value else tool.to_claude_format() for tool in SupervisorAgent.supervisor_tools], + 'tool': self.supervisor_tools, 'toolMaxRecursions': 40, - 'useToolHandler': self.supervisor_tool_handler } else: - raise RuntimeError('Supervisor tool config already set. Please do not set it manually.') + raise RuntimeError('supervisor tools are set by SupervisorAgent class. To add more tools, use the extra_tools parameter in the SupervisorAgentOptions class.') self.user_id = '' self.session_id = '' @@ -109,7 +98,7 @@ def __init__(self, options: SupervisorAgentOptions): self.trace = options.trace - tools_str = ",".join(f"{tool.name}:{tool.func_description}" for tool in SupervisorAgent.supervisor_tools) + tools_str = ",".join(f"{tool.name}:{tool.func_description}" for tool in self.supervisor_tools.tools) agent_list_str = "\n".join( f"{agent.name}: {agent.description}" for agent in self.team @@ -179,32 +168,6 @@ def send_message(self, agent:Agent, content: str, user_id: str, session_id: str, if self.trace else None return f"{agent.name}: {response.content[0].get('text')}" - # async def send_messages(self, messages: list[dict[str, str]]): - # """Process all messages for all agents in parallel.""" - # with ThreadPoolExecutor(max_workers=5) as executor: - # futures = [] - # for agent in self.team: - # for message in messages: - # if agent.name == message.get('recipient'): - # future = executor.submit( - # self.send_message, - # agent, - # message.get('content'), - # self.user_id, - # self.session_id, - # {} - # ) - # futures.append(future) - # responses = [] - - # for future in as_completed(futures): - # response = future.result() - # responses.append(response) - - # # Wait for all tasks to complete - # return ''.join(response for response in responses) - - async def send_messages(self, messages: list[dict[str, str]]): """Process all messages for all agents in parallel.""" tasks = [] @@ -232,86 +195,6 @@ async def send_messages(self, messages: list[dict[str, str]]): return ''.join(responses) return '' - - async def get_current_date(self): - print('Using Tool : get_current_date') - return datetime.now(timezone.utc).strftime('%m/%d/%Y') # from datetime import datetime, timezone - - - - async def supervisor_tool_handler(self, response: Any, conversation: list[dict[str, Any]],) -> Any: - if not response.content: - raise ValueError("No content blocks in response") - - tool_results = [] - content_blocks = response.content - - for block in content_blocks: - # Determine if it's a tool use block based on platform - tool_use_block = self._get_tool_use_block(block) - if not tool_use_block: - continue - - tool_name = ( - tool_use_block.get("name") - if self.supervisor_type == SupervisorType.BEDROCK.value - else tool_use_block.name - ) - - tool_id = ( - tool_use_block.get("toolUseId") - if self.supervisor_type == SupervisorType.BEDROCK.value - else tool_use_block.id - ) - - # Get input based on platform - input_data = ( - tool_use_block.get("input", {}) - if self.supervisor_type == SupervisorType.BEDROCK.value - else tool_use_block.input - ) - - # Process the tool use - result = await self._process_tool(tool_name, input_data) - - # Create tool result - tool_result = ToolResult(tool_id, result) - - # Format according to platform - formatted_result = ( - tool_result.to_bedrock_format() - if self.supervisor_type == SupervisorType.BEDROCK.value - else tool_result.to_anthropic_format() - ) - - tool_results.append(formatted_result) - - # Create and return appropriate message format - if self.supervisor_type == SupervisorType.BEDROCK.value: - return ConversationMessage( - role=ParticipantRole.USER.value, - content=tool_results - ) - else: - return { - 'role': ParticipantRole.USER.value, - 'content': tool_results - } - - - async def _process_tool(self, tool_name: str, input_data: dict) -> Any: - """Process tool use based on tool name.""" - if tool_name == "send_messages": - return await self.send_messages( - input_data.get('messages') - ) - elif tool_name == "get_current_date": - return await self.get_current_date() - else: - error_msg = f"Unknown tool use name: {tool_name}" - Logger.error(error_msg) - return error_msg - async def process_request( self, input_text: str, @@ -336,13 +219,8 @@ async def process_request( # update prompt with agents memory self.supervisor.set_system_prompt(self.prompt_template.replace('{AGENTS_MEMORY}', agents_memory)) # call the supervisor - response = await self.supervisor.process_request(input_text, user_id, session_id, chat_history, additional_params) - return response - - def _get_tool_use_block(self, block: dict) -> Union[dict, None]: - """Extract tool use block based on platform format.""" - if self.supervisor_type == SupervisorType.BEDROCK.value and "toolUse" in block: - return block["toolUse"] - elif self.supervisor_type == SupervisorType.ANTHROPIC.value and block.type == "tool_use": - return block - return None \ No newline at end of file + try: + response = await self.supervisor.process_request(input_text, user_id, session_id, chat_history, additional_params) + return response + except Exception as e: + Logger.error(f"Error in supervisor: {e}") diff --git a/examples/supervisor-mode/weather_tool.py b/examples/supervisor-mode/weather_tool.py index c7ab9b5..c50761a 100644 --- a/examples/supervisor-mode/weather_tool.py +++ b/examples/supervisor-mode/weather_tool.py @@ -1,6 +1,6 @@ import requests from requests.exceptions import RequestException -from typing import List, Dict, Any +from typing import Any from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole @@ -45,7 +45,7 @@ """ -async def weather_tool_handler(response: ConversationMessage, conversation: List[Dict[str, Any]]) -> ConversationMessage: +async def weather_tool_handler(response: ConversationMessage, conversation: list[dict[str, Any]]) -> ConversationMessage: response_content_blocks = response.content # Initialize an empty list of tool results