Skip to content

Commit

Permalink
updated supervisor_agent code
Browse files Browse the repository at this point in the history
  • Loading branch information
brnaba-aws committed Dec 18, 2024
1 parent 24b1ada commit 77d420b
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions examples/supervisor-mode/supervisor_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

from typing import Optional, Any, AsyncIterable, Union
from dataclasses import dataclass, field
from enum import Enum
from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio
Expand All @@ -15,23 +16,16 @@ class SupervisorType(Enum):
BEDROCK = "BEDROCK"
ANTHROPIC = "ANTHROPIC"

@dataclass
class SupervisorAgentOptions(AgentOptions):
def __init__(
self,
supervisor:Agent,
team: list[Agent],
storage: Optional[ChatStorage] = None,
trace: Optional[bool] = None,
**kwargs,
):
kwargs['name'] = supervisor.name
kwargs['description'] = supervisor.description
super().__init__(**kwargs)
self.supervisor:Union[AnthropicAgent,BedrockLLMAgent] = supervisor
self.team: list[Agent] = team
self.storage = storage or InMemoryChatStorage()
self.trace = trace or False
supervisor:Agent = None
team: list[Agent] = field(default_factory=list)
storage: Optional[ChatStorage] = None
trace: Optional[bool] = None

# Hide inherited fields
name: str = field(init=False)
description: str = field(init=False)

class SupervisorAgent(Agent):

Expand Down Expand Up @@ -84,8 +78,11 @@ class SupervisorAgent(Agent):


def __init__(self, options: SupervisorAgentOptions):
options.name = options.supervisor.name
options.description = options.supervisor.description
super().__init__(options)
self.supervisor:Union[AnthropicAgent,BedrockLLMAgent] = options.supervisor

self.team = options.team
self.supervisor_type = SupervisorType.BEDROCK.value if isinstance(self.supervisor, BedrockLLMAgent) else SupervisorType.ANTHROPIC.value
if not self.supervisor.tool_config:
Expand All @@ -99,7 +96,7 @@ def __init__(self, options: SupervisorAgentOptions):

self.user_id = ''
self.session_id = ''
self.storage = options.storage
self.storage = options.storage or InMemoryChatStorage()
self.trace = options.trace


Expand Down

0 comments on commit 77d420b

Please sign in to comment.