From 77d420bdcd8ff49bfafc38969d150bb273fbe663 Mon Sep 17 00:00:00 2001 From: Anthony Bernabeu <64135631+brnaba-aws@users.noreply.github.com> Date: Wed, 18 Dec 2024 16:15:48 +0100 Subject: [PATCH] updated supervisor_agent code --- examples/supervisor-mode/supervisor_agent.py | 29 +++++++++----------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/examples/supervisor-mode/supervisor_agent.py b/examples/supervisor-mode/supervisor_agent.py index 958f816..b2eb9cf 100644 --- a/examples/supervisor-mode/supervisor_agent.py +++ b/examples/supervisor-mode/supervisor_agent.py @@ -1,5 +1,6 @@ from typing import Optional, Any, AsyncIterable, Union +from dataclasses import dataclass, field from enum import Enum from concurrent.futures import ThreadPoolExecutor, as_completed import asyncio @@ -15,23 +16,16 @@ class SupervisorType(Enum): BEDROCK = "BEDROCK" ANTHROPIC = "ANTHROPIC" +@dataclass class SupervisorAgentOptions(AgentOptions): - def __init__( - self, - supervisor:Agent, - team: list[Agent], - storage: Optional[ChatStorage] = None, - trace: Optional[bool] = None, - **kwargs, - ): - kwargs['name'] = supervisor.name - kwargs['description'] = supervisor.description - super().__init__(**kwargs) - self.supervisor:Union[AnthropicAgent,BedrockLLMAgent] = supervisor - self.team: list[Agent] = team - self.storage = storage or InMemoryChatStorage() - self.trace = trace or False + supervisor:Agent = None + team: list[Agent] = field(default_factory=list) + storage: Optional[ChatStorage] = None + trace: Optional[bool] = None + # Hide inherited fields + name: str = field(init=False) + description: str = field(init=False) class SupervisorAgent(Agent): @@ -84,8 +78,11 @@ class SupervisorAgent(Agent): def __init__(self, options: SupervisorAgentOptions): + options.name = options.supervisor.name + options.description = options.supervisor.description super().__init__(options) self.supervisor:Union[AnthropicAgent,BedrockLLMAgent] = options.supervisor + self.team = options.team self.supervisor_type = SupervisorType.BEDROCK.value if isinstance(self.supervisor, BedrockLLMAgent) else SupervisorType.ANTHROPIC.value if not self.supervisor.tool_config: @@ -99,7 +96,7 @@ def __init__(self, options: SupervisorAgentOptions): self.user_id = '' self.session_id = '' - self.storage = options.storage + self.storage = options.storage or InMemoryChatStorage() self.trace = options.trace