Skip to content

Commit

Permalink
updated supervisor Agent example
Browse files Browse the repository at this point in the history
  • Loading branch information
brnaba-aws committed Dec 26, 2024
1 parent b618fdc commit 8811d7c
Showing 1 changed file with 176 additions and 128 deletions.
304 changes: 176 additions & 128 deletions examples/supervisor-mode/supervisor_agent.py
Original file line number Diff line number Diff line change
@@ -1,110 +1,106 @@

from typing import Optional, Any, AsyncIterable, Union
from typing import Optional, Any, AsyncIterable, Union, TypeVar
from dataclasses import dataclass, field
import asyncio
from multi_agent_orchestrator.agents import Agent, AgentOptions, BedrockLLMAgent, AnthropicAgent
from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole, AgentProviderType
from multi_agent_orchestrator.utils import Logger, Tools, Tool
from multi_agent_orchestrator.storage import ChatStorage, InMemoryChatStorage

# T = TypeVar('T', bound='SupervisorAgent')

@dataclass
class SupervisorAgentOptions(AgentOptions):
supervisor:Agent = None
supervisor: Agent = None
team: list[Agent] = field(default_factory=list)
storage: Optional[ChatStorage] = None
trace: Optional[bool] = None
extra_tools: Optional[Union[Tools, list[Tool]]] = None # allow for extra tools

extra_tools: Optional[Union[Tools, list[Tool]]] = None
# Hide inherited fields
name: str = field(init=False)
description: str = field(init=False)

def validate(self) -> None:
if not isinstance(self.supervisor, (BedrockLLMAgent, AnthropicAgent)):
raise ValueError("Supervisor must be BedrockLLMAgent or AnthropicAgent")
if self.extra_tools and not isinstance(self.extra_tools, (Tools, list)):
raise ValueError('extra_tools must be Tools object or list of Tool objects')
if self.supervisor.tool_config:
raise ValueError('Supervisor tools are managed by SupervisorAgent. Use extra_tools for additional tools.')

class SupervisorAgent(Agent):
"""Supervisor agent that orchestrates interactions between multiple agents.
Manages communication, task delegation, and response aggregation between a team of agents.
Supports parallel processing of messages and maintains conversation history.
"""
SupervisorAgent class.
This class represents a supervisor agent that interacts with other agents in an environment. It inherits from the Agent class.
Attributes:
supervisor_tools (list[Tool]): List of tools available to the supervisor agent.
team (list[Agent]): List of agents in the environment.
supervisor_type (str): Type of supervisor agent (BEDROCK or ANTHROPIC).
user_id (str): User ID.
session_id (str): Session ID.
storage (ChatStorage): Chat storage for storing conversation history.
trace (bool): Flag indicating whether to enable tracing.
Methods:
__init__(self, options: SupervisorAgentOptions): Initializes a SupervisorAgent instance.
send_message(self, agent: Agent, content: str, user_id: str, session_id: str, additionalParameters: dict) -> str: Sends a message to an agent.
send_messages(self, messages: list[dict[str, str]]) -> str: Sends messages to multiple agents in parallel.
process_request(self, input_text: str, user_id: str, session_id: str, chat_history: list[ConversationMessage], additional_params: Optional[dict[str, str]] = None) -> Union[ConversationMessage, AsyncIterable[Any]]: Processes a user request.
"""

DEFAULT_TOOL_MAX_RECURSIONS = 40

def __init__(self, options: SupervisorAgentOptions):
options.validate()
options.name = options.supervisor.name
options.description = options.supervisor.description
super().__init__(options)
self.supervisor:Union[AnthropicAgent,BedrockLLMAgent] = options.supervisor

self.supervisor: Union[AnthropicAgent, BedrockLLMAgent] = options.supervisor
self.team = options.team
self.supervisor_type = AgentProviderType.BEDROCK.value if isinstance(self.supervisor, BedrockLLMAgent) else AgentProviderType.ANTHROPIC.value
self.supervisor_tools:Tools = Tools([Tool(
self.storage = options.storage or InMemoryChatStorage()
self.trace = options.trace
self.user_id = ''
self.session_id = ''

self._configure_supervisor_tools(options.extra_tools)
self._configure_prompt()

def _configure_supervisor_tools(self, extra_tools: Optional[Union[Tools, list[Tool]]]) -> None:
"""Configure the tools available to the supervisor."""
self.supervisor_tools = Tools([Tool(
name='send_messages',
description='Send a message to a one or multiple agents in parallel.',
description='Send messages to multiple agents in parallel.',
properties={
"messages": {
"messages": {
"type": "array",
"items": {
"type": "object",
"properties": {
"recipient": {
"type": "string",
"description": "The name of the agent to send the message to."
},
"content": {
"type": "string",
"description": "The content of the message to send."
}
"recipient": {
"type": "string",
"description": "Agent name to send message to."
},
"content": {
"type": "string",
"description": "Message content."
}
},
"required": ["recipient", "content"]
},
"description": "Array of messages to send to different agents.",
"description": "Array of messages for different agents.",
"minItems": 1
}
},
}
},
required=["messages"],
func=self.send_messages
)])

if not self.supervisor.tool_config:
if options.extra_tools:
if isinstance(options.extra_tools, Tools):
self.supervisor_tools.tools.extend(options.extra_tools.tools)
elif isinstance(options.extra_tools, list):
self.supervisor_tools.tools.extend(options.extra_tools)
else:
raise RuntimeError('extra_tools must be a Tools object or a list of Tool objects.')
self.supervisor.tool_config = {
'tool': self.supervisor_tools,
'toolMaxRecursions': 40,
}
else:
raise RuntimeError('supervisor tools are set by SupervisorAgent class. To add more tools, use the extra_tools parameter in the SupervisorAgentOptions class.')
if extra_tools:
if isinstance(extra_tools, Tools):
self.supervisor_tools.tools.extend(extra_tools.tools)
else:
self.supervisor_tools.tools.extend(extra_tools)

self.user_id = ''
self.session_id = ''
self.storage = options.storage or InMemoryChatStorage()
self.trace = options.trace
self.supervisor.tool_config = {
'tool': self.supervisor_tools,
'toolMaxRecursions': self.DEFAULT_TOOL_MAX_RECURSIONS,
}

def _configure_prompt(self) -> None:
"""Configure the supervisor's prompt template."""
tools_str = "\n".join(f"{tool.name}:{tool.func_description}"
for tool in self.supervisor_tools.tools)
agent_list_str = "\n".join(f"{agent.name}: {agent.description}"
for agent in self.team)

tools_str = ",".join(f"{tool.name}:{tool.func_description}" for tool in self.supervisor_tools.tools)
agent_list_str = "\n".join(
f"{agent.name}: {agent.description}"
for agent in self.team
)

self.prompt_template: str = f"""\n
self.prompt_template = f"""\n
You are a {self.name}.
{self.description}
Expand All @@ -115,7 +111,7 @@ def __init__(self, options: SupervisorAgentOptions):
Here are the tools you can use:
<tools>
{tools_str}:
{tools_str}
</tools>
When communicating with other agents, including the User, please follow these guidelines:
Expand Down Expand Up @@ -146,54 +142,108 @@ def __init__(self, options: SupervisorAgentOptions):
"""
self.supervisor.set_system_prompt(self.prompt_template)

if isinstance(self.supervisor, BedrockLLMAgent):
Logger.debug("Supervisor is a BedrockLLMAgent")
Logger.debug('converting tool to Bedrock format')
elif isinstance(self.supervisor, AnthropicAgent):
Logger.debug("Supervisor is a AnthropicAgent")
Logger.debug('converting tool to Anthropic format')
else:
Logger.debug(f"Supervisor {self.supervisor.__class__} is not supported")
raise RuntimeError("Supervisor must be a BedrockLLMAgent or AnthropicAgent")


def send_message(self, agent:Agent, content: str, user_id: str, session_id: str, additionalParameters: dict) -> 'str':
Logger.info(f"\n===>>>>> Supervisor sending {agent.name}: {content}")\
if self.trace else None
agent_chat_history = asyncio.run(self.storage.fetch_chat(user_id, session_id, agent.id)) if agent.save_chat else []
response = asyncio.run(agent.process_request(content, user_id, session_id, agent_chat_history, additionalParameters))
asyncio.run (self.storage.save_chat_message(user_id, session_id, agent.id, ConversationMessage(role=ParticipantRole.USER.value, content=[{'text':content}]))) if agent.save_chat else None
asyncio.run(self.storage.save_chat_message(user_id, session_id, agent.id, ConversationMessage(role=ParticipantRole.ASSISTANT.value, content=[{'text':f"{response.content[0].get('text', '')}"}]))) if agent.save_chat else None
Logger.info(f"\n<<<<<===Supervisor received this response from {agent.name}:\n{response.content[0].get('text','')[:500]}...") \
if self.trace else None
return f"{agent.name}: {response.content[0].get('text')}"

async def send_messages(self, messages: list[dict[str, str]]):
"""Process all messages for all agents in parallel."""
tasks = []

# Create tasks for each matching agent/message pair
for agent in self.team:
for message in messages:
if agent.name == message.get('recipient'):
# Wrap the entire send_message call in to_thread
task = asyncio.create_task(
asyncio.to_thread(
self.send_message,
agent,
message.get('content'),
self.user_id,
self.session_id,
{}
)
print(self.supervisor.prompt_template)

def send_message(
self,
agent: Agent,
content: str,
user_id: str,
session_id: str,
additional_params: dict[str, Any]
) -> str:
"""Send a message to a specific agent and process the response."""
try:
if self.trace:
Logger.info(f"\n===>>>>> Supervisor sending {agent.name}: {content}")

agent_chat_history = (
asyncio.run(self.storage.fetch_chat(user_id, session_id, agent.id))
if agent.save_chat else []
)

response = asyncio.run(agent.process_request(
content, user_id, session_id, agent_chat_history, additional_params
))

if agent.save_chat:
asyncio.run(self._save_chat_messages(
user_id, session_id, agent.id, content, response
))

if self.trace:
Logger.info(
f"\n<<<<<===Supervisor received from {agent.name}:\n"
f"{response.content[0].get('text','')[:500]}..."
)

return f"{agent.name}: {response.content[0].get('text', '')}"

except Exception as e:
Logger.error(f"Error in send_message: {e}")
raise e

async def _save_chat_messages(
self,
user_id: str,
session_id: str,
agent_id: str,
content: str,
response: Any
) -> None:
"""Save chat messages to storage."""
await self.storage.save_chat_message(
user_id, session_id, agent_id,
ConversationMessage(
role=ParticipantRole.USER.value,
content=[{'text': content}]
)
)
await self.storage.save_chat_message(
user_id, session_id, agent_id,
ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
content=[{'text': response.content[0].get('text', '')}]
)
)

async def send_messages(self, messages: list[dict[str, str]]) -> str:
"""Process messages for agents in parallel."""
try:
tasks = [
asyncio.create_task(
asyncio.to_thread(
self.send_message,
agent,
message.get('content'),
self.user_id,
self.session_id,
{}
)
tasks.append(task)
)
for agent in self.team
for message in messages
if agent.name == message.get('recipient')
]

if not tasks:
return ''

# Gather and wait for all tasks to complete
if tasks:
responses = await asyncio.gather(*tasks)
return ''.join(responses)
return ''

except Exception as e:
Logger.error(f"Error in send_messages: {e}")
raise e

def _format_agents_memory(self, agents_history: list[ConversationMessage]) -> str:
"""Format agent conversation history."""
return ''.join(
f"{user_msg.role}:{user_msg.content[0].get('text','')}\n"
f"{asst_msg.role}:{asst_msg.content[0].get('text','')}\n"
for user_msg, asst_msg in zip(agents_history[::2], agents_history[1::2])
if self.id not in asst_msg.content[0].get('text', '')
)

async def process_request(
self,
Expand All @@ -203,24 +253,22 @@ async def process_request(
chat_history: list[ConversationMessage],
additional_params: Optional[dict[str, str]] = None
) -> Union[ConversationMessage, AsyncIterable[Any]]:
"""Process a user request through the supervisor agent."""
try:
self.user_id = user_id
self.session_id = session_id

self.user_id = user_id
self.session_id = session_id
agents_history = await self.storage.fetch_all_chats(user_id, session_id)
agents_memory = self._format_agents_memory(agents_history)

# fetch history from all agents (including supervisor)
agents_history = await self.storage.fetch_all_chats(user_id, session_id)
agents_memory = ''.join(
f"{user_msg.role}:{user_msg.content[0].get('text','')}\n"
f"{asst_msg.role}:{asst_msg.content[0].get('text','')}\n"
for user_msg, asst_msg in zip(agents_history[::2], agents_history[1::2])
if self.id not in asst_msg.content[0].get('text', '') # removing supervisor history from agents_memory (already part of chat_history)
)
self.supervisor.set_system_prompt(
self.prompt_template.replace('{AGENTS_MEMORY}', agents_memory)
)

# update prompt with agents memory
self.supervisor.set_system_prompt(self.prompt_template.replace('{AGENTS_MEMORY}', agents_memory))
# call the supervisor
try:
response = await self.supervisor.process_request(input_text, user_id, session_id, chat_history, additional_params)
return response
except Exception as e:
Logger.error(f"Error in supervisor: {e}")
return await self.supervisor.process_request(
input_text, user_id, session_id, chat_history, additional_params
)

except Exception as e:
Logger.error(f"Error in process_request: {e}")
raise e

0 comments on commit 8811d7c

Please sign in to comment.