diff --git a/.gitignore b/.gitignore index bbb3e01a..cf3e7bce 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,11 @@ typescript/*.tgz *aws-exports.json !download.js +docs/*.py +docs/*.py +docs/*.txt +docs/code_execution_env + examples/local-demo/.env typescript/coverage/**/* diff --git a/docs/src/content/docs/agents/built-in/bedrock-translator-agent.mdx b/docs/src/content/docs/agents/built-in/bedrock-translator-agent.mdx index 06123e6d..b4911f53 100644 --- a/docs/src/content/docs/agents/built-in/bedrock-translator-agent.mdx +++ b/docs/src/content/docs/agents/built-in/bedrock-translator-agent.mdx @@ -12,6 +12,7 @@ The `BedrockTranslatorAgent` uses Amazon Bedrock's language models to translate - Allows dynamic setting of source and target languages - Can be used standalone or as part of a [ChainAgent](/multi-agent-orchestrator/agents/built-in/chain-agent) - Configurable inference parameters for fine-tuned control +- Supports both streaming and non-streaming responses ## Creating a Bedrock Translator Agent @@ -29,7 +30,8 @@ import { Tabs, TabItem } from '@astrojs/starlight/components'; const agent = new BedrockTranslatorAgent({ name: 'BasicTranslator', description: 'Translates text to English', - targetLanguage: 'English' + targetLanguage: 'English', + streaming: false // Set to true for streaming responses }); ``` @@ -40,7 +42,8 @@ import { Tabs, TabItem } from '@astrojs/starlight/components'; agent = BedrockTranslatorAgent(BedrockTranslatorAgentOptions( name='BasicTranslator', description='Translates text to English', - target_language='English' + target_language='English', + streaming=False # Set to True for streaming responses )) ``` @@ -62,6 +65,7 @@ For more complex use cases, you can create a BedrockTranslatorAgent with custom targetLanguage: 'German', modelId: BEDROCK_MODEL_ID_CLAUDE_3_SONNET, region: 'us-west-2', + streaming: true, // Enable streaming responses inferenceConfig: { maxTokens: 2000, temperature: 0.1, @@ -85,6 +89,7 @@ For more complex use cases, you can create a BedrockTranslatorAgent with custom target_language='German', model_id=BEDROCK_MODEL_ID_CLAUDE_3_SONNET, region='us-west-2', + streaming=True, # Enable streaming responses inference_config={ 'maxTokens': 2000, 'temperature': 0.1, @@ -98,6 +103,56 @@ For more complex use cases, you can create a BedrockTranslatorAgent with custom +## Streaming Responses + +The `streaming` parameter allows you to choose between receiving the entire translation at once or as a stream of partial responses. When set to `true`, the agent will return an asynchronous iterable of string chunks, which can be useful for real-time display of translations or processing very large texts. + +### Example of using streaming responses: + + + + ```typescript + import { BedrockTranslatorAgent, BedrockTranslatorAgentOptions } from 'multi-agent-orchestrator'; + + const agent = new BedrockTranslatorAgent({ + name: 'StreamingTranslator', + description: 'Translates text with streaming responses', + targetLanguage: 'French', + streaming: true + }); + + async function translateWithStreaming(text: string) { + const response = await agent.processRequest(text, 'user123', 'session456'); + for await (const chunk of response) { + console.log('Partial translation:', chunk); + } + } + + translateWithStreaming("Hello, world!"); + ``` + + + ```python + from multi_agent_orchestrator.agents import BedrockTranslatorAgent, BedrockTranslatorAgentOptions + + agent = BedrockTranslatorAgent(BedrockTranslatorAgentOptions( + name='StreamingTranslator', + description='Translates text with streaming responses', + target_language='French', + streaming=True + )) + + async def translate_with_streaming(text: str): + response = await agent.process_request(text, 'user123', 'session456') + async for chunk in response: + print('Partial translation:', chunk) + + import asyncio + asyncio.run(translate_with_streaming("Hello, world!")) + ``` + + + ## Dynamic Language Setting To set the language during the invocation: @@ -109,7 +164,8 @@ To set the language during the invocation: const translator = new BedrockTranslatorAgent({ name: 'DynamicTranslator', - description: 'Translator with dynamically set languages' + description: 'Translator with dynamically set languages', + streaming: false // Set to true if you want streaming responses }); const orchestrator = new MultiAgentOrchestrator(); @@ -140,7 +196,8 @@ To set the language during the invocation: translator = BedrockTranslatorAgent(BedrockTranslatorAgentOptions( name='DynamicTranslator', - description='Translator with dynamically set languages' + description='Translator with dynamically set languages', + streaming=False # Set to True if you want streaming responses )) orchestrator = MultiAgentOrchestrator() @@ -180,7 +237,8 @@ The `BedrockTranslatorAgent` can be effectively used within a `ChainAgent` for c const translatorToEnglish = new BedrockTranslatorAgent({ name: 'TranslatorToEnglish', description: 'Translates input to English', - targetLanguage: 'English' + targetLanguage: 'English', + streaming: false // Set to true for streaming responses }); // Create a processing agent (e.g., a BedrockLLMAgent) @@ -226,7 +284,8 @@ The `BedrockTranslatorAgent` can be effectively used within a `ChainAgent` for c translator_to_english = BedrockTranslatorAgent(BedrockTranslatorAgentOptions( name='TranslatorToEnglish', description='Translates input to English', - target_language='English' + target_language='English', + streaming=False # Set to True for streaming responses )) # Create a processing agent (e.g., a BedrockLLMAgent) @@ -273,4 +332,4 @@ This setup allows for seamless multilingual processing, where the core logic can --- -By leveraging the `BedrockTranslatorAgent`, you can create sophisticated multilingual applications and workflows, enabling seamless communication and processing across language barriers in your Multi-Agent Orchestrator system. \ No newline at end of file +By leveraging the `BedrockTranslatorAgent`, you can create sophisticated multilingual applications and workflows, enabling seamless communication and processing across language barriers in your Multi-Agent Orchestrator system. The streaming capability allows for real-time translation of large texts or integration into applications that require immediate feedback. \ No newline at end of file diff --git a/python/src/multi_agent_orchestrator/agents/agent.py b/python/src/multi_agent_orchestrator/agents/agent.py index 0a54e94d..01a9122c 100644 --- a/python/src/multi_agent_orchestrator/agents/agent.py +++ b/python/src/multi_agent_orchestrator/agents/agent.py @@ -1,7 +1,9 @@ from typing import Dict, List, Union, AsyncIterable, Optional, Any from abc import ABC, abstractmethod from dataclasses import dataclass, field -from multi_agent_orchestrator.types import ConversationMessage +from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole +from multi_agent_orchestrator.utils import Logger + @dataclass class AgentProcessingResult: @@ -60,3 +62,17 @@ async def process_request( additional_params: Optional[Dict[str, str]] = None ) -> Union[ConversationMessage, AsyncIterable[any]]: pass + + def create_error_response(self, message: str, error: Optional[Exception] = None) -> ConversationMessage: + error_message = "Sorry, I encountered an error while processing your request." + if error is not None: + error_message += f" Error details: {str(error)}" + else: + error_message += f" {message}" + + Logger.error(f"{self.name} Error: {error_message}") + + return ConversationMessage( + role=ParticipantRole.ASSISTANT, + content=[{"text": error_message}] + ) diff --git a/python/src/multi_agent_orchestrator/agents/amazon_bedrock_agent.py b/python/src/multi_agent_orchestrator/agents/amazon_bedrock_agent.py index 6d5bd3f7..0fb3e05b 100644 --- a/python/src/multi_agent_orchestrator/agents/amazon_bedrock_agent.py +++ b/python/src/multi_agent_orchestrator/agents/amazon_bedrock_agent.py @@ -9,25 +9,21 @@ from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole from multi_agent_orchestrator.utils import Logger - @dataclass class AmazonBedrockAgentOptions(AgentOptions): """Options for Amazon Bedrock Agent.""" agent_id: str = None agent_alias_id: str = None - class AmazonBedrockAgent(Agent): """ Represents an Amazon Bedrock agent that interacts with a runtime client. Extends base Agent class and implements specific methods for Amazon Bedrock. """ - def __init__(self, options: AmazonBedrockAgentOptions): """ Constructs an instance of AmazonBedrockAgent with the specified options. Initializes the agent ID, agent alias ID, and creates a new Bedrock agent runtime client. - :param options: Options to configure the Amazon Bedrock agent. """ super().__init__(options) @@ -46,12 +42,11 @@ async def process_request( ) -> ConversationMessage: """ Processes a user request by sending it to the Amazon Bedrock agent for processing. - :param input_text: The user input as a string. :param user_id: The ID of the user sending the request. :param session_id: The ID of the session associated with the conversation. :param chat_history: A list of ConversationMessage objects representing - the conversation history. + the conversation history. :param additional_params: Optional additional parameters as key-value pairs. :return: A ConversationMessage object containing the agent's response. """ @@ -62,7 +57,6 @@ async def process_request( sessionId=session_id, inputText=input_text ) - completion = "" for event in response['completion']: if 'chunk' in event: @@ -71,15 +65,10 @@ async def process_request( completion += decoded_response else: Logger.warn("Received a chunk event with no chunk data") - return ConversationMessage( role=ParticipantRole.ASSISTANT, content=[{"text": completion}] ) - except (BotoCoreError, ClientError) as error: Logger.error(f"Error processing request: {error}") - return ConversationMessage( - role=ParticipantRole.ASSISTANT, - content=[{"text": "Sorry, I encountered an error while processing your request."}] - ) + return self.createErrorResponse("An error occurred while processing your request.", error) \ No newline at end of file diff --git a/python/src/multi_agent_orchestrator/agents/bedrock_llm_agent.py b/python/src/multi_agent_orchestrator/agents/bedrock_llm_agent.py index 1a890e9d..936cba3c 100644 --- a/python/src/multi_agent_orchestrator/agents/bedrock_llm_agent.py +++ b/python/src/multi_agent_orchestrator/agents/bedrock_llm_agent.py @@ -81,65 +81,68 @@ async def process_request( chat_history: List[ConversationMessage], additional_params: Optional[Dict[str, str]] = None ) -> Union[ConversationMessage, AsyncIterable[Any]]: - - user_message =ConversationMessage( - role=ParticipantRole.USER.value, - content=[{'text': input_text}] - ) + try: + user_message = ConversationMessage( + role=ParticipantRole.USER.value, + content=[{'text': input_text}] + ) - conversation = [*chat_history, user_message] + conversation = [*chat_history, user_message] - self.update_system_prompt() + self.update_system_prompt() + + system_prompt = self.system_prompt - system_prompt = self.system_prompt - - if self.retriever: - response = await self.retriever.retrieve_and_combine_results(input_text) - context_prompt = "\nHere is the context to use to answer the user's question:\n" + response - system_prompt += context_prompt - - converse_cmd = { - 'modelId': self.model_id, - 'messages': conversation_to_dict(conversation), - 'system': [{'text': system_prompt}], - 'inferenceConfig': { - 'maxTokens': self.inference_config.get('maxTokens'), - 'temperature': self.inference_config.get('temperature'), - 'topP': self.inference_config.get('topP'), - 'stopSequences': self.inference_config.get('stopSequences'), + if self.retriever: + response = await self.retriever.retrieve_and_combine_results(input_text) + context_prompt = "\nHere is the context to use to answer the user's question:\n" + response + system_prompt += context_prompt + + converse_cmd = { + 'modelId': self.model_id, + 'messages': conversation_to_dict(conversation), + 'system': [{'text': system_prompt}], + 'inferenceConfig': { + 'maxTokens': self.inference_config.get('maxTokens'), + 'temperature': self.inference_config.get('temperature'), + 'topP': self.inference_config.get('topP'), + 'stopSequences': self.inference_config.get('stopSequences'), + } } - } - if self.guardrail_config: - converse_cmd["guardrailConfig"] = self.guardrail_config + if self.guardrail_config: + converse_cmd["guardrailConfig"] = self.guardrail_config - if self.tool_config: - converse_cmd["toolConfig"] = self.tool_config["tool"] + if self.tool_config: + converse_cmd["toolConfig"] = self.tool_config["tool"] - if self.tool_config: - continue_with_tools = True - final_message: ConversationMessage = {'role': ParticipantRole.USER.value, 'content': []} - max_recursions = self.tool_config.get('toolMaxRecursions', self.default_max_recursions) + if self.tool_config: + continue_with_tools = True + final_message: ConversationMessage = {'role': ParticipantRole.USER.value, 'content': []} + max_recursions = self.tool_config.get('toolMaxRecursions', self.default_max_recursions) - while continue_with_tools and max_recursions > 0: - bedrock_response = await self.handle_single_response(converse_cmd) - conversation.append(bedrock_response) + while continue_with_tools and max_recursions > 0: + bedrock_response = await self.handle_single_response(converse_cmd) + conversation.append(bedrock_response) - if any('toolUse' in content for content in bedrock_response.content): - await self.tool_config['useToolHandler'](bedrock_response, conversation) - else: - continue_with_tools = False - final_message = bedrock_response + if any('toolUse' in content for content in bedrock_response.content): + await self.tool_config['useToolHandler'](bedrock_response, conversation) + else: + continue_with_tools = False + final_message = bedrock_response - max_recursions -= 1 - converse_cmd['messages'] = conversation + max_recursions -= 1 + converse_cmd['messages'] = conversation - return final_message + return final_message - if self.streaming: - return await self.handle_streaming_response(converse_cmd) + if self.streaming: + return await self.handle_streaming_response(converse_cmd) - return await self.handle_single_response(converse_cmd) + return await self.handle_single_response(converse_cmd) + except Exception as error: + Logger.error("Error in BedrockLLMAgent.process_request:", error) + return self.createErrorResponse("An error occurred while processing your request.", error) async def handle_single_response(self, converse_input: Dict[str, Any]) -> ConversationMessage: try: @@ -152,7 +155,7 @@ async def handle_single_response(self, converse_input: Dict[str, Any]) -> Conver ) except Exception as error: Logger.error("Error invoking Bedrock model:", error) - raise + return self.createErrorResponse("An error occurred while processing your request with the Bedrock model.", error) async def handle_streaming_response(self, converse_input: Dict[str, Any]) -> ConversationMessage: try: @@ -168,7 +171,7 @@ async def handle_streaming_response(self, converse_input: Dict[str, Any]) -> Con ) except Exception as error: Logger.error("Error getting stream from Bedrock model:", error) - raise + return self.createErrorResponse("An error occurred while streaming the response from the Bedrock model.", error) def set_system_prompt(self, template: Optional[str] = None, @@ -192,4 +195,4 @@ def replace(match): return '\n'.join(value) if isinstance(value, list) else str(value) return match.group(0) - return re.sub(r'{{(\w+)}}', replace, template) + return re.sub(r'{{(\w+)}}', replace, template) \ No newline at end of file diff --git a/python/src/multi_agent_orchestrator/agents/bedrock_translator_agent.py b/python/src/multi_agent_orchestrator/agents/bedrock_translator_agent.py index a0dcc360..94fb7b00 100644 --- a/python/src/multi_agent_orchestrator/agents/bedrock_translator_agent.py +++ b/python/src/multi_agent_orchestrator/agents/bedrock_translator_agent.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Optional, Any +from typing import List, Dict, Optional, Any, Union, AsyncIterable from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole, BEDROCK_MODEL_ID_CLAUDE_3_HAIKU from multi_agent_orchestrator.utils import conversation_to_dict, Logger from dataclasses import dataclass @@ -12,6 +12,7 @@ class BedrockTranslatorAgentOptions(AgentOptions): inference_config: Optional[Dict[str, Any]] = None model_id: Optional[str] = None region: Optional[str] = None + streaming: Optional[bool] = None class BedrockTranslatorAgent(Agent): def __init__(self, options: BedrockTranslatorAgentOptions): @@ -20,7 +21,8 @@ def __init__(self, options: BedrockTranslatorAgentOptions): self.target_language = options.target_language or 'English' self.model_id = options.model_id or BEDROCK_MODEL_ID_CLAUDE_3_HAIKU self.client = boto3.client('bedrock-runtime', region_name=options.region) - + self.streaming = options.streaming or False + # Default inference configuration self.inference_config: Dict[str, Any] = options.inference_config or { 'maxTokens': 1000, @@ -54,77 +56,95 @@ async def process_request(self, user_id: str, session_id: str, chat_history: List[ConversationMessage], - additional_params: Optional[Dict[str, str]] = None) -> ConversationMessage: - # Check if input is a number and return it as-is if true - if input_text.isdigit(): - return ConversationMessage( - role=ParticipantRole.ASSISTANT, - content=[{"text": input_text}] + additional_params: Optional[Dict[str, str]] = None) -> Union[ConversationMessage, AsyncIterable[Any]]: + try: + # Check if input is a number and return it as-is if true + if input_text.isdigit(): + return ConversationMessage( + role=ParticipantRole.ASSISTANT, + content=[{"text": input_text}] + ) + + # Prepare user message + user_message = ConversationMessage( + role=ParticipantRole.USER, + content=[{"text": f"{input_text}"}] ) - # Prepare user message - user_message = ConversationMessage( - role=ParticipantRole.USER, - content=[{"text": f"{input_text}"}] - ) - - # Construct system prompt - system_prompt = "You are a translator. Translate the text within the tags" - if self.source_language: - system_prompt += f" from {self.source_language} to {self.target_language}" - else: - system_prompt += f" to {self.target_language}" - system_prompt += ". Only provide the translation using the Translate tool." - - # Prepare the converse command for Bedrock - converse_cmd = { - "modelId": self.model_id, - "messages": [conversation_to_dict(user_message)], - "system": [{"text": system_prompt}], - "toolConfig": { - "tools": self.tools, - "toolChoice": { - "tool": { - "name": "Translate", + # Construct system prompt + system_prompt = "You are a translator. Translate the text within the tags" + if self.source_language: + system_prompt += f" from {self.source_language} to {self.target_language}" + else: + system_prompt += f" to {self.target_language}" + system_prompt += ". Only provide the translation using the Translate tool." + + # Prepare the converse command for Bedrock + converse_cmd = { + "modelId": self.model_id, + "messages": [conversation_to_dict(user_message)], + "system": [{"text": system_prompt}], + "toolConfig": { + "tools": self.tools, + "toolChoice": { + "tool": { + "name": "Translate", + }, }, - }, - }, - 'inferenceConfig': self.inference_config - } + }, + 'inferenceConfig': self.inference_config + } - try: - # Send request to Bedrock - response = self.client.converse(**converse_cmd) + if self.streaming: + return await self.handle_streaming_response(converse_cmd) + else: + return await self.handle_single_response(converse_cmd) - if 'output' not in response: - raise ValueError("No output received from Bedrock model") + except Exception as error: + Logger.error("Error processing translation request:", error) + return self.createErrorResponse("An error occurred while processing your translation request.", error) + + async def handle_single_response(self, converse_cmd: Dict[str, Any]) -> ConversationMessage: + response = self.client.converse(**converse_cmd) + + if 'output' not in response: + raise ValueError("No output received from Bedrock model") - if response['output'].get('message', {}).get('content'): - response_content_blocks = response['output']['message']['content'] + if response['output'].get('message', {}).get('content'): + response_content_blocks = response['output']['message']['content'] - for content_block in response_content_blocks: - if "toolUse" in content_block: - tool_use = content_block["toolUse"] - if not tool_use: - raise ValueError("No tool use found in the response") + for content_block in response_content_blocks: + if "toolUse" in content_block: + tool_use = content_block["toolUse"] + if not tool_use: + raise ValueError("No tool use found in the response") - if not isinstance(tool_use.get('input'), dict) or 'translation' not in tool_use['input']: - raise ValueError("Tool input does not match expected structure") + if not isinstance(tool_use.get('input'), dict) or 'translation' not in tool_use['input']: + raise ValueError("Tool input does not match expected structure") - translation = tool_use['input']['translation'] - if not isinstance(translation, str): - raise ValueError("Translation is not a string") + translation = tool_use['input']['translation'] + if not isinstance(translation, str): + raise ValueError("Translation is not a string") - # Return the translated text - return ConversationMessage( - role=ParticipantRole.ASSISTANT, - content=[{"text": translation}] - ) + # Return the translated text + return ConversationMessage( + role=ParticipantRole.ASSISTANT, + content=[{"text": translation}] + ) - raise ValueError("No valid tool use found in the response") + raise ValueError("No valid tool use found in the response") + + async def handle_streaming_response(self, converse_cmd: Dict[str, Any]) -> AsyncIterable[str]: + try: + response = self.client.converse_stream(**converse_cmd) + async for chunk in response["stream"]: + if "contentBlockDelta" in chunk: + content = chunk.get("contentBlockDelta", {}).get("delta", {}).get("text") + if content: + yield content except Exception as error: - Logger.error("Error processing translation request:", error) - raise + Logger.error("Error getting stream from Bedrock model:", error) + yield self.createErrorResponse("An error occurred while streaming the response from the Bedrock model.", error).content[0].text def set_source_language(self, language: Optional[str]): """Set the source language for translation""" diff --git a/python/src/multi_agent_orchestrator/agents/chain_agent.py b/python/src/multi_agent_orchestrator/agents/chain_agent.py index 2f233f84..484677c9 100644 --- a/python/src/multi_agent_orchestrator/agents/chain_agent.py +++ b/python/src/multi_agent_orchestrator/agents/chain_agent.py @@ -27,7 +27,6 @@ async def process_request( ) -> Union[ConversationMessage, AsyncIterable[any]]: current_input = input_text final_response: Union[ConversationMessage, AsyncIterable[any]] - for i, agent in enumerate(self.agents): is_last_agent = i == len(self.agents) - 1 try: @@ -56,16 +55,13 @@ async def process_request( else: Logger.logger.warning(f"Agent {agent.name} returned an invalid response type.") return self.create_default_response() - # If it's not the last agent, ensure we have a non-streaming response to pass to the next agent if not is_last_agent and not self.is_conversation_message(final_response): Logger.logger.error(f"Expected non-streaming response from intermediate agent {agent.name}") return self.create_default_response() - except Exception as error: Logger.logger.error(f"Error processing request with agent {agent.name}:", error) - return self.create_default_response() - + return self.createErrorResponse(f"An error occurred while processing your request with agent {agent.name}.", error) return final_response @staticmethod diff --git a/python/src/multi_agent_orchestrator/agents/comprehend_filter_agent.py b/python/src/multi_agent_orchestrator/agents/comprehend_filter_agent.py index 04eb25ad..bc965784 100644 --- a/python/src/multi_agent_orchestrator/agents/comprehend_filter_agent.py +++ b/python/src/multi_agent_orchestrator/agents/comprehend_filter_agent.py @@ -30,12 +30,12 @@ def __init__(self, class ComprehendFilterAgent(Agent): def __init__(self, options: ComprehendFilterAgentOptions): super().__init__(options) - + config = Config(region_name=options.region) if options.region else None self.comprehend_client = boto3.client('comprehend', config=config) - + self.custom_checks: List[CheckFunction] = [] - + self.enable_sentiment_check = options.enable_sentiment_check self.enable_pii_check = options.enable_pii_check self.enable_toxicity_check = options.enable_toxicity_check @@ -43,7 +43,7 @@ def __init__(self, options: ComprehendFilterAgentOptions): self.toxicity_threshold = options.toxicity_threshold self.allow_pii = options.allow_pii self.language_code = self.validate_language_code(options.language_code) or 'en' - + # Ensure at least one check is enabled if not any([self.enable_sentiment_check, self.enable_pii_check, self.enable_toxicity_check]): self.enable_toxicity_check = True @@ -96,7 +96,7 @@ async def process_request(self, except Exception as error: Logger.logger.error("Error in ComprehendContentFilterAgent:", error) - raise + return self.createErrorResponse("An error occurred while processing your request in the ComprehendContentFilterAgent.", error) def add_custom_check(self, check: CheckFunction): self.custom_checks.append(check) diff --git a/python/src/multi_agent_orchestrator/agents/lambda_agent.py b/python/src/multi_agent_orchestrator/agents/lambda_agent.py index d8175beb..44f17b10 100644 --- a/python/src/multi_agent_orchestrator/agents/lambda_agent.py +++ b/python/src/multi_agent_orchestrator/agents/lambda_agent.py @@ -20,7 +20,6 @@ class LambdaAgentOptions(AgentOptions): ConversationMessage ]] = None - class LambdaAgent(Agent): def __init__(self, options: LambdaAgentOptions): super().__init__(options) @@ -30,19 +29,18 @@ def __init__(self, options: LambdaAgentOptions): self.encoder = self.__default_input_payload_encoder else: self.encoder = self.options.input_payload_encoder - if self.options.output_payload_decoder is None: self.decoder = self.__default_output_payload_decoder else: self.decoder = self.options.output_payload_decoder def __default_input_payload_encoder(self, - input_text: str, - chat_history: List[ConversationMessage], - user_id: str, - session_id: str, - additional_params: Optional[Dict[str, str]] = None - ) -> str: + input_text: str, + chat_history: List[ConversationMessage], + user_id: str, + session_id: str, + additional_params: Optional[Dict[str, str]] = None + ) -> str: """Encode input payload as JSON string.""" return json.dumps({ 'query': input_text, @@ -52,12 +50,11 @@ def __default_input_payload_encoder(self, 'sessionId': session_id, }) - def __default_output_payload_decoder(self, response: Dict[str, Any]) -> ConversationMessage: """Decode Lambda response and create ConversationMessage.""" decoded_response = json.loads( json.loads(response['Payload'].read().decode('utf-8'))['body'] - )['response'] + )['response'] return ConversationMessage( role=ParticipantRole.ASSISTANT.value, content=[{'text': decoded_response}] @@ -72,10 +69,12 @@ async def process_request( additional_params: Optional[Dict[str, str]] = None ) -> ConversationMessage: """Process the request by invoking Lambda function and decoding the response.""" - payload = self.encoder(input_text, chat_history, user_id, session_id, additional_params) - - response = self.lambda_client.invoke( - FunctionName=self.options.function_name, - Payload=payload - ) - return self.decoder(response) + try: + payload = self.encoder(input_text, chat_history, user_id, session_id, additional_params) + response = self.lambda_client.invoke( + FunctionName=self.options.function_name, + Payload=payload + ) + return self.decoder(response) + except Exception as error: + return self.createErrorResponse("An error occurred while processing your request in the LambdaAgent.", error) \ No newline at end of file diff --git a/python/src/multi_agent_orchestrator/agents/lex_bot_agent.py b/python/src/multi_agent_orchestrator/agents/lex_bot_agent.py index bc2d818d..4f1a154d 100644 --- a/python/src/multi_agent_orchestrator/agents/lex_bot_agent.py +++ b/python/src/multi_agent_orchestrator/agents/lex_bot_agent.py @@ -24,13 +24,12 @@ def __init__(self, options: LexBotAgentOptions): self.bot_id = options.bot_id self.bot_alias_id = options.bot_alias_id self.locale_id = options.locale_id - if not all([self.bot_id, self.bot_alias_id, self.locale_id]): raise ValueError("bot_id, bot_alias_id, and locale_id are required for LexBotAgent") async def process_request(self, input_text: str, user_id: str, session_id: str, - chat_history: List[ConversationMessage], - additional_params: Optional[Dict[str, str]] = None) -> ConversationMessage: + chat_history: List[ConversationMessage], + additional_params: Optional[Dict[str, str]] = None) -> ConversationMessage: try: params = { 'botId': self.bot_id, @@ -40,20 +39,15 @@ async def process_request(self, input_text: str, user_id: str, session_id: str, 'text': input_text, 'sessionState': {} # You might want to maintain session state if needed } - response = self.lex_client.recognize_text(**params) - concatenated_content = ' '.join( message.get('content', '') for message in response.get('messages', []) if message.get('content') ) - return ConversationMessage( role=ParticipantRole.ASSISTANT, content=[{"text": concatenated_content or "No response from Lex bot."}] ) - except (BotoCoreError, ClientError) as error: Logger.error(f"Error processing request: {error}") - raise - + return self.createErrorResponse("An error occurred while processing your request in the LexBotAgent.", error) \ No newline at end of file diff --git a/typescript/src/agents/agent.ts b/typescript/src/agents/agent.ts index bf2c1ca6..63ae34af 100644 --- a/typescript/src/agents/agent.ts +++ b/typescript/src/agents/agent.ts @@ -1,5 +1,7 @@ -import { ConversationMessage } from "../types"; import { AccumulatorTransform } from "../utils/helpers"; +import { Logger } from "../utils/logger"; +import { ConversationMessage, ParticipantRole } from "../types"; + export interface AgentProcessingResult { // The original input provided by the user @@ -117,4 +119,18 @@ abstract processRequest( additionalParams?: Record ): Promise>; +protected createErrorResponse(message: string, error?: unknown): ConversationMessage { + let errorMessage = `Sorry, I encountered an error while processing your request.`; + if (error instanceof Error) { + errorMessage += ` Error details: ${error.message}`; + } else { + errorMessage += ` ${message}`; + } + Logger.logger.error(`${this.name} Error:`, errorMessage); + return { + role: ParticipantRole.ASSISTANT, + content: [{ text: errorMessage }], + }; +} + } diff --git a/typescript/src/agents/amazonBedrockAgent.ts b/typescript/src/agents/amazonBedrockAgent.ts index cacb8263..90347080 100644 --- a/typescript/src/agents/amazonBedrockAgent.ts +++ b/typescript/src/agents/amazonBedrockAgent.ts @@ -8,19 +8,18 @@ import { Logger } from "../utils/logger"; * Extends base AgentOptions with specific parameters required for Amazon Bedrock. */ export interface AmazonBedrockAgentOptions extends AgentOptions { - agentId: string; // The ID of the Amazon Bedrock agent. - agentAliasId: string; // The alias ID of the Amazon Bedrock agent. + agentId: string; // The ID of the Amazon Bedrock agent. + agentAliasId: string; // The alias ID of the Amazon Bedrock agent. } - /** * Represents an Amazon Bedrock agent that interacts with a runtime client. * Extends base Agent class and implements specific methods for Amazon Bedrock. */ export class AmazonBedrockAgent extends Agent { - private agentId: string; // The ID of the Amazon Bedrock agent. - private agentAliasId: string; // The alias ID of the Amazon Bedrock agent. - private client: BedrockAgentRuntimeClient; // Client for interacting with the Bedrock agent runtime. + private agentId: string; // The ID of the Amazon Bedrock agent. + private agentAliasId: string; // The alias ID of the Amazon Bedrock agent. + private client: BedrockAgentRuntimeClient; // Client for interacting with the Bedrock agent runtime. /** * Constructs an instance of AmazonBedrockAgent with the specified options. @@ -32,8 +31,8 @@ export class AmazonBedrockAgent extends Agent { this.agentId = options.agentId; this.agentAliasId = options.agentAliasId; this.client = options.region - ? new BedrockAgentRuntimeClient({ region: options.region }) - : new BedrockAgentRuntimeClient(); + ? new BedrockAgentRuntimeClient({ region: options.region }) + : new BedrockAgentRuntimeClient(); } /** @@ -53,15 +52,15 @@ export class AmazonBedrockAgent extends Agent { chatHistory: ConversationMessage[], additionalParams?: Record ): Promise { - // Construct the command to invoke the Amazon Bedrock agent with user input - const command = new InvokeAgentCommand({ - agentId: this.agentId, - agentAliasId: this.agentAliasId, - sessionId, - inputText - }); - try { + // Construct the command to invoke the Amazon Bedrock agent with user input + const command = new InvokeAgentCommand({ + agentId: this.agentId, + agentAliasId: this.agentAliasId, + sessionId, + inputText + }); + let completion = ""; const response = await this.client.send(command); @@ -86,20 +85,10 @@ export class AmazonBedrockAgent extends Agent { role: ParticipantRole.ASSISTANT, content: [{ text: completion }], }; - } catch (err) { + } catch (error) { // Handle errors encountered while invoking the Amazon Bedrock agent - Logger.logger.error(err); - - // Return a default error message as a fallback response - return { - role: ParticipantRole.ASSISTANT, - content: [ - { - text: "Sorry, I encountered an error while processing your request.", - }, - ], - }; + Logger.logger.error("Error in AmazonBedrockAgent.processRequest:", error); + return this.createErrorResponse("An error occurred while processing your request with the Amazon Bedrock agent.", error); } } -} - +} \ No newline at end of file diff --git a/typescript/src/agents/bedrockLLMAgent.ts b/typescript/src/agents/bedrockLLMAgent.ts index f36d33cd..0357382c 100644 --- a/typescript/src/agents/bedrockLLMAgent.ts +++ b/typescript/src/agents/bedrockLLMAgent.ts @@ -37,72 +37,43 @@ export interface BedrockLLMAgentOptions extends AgentOptions { }; } -/** - * BedrockAgent class represents an agent that uses Amazon Bedrock for natural language processing. - * It extends the base Agent class and implements the processRequest method using Bedrock's API. - */ export class BedrockLLMAgent extends Agent { - /** AWS Bedrock Runtime Client for making API calls */ protected client: BedrockRuntimeClient; - protected customSystemPrompt?: string; - protected streaming: boolean; - protected inferenceConfig: { maxTokens?: number; temperature?: number; topP?: number; stopSequences?: string[]; }; - - /** - * The ID of the model used by this agent. - */ protected modelId?: string; - protected guardrailConfig?: { guardrailIdentifier: string; guardrailVersion: string; }; - protected retriever?: Retriever; - private toolConfig?: { tool: any[]; useToolHandler: (response: any, conversation: ConversationMessage[]) => void; toolMaxRecursions?: number; }; - private promptTemplate: string; private systemPrompt: string; private customVariables: TemplateVariables; private defaultMaxRecursions: number = 20; - /** - * Constructs a new BedrockAgent instance. - * @param options - Configuration options for the agent, inherited from AgentOptions. - */ constructor(options: BedrockLLMAgentOptions) { super(options); - this.client = options.region ? new BedrockRuntimeClient({ region: options.region }) : new BedrockRuntimeClient(); - - // Initialize the modelId this.modelId = options.modelId ?? BEDROCK_MODEL_ID_CLAUDE_3_HAIKU; - this.streaming = options.streaming ?? false; - this.inferenceConfig = options.inferenceConfig ?? {}; - this.guardrailConfig = options.guardrailConfig ?? null; - this.retriever = options.retriever ?? null; - this.toolConfig = options.toolConfig ?? null; - this.promptTemplate = `You are a ${this.name}. ${this.description} Provide helpful and accurate information based on your expertise. You will engage in an open-ended conversation, providing helpful and accurate information based on your expertise. The conversation will proceed as follows: @@ -118,25 +89,24 @@ export class BedrockLLMAgent extends Agent { - Ask for clarification if any part of the question or prompt is ambiguous. - Maintain a consistent, respectful, and engaging tone tailored to the human's communication style. - Seamlessly transition between topics as the human introduces new subjects.` - + if (options.customSystemPrompt) { this.setSystemPrompt( options.customSystemPrompt.template, options.customSystemPrompt.variables ); } - } - - /** - * Abstract method to process a request. - * This method must be implemented by all concrete agent classes. - * - * @param inputText - The user input as a string. - * @param chatHistory - An array of Message objects representing the conversation history. - * @param additionalParams - Optional additional parameters as key-value pairs. - * @returns A Promise that resolves to a Message object containing the agent's response. - */ + +/** + * Processes a user request by sending it to the Amazon Bedrock agent for processing. + * @param inputText - The user input as a string. + * @param userId - The ID of the user sending the request. + * @param sessionId - The ID of the session associated with the conversation. + * @param chatHistory - An array of Message objects representing the conversation history. + * @param additionalParams - Optional additional parameters as key-value pairs. + * @returns A Promise that resolves to a Message object containing the agent's response. + */ /* eslint-disable @typescript-eslint/no-unused-vars */ async processRequest( inputText: string, @@ -145,81 +115,75 @@ export class BedrockLLMAgent extends Agent { chatHistory: ConversationMessage[], additionalParams?: Record ): Promise> { - // Construct the user's message based on the provided inputText - const userMessage: ConversationMessage = { - role: ParticipantRole.USER, - content: [{ text: `${inputText}` }], - }; - - // Combine the existing chat history with the user's message - const conversation: ConversationMessage[] = [ - ...chatHistory, - userMessage, - ]; - - this.updateSystemPrompt(); - - let systemPrompt = this.systemPrompt; - - // Update the system prompt with the latest history, agent descriptions, and custom variables - if (this.retriever) { - // retrieve from Vector store - const response = await this.retriever.retrieveAndCombineResults(inputText); - const contextPrompt = - "\nHere is the context to use to answer the user's question:\n" + - response; - systemPrompt = systemPrompt + contextPrompt; - } - - // Prepare the command to converse with the Bedrock API - const converseCmd = { - modelId: this.modelId, - messages: conversation, //Include the updated conversation history - system: [{ text: systemPrompt }], - inferenceConfig: { - maxTokens: this.inferenceConfig.maxTokens, - temperature: this.inferenceConfig.temperature, - topP: this.inferenceConfig.topP, - stopSequences: this.inferenceConfig.stopSequences, - }, - guardrailConfig: this.guardrailConfig? this.guardrailConfig:undefined, - toolConfig: (this.toolConfig ? { tools:this.toolConfig.tool}:undefined) - }; - - if (this.toolConfig){ - let continueWithTools = true; - let finalMessage:ConversationMessage = { role: ParticipantRole.USER, content:[]}; - let maxRecursions = this.toolConfig.toolMaxRecursions || this.defaultMaxRecursions; - - while (continueWithTools && maxRecursions > 0){ - // send the conversation to Amazon Bedrock - const bedrockResponse = await this.handleSingleResponse(converseCmd); + try { + const userMessage: ConversationMessage = { + role: ParticipantRole.USER, + content: [{ text: `${inputText}` }], + }; + + const conversation: ConversationMessage[] = [ + ...chatHistory, + userMessage, + ]; + + this.updateSystemPrompt(); + + let systemPrompt = this.systemPrompt; + + if (this.retriever) { + const response = await this.retriever.retrieveAndCombineResults(inputText); + const contextPrompt = + "\nHere is the context to use to answer the user's question:\n" + + response; + systemPrompt = systemPrompt + contextPrompt; + } - // Append the model's response to the ongoing conversation - conversation.push(bedrockResponse); - - // process model response - if (bedrockResponse.content.some((content) => 'toolUse' in content)){ - // forward everything to the tool use handler - await this.toolConfig.useToolHandler(bedrockResponse, conversation); + const converseCmd = { + modelId: this.modelId, + messages: conversation, + system: [{ text: systemPrompt }], + inferenceConfig: { + maxTokens: this.inferenceConfig.maxTokens, + temperature: this.inferenceConfig.temperature, + topP: this.inferenceConfig.topP, + stopSequences: this.inferenceConfig.stopSequences, + }, + guardrailConfig: this.guardrailConfig? this.guardrailConfig:undefined, + toolConfig: (this.toolConfig ? { tools:this.toolConfig.tool}:undefined) + }; + + if (this.toolConfig){ + let continueWithTools = true; + let finalMessage:ConversationMessage = { role: ParticipantRole.USER, content:[]}; + let maxRecursions = this.toolConfig.toolMaxRecursions || this.defaultMaxRecursions; + + while (continueWithTools && maxRecursions > 0){ + const bedrockResponse = await this.handleSingleResponse(converseCmd); + conversation.push(bedrockResponse); + + if (bedrockResponse.content.some((content) => 'toolUse' in content)){ + await this.toolConfig.useToolHandler(bedrockResponse, conversation); + } + else { + continueWithTools = false; + finalMessage = bedrockResponse; + } + maxRecursions--; + + converseCmd.messages = conversation; } - else { - continueWithTools = false; - finalMessage = bedrockResponse; - } - maxRecursions--; - - converseCmd.messages = conversation; - + return finalMessage; } - return finalMessage; - } - else { - if (this.streaming) { - return this.handleStreamingResponse(converseCmd); - } else { - return this.handleSingleResponse(converseCmd); + else { + if (this.streaming) { + return this.handleStreamingResponse(converseCmd); + } else { + return this.handleSingleResponse(converseCmd); + } } + } catch (error) { + Logger.logger.error("Error in BedrockLLMAgent.processRequest:", error); + return this.createErrorResponse("An error occurred while processing your request.", error); } } @@ -234,7 +198,7 @@ export class BedrockLLMAgent extends Agent { return response.output.message as ConversationMessage; } catch (error) { Logger.logger.error("Error invoking Bedrock model:", error); - throw error; + return this.createErrorResponse("An error occurred while processing your request with the Bedrock model.", error); } } @@ -244,17 +208,15 @@ export class BedrockLLMAgent extends Agent { const response = await this.client.send(command); for await (const chunk of response.stream) { const content = chunk.contentBlockDelta?.delta?.text; - if (chunk.contentBlockDelta && chunk.contentBlockDelta.delta && chunk.contentBlockDelta.delta.text) { - yield content; - } - + if (chunk.contentBlockDelta && chunk.contentBlockDelta.delta && chunk.contentBlockDelta.delta.text) { + yield content; + } } } catch (error) { Logger.logger.error("Error getting stream from Bedrock model:", error); - throw error; + yield this.createErrorResponse("An error occurred while streaming the response from the Bedrock model.", error).content[0].text; } } - setSystemPrompt(template?: string, variables?: TemplateVariables): void { if (template) { @@ -294,6 +256,4 @@ export class BedrockLLMAgent extends Agent { return match; // If no replacement found, leave the placeholder as is }); } - - -} +} \ No newline at end of file diff --git a/typescript/src/agents/bedrockTranslatorAgent.ts b/typescript/src/agents/bedrockTranslatorAgent.ts index 5c94cf02..3a078495 100644 --- a/typescript/src/agents/bedrockTranslatorAgent.ts +++ b/typescript/src/agents/bedrockTranslatorAgent.ts @@ -1,12 +1,13 @@ import { Agent, AgentOptions } from "./agent"; import { ConversationMessage, ParticipantRole, BEDROCK_MODEL_ID_CLAUDE_3_HAIKU } from "../types"; -import { BedrockRuntimeClient, ConverseCommand, ContentBlock } from "@aws-sdk/client-bedrock-runtime"; +import { BedrockRuntimeClient, ConverseCommand, ConverseStreamCommand, ContentBlock } from "@aws-sdk/client-bedrock-runtime"; import { Logger } from "../utils/logger"; interface BedrockTranslatorAgentOptions extends AgentOptions { sourceLanguage?: string; targetLanguage?: string; modelId?: string; + streaming?: boolean; inferenceConfig?: { maxTokens?: number; temperature?: number; @@ -32,6 +33,7 @@ export class BedrockTranslatorAgent extends Agent { private targetLanguage: string; private modelId: string; private client: BedrockRuntimeClient; + private streaming: boolean; private inferenceConfig: { maxTokens?: number; temperature?: number; @@ -66,18 +68,19 @@ export class BedrockTranslatorAgent extends Agent { this.targetLanguage = options.targetLanguage || 'English'; this.modelId = options.modelId || BEDROCK_MODEL_ID_CLAUDE_3_HAIKU; this.client = new BedrockRuntimeClient({ region: options.region }); + this.streaming = options.streaming ?? false; this.inferenceConfig = options.inferenceConfig || {}; } /** - * Processes a user request by sending it to the Amazon Bedrock agent for processing. - * @param inputText - The user input as a string. - * @param userId - The ID of the user sending the request. - * @param sessionId - The ID of the session associated with the conversation. - * @param chatHistory - An array of Message objects representing the conversation history. - * @param additionalParams - Optional additional parameters as key-value pairs. - * @returns A Promise that resolves to a Message object containing the agent's response. - */ + * Processes a user request by sending it to the Amazon Bedrock agent for processing. + * @param inputText - The user input as a string. + * @param userId - The ID of the user sending the request. + * @param sessionId - The ID of the session associated with the conversation. + * @param chatHistory - An array of Message objects representing the conversation history. + * @param additionalParams - Optional additional parameters as key-value pairs. + * @returns A Promise that resolves to a Message object containing the agent's response. + */ /* eslint-disable @typescript-eslint/no-unused-vars */ async processRequest( inputText: string, @@ -85,86 +88,118 @@ export class BedrockTranslatorAgent extends Agent { sessionId: string, chatHistory: ConversationMessage[], additionalParams?: Record - ): Promise { - // Check if input is a number - if (!isNaN(Number(inputText))) { - return { - role: ParticipantRole.ASSISTANT, - content: [{ text: inputText }], + ): Promise> { + try { + // Check if input is a number + if (!isNaN(Number(inputText))) { + return { + role: ParticipantRole.ASSISTANT, + content: [{ text: inputText }], + }; + } + + const userMessage: ConversationMessage = { + role: ParticipantRole.USER, + content: [{ text: `${inputText}` }], }; + + let systemPrompt = `You are a translator. Translate the text within the tags`; + if (this.sourceLanguage) { + systemPrompt += ` from ${this.sourceLanguage} to ${this.targetLanguage}`; + } else { + systemPrompt += ` to ${this.targetLanguage}`; + } + systemPrompt += `. Only provide the translation using the Translate tool.`; + + const converseCmd = { + modelId: this.modelId, + messages: [userMessage], + system: [{ text: systemPrompt }], + toolConfig: { + tools: this.tools, + toolChoice: { + tool: { + name: "Translate", + }, + }, + }, + inferenceConfiguration: { + maximumTokens: this.inferenceConfig.maxTokens, + temperature: this.inferenceConfig.temperature, + topP: this.inferenceConfig.topP, + stopSequences: this.inferenceConfig.stopSequences, + }, + }; + + if (this.streaming) { + return this.handleStreamingResponse(converseCmd); + } else { + return this.handleSingleResponse(converseCmd); + } + } catch (error) { + Logger.logger.error("Error processing translation request:", error); + return this.createErrorResponse("An error occurred while processing your translation request.", error); } + } - const userMessage: ConversationMessage = { - role: ParticipantRole.USER, - content: [{ text: `${inputText}` }], - }; + private async handleSingleResponse(input: any): Promise { + const command = new ConverseCommand(input); + const response = await this.client.send(command); - let systemPrompt = `You are a translator. Translate the text within the tags`; - if (this.sourceLanguage) { - systemPrompt += ` from ${this.sourceLanguage} to ${this.targetLanguage}`; - } else { - systemPrompt += ` to ${this.targetLanguage}`; + if (!response.output) { + throw new Error("No output received from Bedrock model"); } - systemPrompt += `. Only provide the translation using the Translate tool.`; - - const converseCmd = { - modelId: this.modelId, - messages: [userMessage], - system: [{ text: systemPrompt }], - toolConfig: { - tools: this.tools, - toolChoice: { - tool: { - name: "Translate", - }, - }, - }, - inferenceConfiguration: { - maximumTokens: this.inferenceConfig.maxTokens, - temperature: this.inferenceConfig.temperature, - topP: this.inferenceConfig.topP, - stopSequences: this.inferenceConfig.stopSequences, - }, - }; + if (response.output.message.content) { + const responseContentBlocks = response.output.message.content as ContentBlock[]; + + for (const contentBlock of responseContentBlocks) { + if ("toolUse" in contentBlock) { + const toolUse = contentBlock.toolUse; + if (!toolUse) { + throw new Error("No tool use found in the response"); + } - try { - const command = new ConverseCommand(converseCmd); - const response = await this.client.send(command); + if (!isToolInput(toolUse.input)) { + throw new Error("Tool input does not match expected structure"); + } + + if (typeof toolUse.input.translation !== 'string') { + throw new Error("Translation is not a string"); + } - if (!response.output) { - throw new Error("No output received from Bedrock model"); + return { + role: ParticipantRole.ASSISTANT, + content: [{ text: toolUse.input.translation }], + }; + } } - if (response.output.message.content) { - const responseContentBlocks = response.output.message - .content as ContentBlock[]; - - for (const contentBlock of responseContentBlocks) { - if ("toolUse" in contentBlock) { - const toolUse = contentBlock.toolUse; - if (!toolUse) { - throw new Error("No tool use found in the response"); - } - - if (!isToolInput(toolUse.input)) { - throw new Error("Tool input does not match expected structure"); - } - - if (typeof toolUse.input.translation !== 'string') { - throw new Error("Translation is not a string"); - } - - return { - role: ParticipantRole.ASSISTANT, - content: [{ text: toolUse.input.translation }], - }; + } + + throw new Error("No valid tool use found in the response"); + } + + private async *handleStreamingResponse(input: any): AsyncIterable { + try { + const command = new ConverseStreamCommand(input); + const response = await this.client.send(command); + let translation = ""; + + for await (const chunk of response.stream) { + if (chunk.contentBlockDelta && chunk.contentBlockDelta.delta && chunk.contentBlockDelta.delta.toolUse) { + const toolUse = chunk.contentBlockDelta.delta.toolUse; + if (toolUse.input && isToolInput(toolUse.input)) { + translation = toolUse.input.translation; + yield translation; } } } - throw new Error("No valid tool use found in the response"); + if (!translation) { + throw new Error("No valid translation found in the streaming response"); + } } catch (error) { - Logger.logger.error("Error processing translation request:", error); - throw error; + Logger.logger.error("Error getting stream from Bedrock model:", error); + yield this.createErrorResponse("An error occurred while streaming the translation from the Bedrock model.", error).content[0].text; } } diff --git a/typescript/src/agents/chainAgent.ts b/typescript/src/agents/chainAgent.ts index 49e0533b..1a13ff96 100644 --- a/typescript/src/agents/chainAgent.ts +++ b/typescript/src/agents/chainAgent.ts @@ -21,7 +21,7 @@ export class ChainAgent extends Agent { } } -/** + /** * Processes a user request by sending it to the Amazon Bedrock agent for processing. * @param inputText - The user input as a string. * @param userId - The ID of the user sending the request. @@ -38,65 +38,68 @@ export class ChainAgent extends Agent { chatHistory: ConversationMessage[], additionalParams?: Record ): Promise> { + try { + let currentInput = inputText; + let finalResponse: ConversationMessage | AsyncIterable; - let currentInput = inputText; - let finalResponse: ConversationMessage | AsyncIterable; - - console.log(`Processing chain with ${this.agents.length} agents`); - - for (let i = 0; i < this.agents.length; i++) { - const isLastAgent = i === this.agents.length - 1; - const agent = this.agents[i]; - - try { - console.log(`Input for agent ${i}: ${currentInput}`); - const response = await agent.processRequest( - currentInput, - userId, - sessionId, - chatHistory, - additionalParams - ); - - if (this.isConversationMessage(response)) { - if (response.content.length > 0 && 'text' in response.content[0]) { - currentInput = response.content[0].text; + Logger.logger.info(`Processing chain with ${this.agents.length} agents`); + + for (let i = 0; i < this.agents.length; i++) { + const isLastAgent = i === this.agents.length - 1; + const agent = this.agents[i]; + + try { + Logger.logger.debug(`Input for agent ${i}: ${currentInput}`); + const response = await agent.processRequest( + currentInput, + userId, + sessionId, + chatHistory, + additionalParams + ); + + if (this.isConversationMessage(response)) { + if (response.content.length > 0 && 'text' in response.content[0]) { + currentInput = response.content[0].text; + finalResponse = response; + Logger.logger.debug(`Output from agent ${i}: ${currentInput}`); + } else { + Logger.logger.warn(`Agent ${agent.name} returned no text content.`); + return this.createErrorResponse(`Agent ${agent.name} returned no text content.`); + } + } else if (this.isAsyncIterable(response)) { + if (!isLastAgent) { + Logger.logger.warn(`Intermediate agent ${agent.name} returned a streaming response, which is not allowed.`); + return this.createErrorResponse(`Intermediate agent ${agent.name} returned an unexpected streaming response.`); + } + // It's the last agent and streaming is allowed finalResponse = response; - console.log(`Output from agent ${i}: ${currentInput}`); } else { - Logger.logger.warn(`Agent ${agent.name} returned no text content.`); - return this.createDefaultResponse(); + Logger.logger.warn(`Agent ${agent.name} returned an invalid response type.`); + return this.createErrorResponse(`Agent ${agent.name} returned an invalid response type.`); } - } else if (this.isAsyncIterable(response)) { - if (!isLastAgent) { - Logger.logger.warn(`Intermediate agent ${agent.name} returned a streaming response, which is not allowed.`); - return this.createDefaultResponse(); + + // If it's not the last agent, ensure we have a non-streaming response to pass to the next agent + if (!isLastAgent && !this.isConversationMessage(finalResponse)) { + Logger.logger.error(`Expected non-streaming response from intermediate agent ${agent.name}`); + return this.createErrorResponse(`Unexpected streaming response from intermediate agent ${agent.name}.`); } - // It's the last agent and streaming is allowed - finalResponse = response; - } else { - Logger.logger.warn(`Agent ${agent.name} returned an invalid response type.`); - return this.createDefaultResponse(); - } - - // If it's not the last agent, ensure we have a non-streaming response to pass to the next agent - if (!isLastAgent && !this.isConversationMessage(finalResponse)) { - Logger.logger.error(`Expected non-streaming response from intermediate agent ${agent.name}`); - return this.createDefaultResponse(); + } catch (error) { + Logger.logger.error(`Error processing request with agent ${agent.name}:`, error); + return this.createErrorResponse(`Error processing request with agent ${agent.name}.`, error); } - } catch (error) { - Logger.logger.error(`Error processing request with agent ${agent.name}:`, error); - return this.createDefaultResponse(); } + + return finalResponse; + } catch (error) { + Logger.logger.error("Error in ChainAgent.processRequest:", error); + return this.createErrorResponse("An error occurred while processing the chain of agents.", error); } - - return finalResponse; } private isAsyncIterable(obj: any): obj is AsyncIterable { return obj && typeof obj[Symbol.asyncIterator] === 'function'; } - private isConversationMessage(response: any): response is ConversationMessage { return response && 'role' in response && 'content' in response && Array.isArray(response.content); diff --git a/typescript/src/agents/comprehendFilterAgent.ts b/typescript/src/agents/comprehendFilterAgent.ts index ae776cfc..b6c5095f 100644 --- a/typescript/src/agents/comprehendFilterAgent.ts +++ b/typescript/src/agents/comprehendFilterAgent.ts @@ -38,13 +38,6 @@ export interface ComprehendFilterAgentOptions extends AgentOptions { languageCode?: LanguageCode; } -/** - * ComprehendContentFilterAgent class - * - * This agent uses Amazon Comprehend to analyze and filter content based on - * sentiment, PII, and toxicity. It can be configured to enable/disable specific - * checks and allows for the addition of custom checks. - */ export class ComprehendFilterAgent extends Agent { private comprehendClient: ComprehendClient; private customChecks: CheckFunction[] = []; @@ -57,10 +50,6 @@ export class ComprehendFilterAgent extends Agent { private allowPii: boolean; private languageCode: LanguageCode; - /** - * Constructor for ComprehendContentFilterAgent - * @param options - Configuration options for the agent - */ constructor(options: ComprehendFilterAgentOptions) { super(options); @@ -68,7 +57,6 @@ export class ComprehendFilterAgent extends Agent { ? new ComprehendClient({ region: options.region }) : new ComprehendClient(); - // Set default configuration using fields from options this.enableSentimentCheck = options.enableSentimentCheck ?? true; this.enablePiiCheck = options.enablePiiCheck ?? true; this.enableToxicityCheck = options.enableToxicityCheck ?? true; @@ -77,7 +65,6 @@ export class ComprehendFilterAgent extends Agent { this.allowPii = options.allowPii ?? false; this.languageCode = this.validateLanguageCode(options.languageCode) ?? 'en'; - // Ensure at least one check is enabled if (!this.enableSentimentCheck && !this.enablePiiCheck && !this.enableToxicityCheck) { @@ -105,14 +92,12 @@ export class ComprehendFilterAgent extends Agent { try { const issues: string[] = []; - // Run all checks in parallel const [sentimentResult, piiResult, toxicityResult] = await Promise.all([ this.enableSentimentCheck ? this.detectSentiment(inputText) : null, this.enablePiiCheck ? this.detectPiiEntities(inputText) : null, this.enableToxicityCheck ? this.detectToxicContent(inputText) : null ]); - // Process results if (this.enableSentimentCheck && sentimentResult) { const sentimentIssue = this.checkSentiment(sentimentResult); if (sentimentIssue) issues.push(sentimentIssue); @@ -128,7 +113,6 @@ export class ComprehendFilterAgent extends Agent { if (toxicityIssue) issues.push(toxicityIssue); } - // Run custom checks for (const check of this.customChecks) { const customIssue = await check(inputText); if (customIssue) issues.push(customIssue); @@ -136,10 +120,9 @@ export class ComprehendFilterAgent extends Agent { if (issues.length > 0) { Logger.logger.warn(`Content filter issues detected: ${issues.join('; ')}`); - return null; // Return null to indicate content should not be processed further + return this.createErrorResponse("Content filter issues detected", new Error(issues.join('; '))); } - // If no issues, return the original input as a ConversationMessage return { role: ParticipantRole.ASSISTANT, content: [{ text: inputText }] @@ -147,23 +130,14 @@ export class ComprehendFilterAgent extends Agent { } catch (error) { Logger.logger.error("Error in ComprehendContentFilterAgent:", error); - throw error; + return this.createErrorResponse("An error occurred while processing your request", error); } } - /** - * Add a custom check function to the agent - * @param check - A function that takes a string input and returns a Promise - */ addCustomCheck(check: CheckFunction) { this.customChecks.push(check); } - /** - * Check sentiment of the input text - * @param result - Result from Comprehend's sentiment detection - * @returns A string describing the issue if sentiment is negative, null otherwise - */ private checkSentiment(result: DetectSentimentCommandOutput): string | null { if (result.Sentiment === 'NEGATIVE' && result.SentimentScore?.Negative > this.sentimentThreshold) { @@ -172,11 +146,6 @@ export class ComprehendFilterAgent extends Agent { return null; } - /** - * Check for PII in the input text - * @param result - Result from Comprehend's PII detection - * @returns A string describing the issue if PII is detected, null otherwise - */ private checkPii(result: DetectPiiEntitiesCommandOutput): string | null { if (!this.allowPii && result.Entities && result.Entities.length > 0) { return `PII detected: ${result.Entities.map(e => e.Type).join(', ')}`; @@ -184,11 +153,6 @@ export class ComprehendFilterAgent extends Agent { return null; } - /** - * Check for toxic content in the input text - * @param result - Result from Comprehend's toxic content detection - * @returns A string describing the issue if toxic content is detected, null otherwise - */ private checkToxicity(result: DetectToxicContentCommandOutput): string | null { const toxicLabels = this.getToxicLabels(result); if (toxicLabels.length > 0) { @@ -197,10 +161,6 @@ export class ComprehendFilterAgent extends Agent { return null; } - /** - * Detect sentiment using Amazon Comprehend - * @param text - Input text to analyze - */ private async detectSentiment(text: string) { const command = new DetectSentimentCommand({ Text: text, @@ -209,10 +169,6 @@ export class ComprehendFilterAgent extends Agent { return this.comprehendClient.send(command); } - /** - * Detect PII entities using Amazon Comprehend - * @param text - Input text to analyze - */ private async detectPiiEntities(text: string) { const command = new DetectPiiEntitiesCommand({ Text: text, @@ -221,10 +177,6 @@ export class ComprehendFilterAgent extends Agent { return this.comprehendClient.send(command); } - /** - * Detect toxic content using Amazon Comprehend - * @param text - Input text to analyze - */ private async detectToxicContent(text: string) { const command = new DetectToxicContentCommand({ TextSegments: [{ Text: text }], @@ -233,11 +185,6 @@ export class ComprehendFilterAgent extends Agent { return this.comprehendClient.send(command); } - /** - * Extract toxic labels from the Comprehend response - * @param toxicityResult - Result from Comprehend's toxic content detection - * @returns Array of toxic label names that exceed the threshold - */ private getToxicLabels(toxicityResult: DetectToxicContentCommandOutput): string[] { const toxicLabels: string[] = []; @@ -256,10 +203,6 @@ export class ComprehendFilterAgent extends Agent { return toxicLabels; } - /** - * Set the language code for Comprehend operations - * @param languageCode - The ISO 639-1 language code - */ setLanguageCode(languageCode: LanguageCode): void { const validatedLanguageCode = this.validateLanguageCode(languageCode); if (validatedLanguageCode) { @@ -269,18 +212,13 @@ export class ComprehendFilterAgent extends Agent { } } - /** - * Validate the provided language code - * @param languageCode - The language code to validate - * @returns The validated LanguageCode or undefined if invalid - */ private validateLanguageCode(languageCode: LanguageCode | undefined): LanguageCode | undefined { if (!languageCode) return undefined; - + const validLanguageCodes: LanguageCode[] = [ 'en', 'es', 'fr', 'de', 'it', 'pt', 'ar', 'hi', 'ja', 'ko', 'zh', 'zh-TW' ]; - + return validLanguageCodes.includes(languageCode) ? languageCode : undefined; } } \ No newline at end of file diff --git a/typescript/src/agents/lambdaAgent.ts b/typescript/src/agents/lambdaAgent.ts index dba82459..73260d7d 100644 --- a/typescript/src/agents/lambdaAgent.ts +++ b/typescript/src/agents/lambdaAgent.ts @@ -1,62 +1,85 @@ -import { ConversationMessage, ParticipantRole } from "../types"; +import { ConversationMessage, ParticipantRole } from "../types"; import { Agent, AgentOptions } from "./agent"; import { LambdaClient, InvokeCommand } from "@aws-sdk/client-lambda"; +import { Logger } from "../utils/logger"; export interface LambdaAgentOptions extends AgentOptions { - functionName: string; - functionRegion: string; - inputPayloadEncoder?: (inputText: string, ...additionalParams: any) => any; - outputPayloadDecoder?: (response: any) => ConversationMessage; + functionName: string; + functionRegion: string; + inputPayloadEncoder?: (inputText: string, ...additionalParams: any) => any; + outputPayloadDecoder?: (response: any) => ConversationMessage; } export class LambdaAgent extends Agent { - private options: LambdaAgentOptions; - private lambdaClient: LambdaClient; + private options: LambdaAgentOptions; + private lambdaClient: LambdaClient; - constructor(options: LambdaAgentOptions) { - super(options); - this.options = options; - this.lambdaClient = new LambdaClient({region:this.options.functionRegion}); - } + constructor(options: LambdaAgentOptions) { + super(options); + this.options = options; + this.lambdaClient = new LambdaClient({region:this.options.functionRegion}); + } - private defaultInputPayloadEncoder(inputText: string, chatHistory: ConversationMessage[], userId: string, sessionId:string, additionalParams?: Record):string { - return JSON.stringify({ - query: inputText, - chatHistory: chatHistory, - additionalParams: additionalParams, - userId: userId, - sessionId: sessionId, - }); - } + private defaultInputPayloadEncoder(inputText: string, chatHistory: ConversationMessage[], userId: string, sessionId:string, additionalParams?: Record):string { + return JSON.stringify({ + query: inputText, + chatHistory: chatHistory, + additionalParams: additionalParams, + userId: userId, + sessionId: sessionId, + }); + } - private defaultOutputPayloaderDecoder(response: any): ConversationMessage { - const decodedResponse = JSON.parse(JSON.parse(new TextDecoder("utf-8").decode(response.Payload)).body).response; - const message: ConversationMessage = { - role: ParticipantRole.ASSISTANT, - content: [{ text: `${decodedResponse}` }] - }; - return message; - } + private defaultOutputPayloaderDecoder(response: any): ConversationMessage { + const decodedResponse = JSON.parse(JSON.parse(new TextDecoder("utf-8").decode(response.Payload)).body).response; + const message: ConversationMessage = { + role: ParticipantRole.ASSISTANT, + content: [{ text: `${decodedResponse}` }] + }; + return message; + } + +/** + * Processes a user request by sending it to the Amazon Bedrock agent for processing. + * @param inputText - The user input as a string. + * @param userId - The ID of the user sending the request. + * @param sessionId - The ID of the session associated with the conversation. + * @param chatHistory - An array of Message objects representing the conversation history. + * @param additionalParams - Optional additional parameters as key-value pairs. + * @returns A Promise that resolves to a Message object containing the agent's response. + */ +/* eslint-disable @typescript-eslint/no-unused-vars */ + async processRequest( + inputText: string, + userId: string, + sessionId: string, + chatHistory: ConversationMessage[], + additionalParams?: Record + ): Promise { + try { + const payload = this.options.inputPayloadEncoder + ? this.options.inputPayloadEncoder(inputText, chatHistory, userId, sessionId, additionalParams) + : this.defaultInputPayloadEncoder(inputText, chatHistory, userId, sessionId, additionalParams); + + const invokeParams = { + FunctionName: this.options.functionName, + Payload: payload, + }; - async processRequest( - inputText: string, - userId: string, - sessionId: string, - chatHistory: ConversationMessage[], - additionalParams?: Record - ): Promise{ - - const payload = this.options.inputPayloadEncoder ? this.options.inputPayloadEncoder(inputText, chatHistory, userId, sessionId, additionalParams) : this.defaultInputPayloadEncoder(inputText, chatHistory, userId, sessionId, additionalParams); - const invokeParams = { - FunctionName: this.options.functionName, - Payload: payload, - }; - - const response = await this.lambdaClient.send(new InvokeCommand(invokeParams)); - - return new Promise((resolve) => { - const message = this.options.outputPayloadDecoder ? this.options.outputPayloadDecoder(response) : this.defaultOutputPayloaderDecoder(response); - resolve(message); - }); + const response = await this.lambdaClient.send(new InvokeCommand(invokeParams)); + + if (response.FunctionError) { + throw new Error(`Lambda function returned an error: ${response.FunctionError}`); } -} + + const message = this.options.outputPayloadDecoder + ? this.options.outputPayloadDecoder(response) + : this.defaultOutputPayloaderDecoder(response); + + return message; + } catch (error) { + Logger.logger.error(`Error in LambdaAgent.processRequest for function ${this.options.functionName}:`, error); + return this.createErrorResponse("An error occurred while processing your request with the Lambda function.", error); + } + } +} \ No newline at end of file diff --git a/typescript/src/agents/lexBotAgent.ts b/typescript/src/agents/lexBotAgent.ts index 6b9ac1b3..3a44fa1e 100644 --- a/typescript/src/agents/lexBotAgent.ts +++ b/typescript/src/agents/lexBotAgent.ts @@ -37,7 +37,6 @@ export class LexBotAgent extends Agent { this.botId = options.botId; this.botAliasId = options.botAliasId; this.localeId = options.localeId; - // Validate required fields if (!this.botId || !this.botAliasId || !this.localeId) { throw new Error("botId, botAliasId, and localeId are required for LexBotAgent"); @@ -45,13 +44,13 @@ export class LexBotAgent extends Agent { } /** - * Process a request to the Lex Bot. - * @param inputText - The user's input text - * @param userId - The ID of the user - * @param sessionId - The ID of the current session - * @param chatHistory - The history of the conversation - * @param additionalParams - Any additional parameters to include - * @returns A Promise resolving to a ConversationMessage containing the bot's response + * Processes a user request by sending it to the Amazon Bedrock agent for processing. + * @param inputText - The user input as a string. + * @param userId - The ID of the user sending the request. + * @param sessionId - The ID of the session associated with the conversation. + * @param chatHistory - An array of Message objects representing the conversation history. + * @param additionalParams - Optional additional parameters as key-value pairs. + * @returns A Promise that resolves to a Message object containing the agent's response. */ /* eslint-disable @typescript-eslint/no-unused-vars */ async processRequest( @@ -93,9 +92,9 @@ export class LexBotAgent extends Agent { content: [{ text: concatenatedContent || "No response from Lex bot." }], }; } catch (error) { - // Log the error and re-throw it - Logger.logger.error("Error processing request:", error); - throw error; + // Log the error and return an error response + Logger.logger.error("Error processing request in LexBotAgent:", error); + return this.createErrorResponse("An error occurred while processing your request with the Lex bot.", error); } } } \ No newline at end of file diff --git a/typescript/src/agents/openAIAgent.ts b/typescript/src/agents/openAIAgent.ts index 51785d3c..fa9bb741 100644 --- a/typescript/src/agents/openAIAgent.ts +++ b/typescript/src/agents/openAIAgent.ts @@ -36,47 +36,55 @@ export class OpenAIAgent extends Agent { this.inferenceConfig = { maxTokens: options.inferenceConfig?.maxTokens ?? DEFAULT_MAX_TOKENS, temperature: options.inferenceConfig?.temperature, - topP: options.inferenceConfig?.topP, + topP: options.inferenceConfig?.topP, stopSequences: options.inferenceConfig?.stopSequences, }; } - /* eslint-disable @typescript-eslint/no-unused-vars */ - async processRequest( +/** + * Processes a user request by sending it to the Amazon Bedrock agent for processing. + * @param inputText - The user input as a string. + * @param userId - The ID of the user sending the request. + * @param sessionId - The ID of the session associated with the conversation. + * @param chatHistory - An array of Message objects representing the conversation history. + * @param additionalParams - Optional additional parameters as key-value pairs. + * @returns A Promise that resolves to a Message object containing the agent's response. + */ +/* eslint-disable @typescript-eslint/no-unused-vars */ async processRequest( inputText: string, userId: string, sessionId: string, chatHistory: ConversationMessage[], additionalParams?: Record ): Promise> { - - - const messages = [ - ...chatHistory.map(msg => ({ - role: msg.role.toLowerCase() as OpenAI.Chat.ChatCompletionMessageParam['role'], - content: msg.content[0]?.text || '' - })), - { role: 'user' as const, content: inputText } - ] as OpenAI.Chat.ChatCompletionMessageParam[]; - - const { maxTokens, temperature, topP, stopSequences } = this.inferenceConfig; - - const requestOptions: OpenAI.Chat.ChatCompletionCreateParams = { - model: this.model, - messages: messages, - max_tokens: maxTokens, - stream: this.streaming, - temperature, - top_p: topP, - stop: stopSequences, - }; - - + try { + const messages = [ + ...chatHistory.map(msg => ({ + role: msg.role.toLowerCase() as OpenAI.Chat.ChatCompletionMessageParam['role'], + content: msg.content[0]?.text || '' + })), + { role: 'user' as const, content: inputText } + ] as OpenAI.Chat.ChatCompletionMessageParam[]; + + const { maxTokens, temperature, topP, stopSequences } = this.inferenceConfig; + const requestOptions: OpenAI.Chat.ChatCompletionCreateParams = { + model: this.model, + messages: messages, + max_tokens: maxTokens, + stream: this.streaming, + temperature, + top_p: topP, + stop: stopSequences, + }; - if (this.streaming) { - return this.handleStreamingResponse(requestOptions); - } else { - return this.handleSingleResponse(requestOptions); + if (this.streaming) { + return this.handleStreamingResponse(requestOptions); + } else { + return this.handleSingleResponse(requestOptions); + } + } catch (error) { + Logger.logger.error('Error in OpenAIAgent.processRequest:', error); + return this.createErrorResponse("An error occurred while processing your request.", error); } } @@ -84,41 +92,35 @@ export class OpenAIAgent extends Agent { try { const nonStreamingOptions = { ...input, stream: false }; const chatCompletion = await this.openai.chat.completions.create(nonStreamingOptions); - if (!chatCompletion.choices || chatCompletion.choices.length === 0) { throw new Error('No choices returned from OpenAI API'); } - const assistantMessage = chatCompletion.choices[0]?.message?.content; - if (typeof assistantMessage !== 'string') { throw new Error('Unexpected response format from OpenAI API'); } - return { role: ParticipantRole.ASSISTANT, content: [{ text: assistantMessage }], }; } catch (error) { - Logger.logger.error('Error in OpenAI API call:', error); - return { - role: ParticipantRole.ASSISTANT, - content: [{ text: 'I encountered an error while processing your request.' }], - }; + Logger.logger.error('Error in OpenAIAgent.handleSingleResponse:', error); + return this.createErrorResponse("An error occurred while processing your request with the OpenAI API.", error); } } private async *handleStreamingResponse(options: OpenAI.Chat.ChatCompletionCreateParams): AsyncIterable { - const stream = await this.openai.chat.completions.create({ ...options, stream: true }); - for await (const chunk of stream) { - const content = chunk.choices[0]?.delta?.content; - if (content) { - yield content; + try { + const stream = await this.openai.chat.completions.create({ ...options, stream: true }); + for await (const chunk of stream) { + const content = chunk.choices[0]?.delta?.content; + if (content) { + yield content; + } } + } catch (error) { + Logger.logger.error('Error in OpenAIAgent.handleStreamingResponse:', error); + yield this.createErrorResponse("An error occurred while streaming the response from the OpenAI API.", error).content[0].text; } } - - - - } \ No newline at end of file