-
Notifications
You must be signed in to change notification settings - Fork 263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ParallelAgent class #103
Open
carr324
wants to merge
12
commits into
awslabs:main
Choose a base branch
from
carr324:parallel-agent
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+672
−0
Open
Changes from 4 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
019aab6
adding ParallelAgent code
carr324 42f064e
Update parallel_agent.py after initial PR review
carr324 49c55e9
Removing unneeded BedrockLLM import
carr324 52d9d74
Merge branch 'awslabs:main' into parallel-agent
carr324 8e8eb85
Merge branch 'awslabs:main' into parallel-agent
carr324 952f412
Create init
carr324 5c81610
Adding example workflow for parallel agent
carr324 488d3e8
Delete examples/parallel-chain-agent/init
carr324 bbc9d8e
Merge branch 'awslabs:main' into parallel-agent
carr324 8980eb6
Merge branch 'awslabs:main' into parallel-agent
carr324 62d4922
Updated parallel_agent agent with concurrent.futures
carr324 0dee77c
Updated parallel_agent example with concurrent.futures
carr324 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
130 changes: 130 additions & 0 deletions
130
python/src/multi_agent_orchestrator/agents/parallel_agent.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import asyncio | ||
from typing import Any, AsyncIterable | ||
|
||
from multi_agent_orchestrator.agents import ( | ||
Agent, | ||
AgentOptions, | ||
) | ||
from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole | ||
from multi_agent_orchestrator.utils.logger import Logger | ||
|
||
|
||
# Extend AgentOptions for ParallelAgent class: | ||
class ParallelAgentOptions(AgentOptions): | ||
def __init__( | ||
self, | ||
agents: list[Agent], | ||
default_output: str = None, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.agents = agents | ||
self.default_output = default_output | ||
|
||
|
||
# Create a new custom agent that allows for parallel processing: | ||
class ParallelAgent(Agent): | ||
def __init__(self, options: ParallelAgentOptions): | ||
super().__init__(options) | ||
self.agents = options.agents | ||
self.default_output = ( | ||
options.default_output or "No output generated from the ParallelAgent." | ||
) | ||
if len(self.agents) == 0: | ||
raise ValueError("ParallelAgent requires at least 1 agent to initiate!") | ||
|
||
async def _get_llm_response( | ||
self, | ||
agent: Agent, | ||
input_text: str, | ||
user_id: str, | ||
session_id: str, | ||
chat_history: list[ConversationMessage], | ||
additional_params: dict[str, str] = None, | ||
) -> str: | ||
# Get response from LLM agent: | ||
final_response: ConversationMessage | AsyncIterable[Any] | ||
|
||
try: | ||
response = await agent.process_request( | ||
input_text, user_id, session_id, chat_history, additional_params | ||
) | ||
if self.is_conversation_message(response): | ||
if response.content and "text" in response.content[0]: | ||
final_response = response | ||
else: | ||
Logger.warn(f"Agent {agent.name} returned no text content.") | ||
return self.create_default_response() | ||
elif self.is_async_iterable(response): | ||
Logger.warn("Streaming is not allowed for ParallelAgents!") | ||
return self.create_default_response() | ||
else: | ||
Logger.warn(f"Agent {agent.name} returned an invalid response type.") | ||
return self.create_default_response() | ||
|
||
except Exception as error: | ||
Logger.error( | ||
f"Error processing request with agent {agent.name}: {str(error)}" | ||
) | ||
raise f"Error processing request with agent {agent.name}: {str(error)}" | ||
|
||
return final_response | ||
|
||
async def process_request( | ||
self, | ||
input_text: str, | ||
user_id: str, | ||
session_id: str, | ||
chat_history: list[ConversationMessage], | ||
additional_params: dict[str, str] = None, | ||
) -> ConversationMessage: | ||
# Create tasks for all LLMs to run in parallel: | ||
tasks = [] | ||
for agent in self.agents: | ||
tasks.append( | ||
self._get_llm_response( | ||
agent, | ||
input_text, | ||
user_id, | ||
session_id, | ||
chat_history, | ||
additional_params, | ||
) | ||
) | ||
|
||
# Run all tasks concurrently and wait for results: | ||
responses = await asyncio.gather(*tasks) | ||
|
||
# Create dictionary of responses: | ||
response_dict = { | ||
agent.name: response.content[0]["text"] | ||
for agent, response in zip(self.agents, responses) | ||
if response # Only include non-empty responses! | ||
} | ||
|
||
# Convert dictionary to string representation: | ||
combined_response = str(response_dict) | ||
|
||
return ConversationMessage( | ||
role=ParticipantRole.ASSISTANT.value, | ||
content=[{"text": combined_response}], | ||
) | ||
|
||
@staticmethod | ||
def is_async_iterable(obj: any) -> bool: | ||
return hasattr(obj, "__aiter__") | ||
|
||
@staticmethod | ||
def is_conversation_message(response: any) -> bool: | ||
return ( | ||
isinstance(response, ConversationMessage) | ||
and hasattr(response, "role") | ||
and hasattr(response, "content") | ||
and isinstance(response.content, list) | ||
) | ||
|
||
def create_default_response(self) -> ConversationMessage: | ||
return ConversationMessage( | ||
role=ParticipantRole.ASSISTANT.value, | ||
content=[{"text": self.default_output}], | ||
) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why adding another method? Can't you just call agent.process_request()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wanted to include some of the
ChainAgent
logic in lines 41-65 here in the async function that gets run for each individual agent within theParallelAgent
, which seemed easier/cleaner in a new internal method. If we think it's unnecessary, fine to revise or removeThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see ok. Well I'd suggest to change the method name from _get_llm_response to self.agent_process_request()
The framework is not only about llm.