diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py index dad3fc335edf..7cd7fdb92a35 100644 --- a/autogen/agentchat/contrib/capabilities/transforms.py +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -53,13 +53,16 @@ class MessageHistoryLimiter: It trims the conversation history by removing older messages, retaining only the most recent messages. """ - def __init__(self, max_messages: Optional[int] = None): + def __init__(self, max_messages: Optional[int] = None, keep_first_message: bool = False): """ Args: max_messages Optional[int]: Maximum number of messages to keep in the context. Must be greater than 0 if not None. + keep_first_message bool: Whether to keep the original first message in the conversation history. + Defaults to False. """ self._validate_max_messages(max_messages) self._max_messages = max_messages + self._keep_first_message = keep_first_message def apply_transform(self, messages: List[Dict]) -> List[Dict]: """Truncates the conversation history to the specified maximum number of messages. @@ -75,10 +78,31 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: List[Dict]: A new list containing the most recent messages up to the specified maximum. """ - if self._max_messages is None: + if self._max_messages is None or len(messages) <= self._max_messages: return messages - return messages[-self._max_messages :] + truncated_messages = [] + remaining_count = self._max_messages + + # Start with the first message if we need to keep it + if self._keep_first_message: + truncated_messages = [messages[0]] + remaining_count -= 1 + + # Loop through messages in reverse + for i in range(len(messages) - 1, 0, -1): + if remaining_count > 1: + truncated_messages.insert(1 if self._keep_first_message else 0, messages[i]) + if remaining_count == 1: + # If there's only 1 slot left and it's a 'tools' message, ignore it. + if messages[i].get("role") != "tool": + truncated_messages.insert(1, messages[i]) + + remaining_count -= 1 + if remaining_count == 0: + break + + return truncated_messages def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: pre_transform_messages_len = len(pre_transform_messages) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index a088c491082e..9254ef57de08 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -377,9 +377,9 @@ def replace_reply_func(self, old_reply_func: Callable, new_reply_func: Callable) f["reply_func"] = new_reply_func @staticmethod - def _summary_from_nested_chats( + def _get_chats_to_run( chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any - ) -> Tuple[bool, str]: + ) -> List[Dict[str, Any]]: """A simple chat reply function. This function initiate one or a sequence of chats between the "recipient" and the agents in the chat_queue. @@ -406,22 +406,59 @@ def _summary_from_nested_chats( if message: current_c["message"] = message chat_to_run.append(current_c) + return chat_to_run + + @staticmethod + def _summary_from_nested_chats( + chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any + ) -> Tuple[bool, Union[str, None]]: + """A simple chat reply function. + This function initiate one or a sequence of chats between the "recipient" and the agents in the + chat_queue. + + It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. + + Returns: + Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. + """ + chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config) if not chat_to_run: return True, None res = initiate_chats(chat_to_run) return True, res[-1].summary + @staticmethod + async def _a_summary_from_nested_chats( + chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any + ) -> Tuple[bool, Union[str, None]]: + """A simple chat reply function. + This function initiate one or a sequence of chats between the "recipient" and the agents in the + chat_queue. + + It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. + + Returns: + Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. + """ + chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config) + if not chat_to_run: + return True, None + res = await a_initiate_chats(chat_to_run) + index_of_last_chat = chat_to_run[-1]["chat_id"] + return True, res[index_of_last_chat].summary + def register_nested_chats( self, chat_queue: List[Dict[str, Any]], trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List], reply_func_from_nested_chats: Union[str, Callable] = "summary_from_nested_chats", position: int = 2, + use_async: Union[bool, None] = None, **kwargs, ) -> None: """Register a nested chat reply function. Args: - chat_queue (list): a list of chat objects to be initiated. + chat_queue (list): a list of chat objects to be initiated. If use_async is used, then all messages in chat_queue must have a chat-id associated with them. trigger (Agent class, str, Agent instance, callable, or list): refer to `register_reply` for details. reply_func_from_nested_chats (Callable, str): the reply function for the nested chat. The function takes a chat_queue for nested chat, recipient agent, a list of messages, a sender agent and a config as input and returns a reply message. @@ -436,15 +473,33 @@ def reply_func_from_nested_chats( ) -> Tuple[bool, Union[str, Dict, None]]: ``` position (int): Ref to `register_reply` for details. Default to 2. It means we first check the termination and human reply, then check the registered nested chat reply. + use_async: Uses a_initiate_chats internally to start nested chats. If the original chat is initiated with a_initiate_chats, you may set this to true so nested chats do not run in sync. kwargs: Ref to `register_reply` for details. """ - if reply_func_from_nested_chats == "summary_from_nested_chats": - reply_func_from_nested_chats = self._summary_from_nested_chats - if not callable(reply_func_from_nested_chats): - raise ValueError("reply_func_from_nested_chats must be a callable") + if use_async: + for chat in chat_queue: + if chat.get("chat_id") is None: + raise ValueError("chat_id is required for async nested chats") + + if use_async: + if reply_func_from_nested_chats == "summary_from_nested_chats": + reply_func_from_nested_chats = self._a_summary_from_nested_chats + if not callable(reply_func_from_nested_chats) or not inspect.iscoroutinefunction( + reply_func_from_nested_chats + ): + raise ValueError("reply_func_from_nested_chats must be a callable and a coroutine") - def wrapped_reply_func(recipient, messages=None, sender=None, config=None): - return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config) + async def wrapped_reply_func(recipient, messages=None, sender=None, config=None): + return await reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config) + + else: + if reply_func_from_nested_chats == "summary_from_nested_chats": + reply_func_from_nested_chats = self._summary_from_nested_chats + if not callable(reply_func_from_nested_chats): + raise ValueError("reply_func_from_nested_chats must be a callable") + + def wrapped_reply_func(recipient, messages=None, sender=None, config=None): + return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config) functools.update_wrapper(wrapped_reply_func, reply_func_from_nested_chats) @@ -454,7 +509,9 @@ def wrapped_reply_func(recipient, messages=None, sender=None, config=None): position, kwargs.get("config"), kwargs.get("reset_config"), - ignore_async_in_sync_chat=kwargs.get("ignore_async_in_sync_chat"), + ignore_async_in_sync_chat=( + not use_async if use_async is not None else kwargs.get("ignore_async_in_sync_chat") + ), ) @property diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index 41e1f7865224..bde31637688c 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -986,6 +986,7 @@ def __init__( # Store groupchat self._groupchat = groupchat + self._last_speaker = None self._silent = silent # Order of register_reply is important. @@ -1027,6 +1028,52 @@ def _prepare_chat( if (recipient != agent or prepare_recipient) and isinstance(agent, ConversableAgent): agent._prepare_chat(self, clear_history, False, reply_at_receive) + @property + def last_speaker(self) -> Agent: + """Return the agent who sent the last message to group chat manager. + + In a group chat, an agent will always send a message to the group chat manager, and the group chat manager will + send the message to all other agents in the group chat. So, when an agent receives a message, it will always be + from the group chat manager. With this property, the agent receiving the message can know who actually sent the + message. + + Example: + ```python + from autogen import ConversableAgent + from autogen import GroupChat, GroupChatManager + + + def print_messages(recipient, messages, sender, config): + # Print the message immediately + print( + f"Sender: {sender.name} | Recipient: {recipient.name} | Message: {messages[-1].get('content')}" + ) + print(f"Real Sender: {sender.last_speaker.name}") + assert sender.last_speaker.name in messages[-1].get("content") + return False, None # Required to ensure the agent communication flow continues + + + agent_a = ConversableAgent("agent A", default_auto_reply="I'm agent A.") + agent_b = ConversableAgent("agent B", default_auto_reply="I'm agent B.") + agent_c = ConversableAgent("agent C", default_auto_reply="I'm agent C.") + for agent in [agent_a, agent_b, agent_c]: + agent.register_reply( + [ConversableAgent, None], reply_func=print_messages, config=None + ) + group_chat = GroupChat( + [agent_a, agent_b, agent_c], + messages=[], + max_round=6, + speaker_selection_method="random", + allow_repeat_speaker=True, + ) + chat_manager = GroupChatManager(group_chat) + groupchat_result = agent_a.initiate_chat( + chat_manager, message="Hi, there, I'm agent A." + ) + """ + return self._last_speaker + def run_chat( self, messages: Optional[List[Dict]] = None, @@ -1055,6 +1102,7 @@ def run_chat( a.previous_cache = a.client_cache a.client_cache = self.client_cache for i in range(groupchat.max_round): + self._last_speaker = speaker groupchat.append(message, speaker) # broadcast the message to all agents except the speaker for agent in groupchat.agents: diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 4cc7c697f738..fb13afdfcc63 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -457,10 +457,13 @@ def _configure_azure_openai(self, config: Dict[str, Any], openai_config: Dict[st def _configure_openai_config_for_bedrock(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None: """Update openai_config with AWS credentials from config.""" required_keys = ["aws_access_key", "aws_secret_key", "aws_region"] - + optional_keys = ["aws_session_token"] for key in required_keys: if key in config: openai_config[key] = config[key] + for key in optional_keys: + if key in config: + openai_config[key] = config[key] def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None: """Create a client with the given config to override openai_config, diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln index 1218cf129821..0fcaf15ceb2a 100644 --- a/dotnet/AutoGen.sln +++ b/dotnet/AutoGen.sln @@ -68,6 +68,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI.Sample", "sa EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.WebAPI.Sample", "sample\AutoGen.WebAPI.Sample\AutoGen.WebAPI.Sample.csproj", "{12079C18-A519-403F-BBFD-200A36A0C083}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.AzureAIInference", "src\AutoGen.AzureAIInference\AutoGen.AzureAIInference.csproj", "{5C45981D-1319-4C25-935C-83D411CB28DF}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.AzureAIInference.Tests", "test\AutoGen.AzureAIInference.Tests\AutoGen.AzureAIInference.Tests.csproj", "{5970868F-831E-418F-89A9-4EC599563E16}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.Tests.Share", "test\AutoGen.Test.Share\AutoGen.Tests.Share.csproj", "{143725E2-206C-4D37-93E4-9EDF699826B2}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -194,6 +200,18 @@ Global {12079C18-A519-403F-BBFD-200A36A0C083}.Debug|Any CPU.Build.0 = Debug|Any CPU {12079C18-A519-403F-BBFD-200A36A0C083}.Release|Any CPU.ActiveCfg = Release|Any CPU {12079C18-A519-403F-BBFD-200A36A0C083}.Release|Any CPU.Build.0 = Release|Any CPU + {5C45981D-1319-4C25-935C-83D411CB28DF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {5C45981D-1319-4C25-935C-83D411CB28DF}.Debug|Any CPU.Build.0 = Debug|Any CPU + {5C45981D-1319-4C25-935C-83D411CB28DF}.Release|Any CPU.ActiveCfg = Release|Any CPU + {5C45981D-1319-4C25-935C-83D411CB28DF}.Release|Any CPU.Build.0 = Release|Any CPU + {5970868F-831E-418F-89A9-4EC599563E16}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {5970868F-831E-418F-89A9-4EC599563E16}.Debug|Any CPU.Build.0 = Debug|Any CPU + {5970868F-831E-418F-89A9-4EC599563E16}.Release|Any CPU.ActiveCfg = Release|Any CPU + {5970868F-831E-418F-89A9-4EC599563E16}.Release|Any CPU.Build.0 = Release|Any CPU + {143725E2-206C-4D37-93E4-9EDF699826B2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {143725E2-206C-4D37-93E4-9EDF699826B2}.Debug|Any CPU.Build.0 = Debug|Any CPU + {143725E2-206C-4D37-93E4-9EDF699826B2}.Release|Any CPU.ActiveCfg = Release|Any CPU + {143725E2-206C-4D37-93E4-9EDF699826B2}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -229,6 +247,9 @@ Global {6B82F26D-5040-4453-B21B-C8D1F913CE4C} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {0E635268-351C-4A6B-A28D-593D868C2CA4} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9} {12079C18-A519-403F-BBFD-200A36A0C083} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9} + {5C45981D-1319-4C25-935C-83D411CB28DF} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} + {5970868F-831E-418F-89A9-4EC599563E16} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} + {143725E2-206C-4D37-93E4-9EDF699826B2} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B} diff --git a/dotnet/eng/Version.props b/dotnet/eng/Version.props index c78ce4b415fc..d90e8bc76c80 100644 --- a/dotnet/eng/Version.props +++ b/dotnet/eng/Version.props @@ -15,6 +15,7 @@ 8.0.4 3.0.0 4.3.0.2 + 1.0.0-beta.1 7.4.4 \ No newline at end of file diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/RunCodeSnippetCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/RunCodeSnippetCodeSnippet.cs index a1e110bcc6a5..b087beb993bc 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/RunCodeSnippetCodeSnippet.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/RunCodeSnippetCodeSnippet.cs @@ -16,7 +16,7 @@ public async Task CodeSnippet1() #region code_snippet_1_1 var kernel = DotnetInteractiveKernelBuilder - .CreateDefaultBuilder() // add C# and F# kernels + .CreateDefaultInProcessKernelBuilder() // add C# and F# kernels .Build(); #endregion code_snippet_1_1 @@ -67,7 +67,7 @@ public async Task CodeSnippet1() #region code_snippet_1_4 var pythonKernel = DotnetInteractiveKernelBuilder - .CreateDefaultBuilder() + .CreateDefaultInProcessKernelBuilder() .AddPythonKernel(venv: "python3") .Build(); diff --git a/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs b/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs index c9e8a0cab155..08419d436e03 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs @@ -16,7 +16,7 @@ public static async Task RunAsync() var instance = new Example04_Dynamic_GroupChat_Coding_Task(); var kernel = DotnetInteractiveKernelBuilder - .CreateDefaultBuilder() + .CreateDefaultInProcessKernelBuilder() .AddPythonKernel("python3") .Build(); diff --git a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs index d78bb7656ae6..cc9b2a80a340 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs @@ -233,7 +233,9 @@ public static async Task CreateReviewerAgentAsync(OpenAIClient openAICli public static async Task RunWorkflowAsync() { long the39thFibonacciNumber = 63245986; - var kernel = DotnetInteractiveKernelBuilder.CreateDefaultBuilder().Build(); + var kernel = DotnetInteractiveKernelBuilder + .CreateDefaultInProcessKernelBuilder() + .Build(); var config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo(); var openaiClient = new OpenAIClient(new Uri(config.Endpoint), new Azure.AzureKeyCredential(config.ApiKey)); @@ -344,7 +346,9 @@ public static async Task RunAsync() var config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo(); var openaiClient = new OpenAIClient(new Uri(config.Endpoint), new Azure.AzureKeyCredential(config.ApiKey)); - var kernel = DotnetInteractiveKernelBuilder.CreateDefaultBuilder().Build(); + var kernel = DotnetInteractiveKernelBuilder + .CreateDefaultInProcessKernelBuilder() + .Build(); #region create_group_chat var reviewer = await CreateReviewerAgentAsync(openaiClient, config.DeploymentName); var coder = await CreateCoderAgentAsync(openaiClient, config.DeploymentName); diff --git a/dotnet/src/AutoGen.AzureAIInference/Agent/ChatCompletionsClientAgent.cs b/dotnet/src/AutoGen.AzureAIInference/Agent/ChatCompletionsClientAgent.cs new file mode 100644 index 000000000000..452c5b1c3079 --- /dev/null +++ b/dotnet/src/AutoGen.AzureAIInference/Agent/ChatCompletionsClientAgent.cs @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatCompletionsClientAgent.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using AutoGen.AzureAIInference.Extension; +using AutoGen.Core; +using Azure.AI.Inference; + +namespace AutoGen.AzureAIInference; + +/// +/// ChatCompletions client agent. This agent is a thin wrapper around to provide a simple interface for chat completions. +/// supports the following message types: +/// +/// +/// where T is : chat request message. +/// +/// +/// returns the following message types: +/// +/// +/// where T is : chat response message. +/// where T is : streaming chat completions update. +/// +/// +/// +public class ChatCompletionsClientAgent : IStreamingAgent +{ + private readonly ChatCompletionsClient chatCompletionsClient; + private readonly ChatCompletionsOptions options; + private readonly string systemMessage; + + /// + /// Create a new instance of . + /// + /// chat completions client + /// agent name + /// model name. e.g. gpt-turbo-3.5 + /// system message + /// temperature + /// max tokens to generated + /// response format, set it to to enable json mode. + /// seed to use, set it to enable deterministic output + /// functions + public ChatCompletionsClientAgent( + ChatCompletionsClient chatCompletionsClient, + string name, + string modelName, + string systemMessage = "You are a helpful AI assistant", + float temperature = 0.7f, + int maxTokens = 1024, + int? seed = null, + ChatCompletionsResponseFormat? responseFormat = null, + IEnumerable? functions = null) + : this( + chatCompletionsClient: chatCompletionsClient, + name: name, + options: CreateChatCompletionOptions(modelName, temperature, maxTokens, seed, responseFormat, functions), + systemMessage: systemMessage) + { + } + + /// + /// Create a new instance of . + /// + /// chat completions client + /// agent name + /// system message + /// chat completion option. The option can't contain messages + public ChatCompletionsClientAgent( + ChatCompletionsClient chatCompletionsClient, + string name, + ChatCompletionsOptions options, + string systemMessage = "You are a helpful AI assistant") + { + if (options.Messages is { Count: > 0 }) + { + throw new ArgumentException("Messages should not be provided in options"); + } + + this.chatCompletionsClient = chatCompletionsClient; + this.Name = name; + this.options = options; + this.systemMessage = systemMessage; + } + + public string Name { get; } + + public async Task GenerateReplyAsync( + IEnumerable messages, + GenerateReplyOptions? options = null, + CancellationToken cancellationToken = default) + { + var settings = this.CreateChatCompletionsOptions(options, messages); + var reply = await this.chatCompletionsClient.CompleteAsync(settings, cancellationToken: cancellationToken); + + return new MessageEnvelope(reply, from: this.Name); + } + + public async IAsyncEnumerable GenerateStreamingReplyAsync( + IEnumerable messages, + GenerateReplyOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var settings = this.CreateChatCompletionsOptions(options, messages); + var response = await this.chatCompletionsClient.CompleteStreamingAsync(settings, cancellationToken); + await foreach (var update in response.WithCancellation(cancellationToken)) + { + yield return new MessageEnvelope(update, from: this.Name); + } + } + + private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions? options, IEnumerable messages) + { + var oaiMessages = messages.Select(m => m switch + { + IMessage chatRequestMessage => chatRequestMessage.Content, + _ => throw new ArgumentException("Invalid message type") + }); + + // add system message if there's no system message in messages + if (!oaiMessages.Any(m => m is ChatRequestSystemMessage)) + { + oaiMessages = new[] { new ChatRequestSystemMessage(systemMessage) }.Concat(oaiMessages); + } + + // clone the options by serializing and deserializing + var json = JsonSerializer.Serialize(this.options); + var settings = JsonSerializer.Deserialize(json) ?? throw new InvalidOperationException("Failed to clone options"); + + foreach (var m in oaiMessages) + { + settings.Messages.Add(m); + } + + settings.Temperature = options?.Temperature ?? settings.Temperature; + settings.MaxTokens = options?.MaxToken ?? settings.MaxTokens; + + foreach (var functions in this.options.Tools) + { + settings.Tools.Add(functions); + } + + foreach (var stopSequence in this.options.StopSequences) + { + settings.StopSequences.Add(stopSequence); + } + + var openAIFunctionDefinitions = options?.Functions?.Select(f => f.ToAzureAIInferenceFunctionDefinition()).ToList(); + if (openAIFunctionDefinitions is { Count: > 0 }) + { + foreach (var f in openAIFunctionDefinitions) + { + settings.Tools.Add(new ChatCompletionsFunctionToolDefinition(f)); + } + } + + if (options?.StopSequence is var sequence && sequence is { Length: > 0 }) + { + foreach (var seq in sequence) + { + settings.StopSequences.Add(seq); + } + } + + return settings; + } + + private static ChatCompletionsOptions CreateChatCompletionOptions( + string modelName, + float temperature = 0.7f, + int maxTokens = 1024, + int? seed = null, + ChatCompletionsResponseFormat? responseFormat = null, + IEnumerable? functions = null) + { + var options = new ChatCompletionsOptions() + { + Model = modelName, + Temperature = temperature, + MaxTokens = maxTokens, + Seed = seed, + ResponseFormat = responseFormat, + }; + + if (functions is not null) + { + foreach (var f in functions) + { + options.Tools.Add(new ChatCompletionsFunctionToolDefinition(f)); + } + } + + return options; + } +} diff --git a/dotnet/src/AutoGen.AzureAIInference/AutoGen.AzureAIInference.csproj b/dotnet/src/AutoGen.AzureAIInference/AutoGen.AzureAIInference.csproj new file mode 100644 index 000000000000..e9401bc4bc22 --- /dev/null +++ b/dotnet/src/AutoGen.AzureAIInference/AutoGen.AzureAIInference.csproj @@ -0,0 +1,25 @@ + + + $(PackageTargetFrameworks) + AutoGen.AzureAIInference + + + + + + + AutoGen.AzureAIInference + + Azure AI Inference Intergration for AutoGen. + + + + + + + + + + + + diff --git a/dotnet/src/AutoGen.AzureAIInference/Extension/ChatComptionClientAgentExtension.cs b/dotnet/src/AutoGen.AzureAIInference/Extension/ChatComptionClientAgentExtension.cs new file mode 100644 index 000000000000..8faf29604ed1 --- /dev/null +++ b/dotnet/src/AutoGen.AzureAIInference/Extension/ChatComptionClientAgentExtension.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatComptionClientAgentExtension.cs + +using AutoGen.Core; + +namespace AutoGen.AzureAIInference.Extension; + +public static class ChatComptionClientAgentExtension +{ + /// + /// Register an to the + /// + /// the connector to use. If null, a new instance of will be created. + public static MiddlewareStreamingAgent RegisterMessageConnector( + this ChatCompletionsClientAgent agent, AzureAIInferenceChatRequestMessageConnector? connector = null) + { + if (connector == null) + { + connector = new AzureAIInferenceChatRequestMessageConnector(); + } + + return agent.RegisterStreamingMiddleware(connector); + } + + /// + /// Register an to the where T is + /// + /// the connector to use. If null, a new instance of will be created. + public static MiddlewareStreamingAgent RegisterMessageConnector( + this MiddlewareStreamingAgent agent, AzureAIInferenceChatRequestMessageConnector? connector = null) + { + if (connector == null) + { + connector = new AzureAIInferenceChatRequestMessageConnector(); + } + + return agent.RegisterStreamingMiddleware(connector); + } +} diff --git a/dotnet/src/AutoGen.AzureAIInference/Extension/FunctionContractExtension.cs b/dotnet/src/AutoGen.AzureAIInference/Extension/FunctionContractExtension.cs new file mode 100644 index 000000000000..4cd7b3864f95 --- /dev/null +++ b/dotnet/src/AutoGen.AzureAIInference/Extension/FunctionContractExtension.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// FunctionContractExtension.cs + +using System; +using System.Collections.Generic; +using AutoGen.Core; +using Azure.AI.Inference; +using Json.Schema; +using Json.Schema.Generation; + +namespace AutoGen.AzureAIInference.Extension; + +public static class FunctionContractExtension +{ + /// + /// Convert a to a that can be used in gpt funciton call. + /// + /// function contract + /// + public static FunctionDefinition ToAzureAIInferenceFunctionDefinition(this FunctionContract functionContract) + { + var functionDefinition = new FunctionDefinition + { + Name = functionContract.Name, + Description = functionContract.Description, + }; + var requiredParameterNames = new List(); + var propertiesSchemas = new Dictionary(); + var propertySchemaBuilder = new JsonSchemaBuilder().Type(SchemaValueType.Object); + foreach (var param in functionContract.Parameters ?? []) + { + if (param.Name is null) + { + throw new InvalidOperationException("Parameter name cannot be null"); + } + + var schemaBuilder = new JsonSchemaBuilder().FromType(param.ParameterType ?? throw new ArgumentNullException(nameof(param.ParameterType))); + if (param.Description != null) + { + schemaBuilder = schemaBuilder.Description(param.Description); + } + + if (param.IsRequired) + { + requiredParameterNames.Add(param.Name); + } + + var schema = schemaBuilder.Build(); + propertiesSchemas[param.Name] = schema; + + } + propertySchemaBuilder = propertySchemaBuilder.Properties(propertiesSchemas); + propertySchemaBuilder = propertySchemaBuilder.Required(requiredParameterNames); + + var option = new System.Text.Json.JsonSerializerOptions() + { + PropertyNamingPolicy = System.Text.Json.JsonNamingPolicy.CamelCase + }; + + functionDefinition.Parameters = BinaryData.FromObjectAsJson(propertySchemaBuilder.Build(), option); + + return functionDefinition; + } +} diff --git a/dotnet/src/AutoGen.AzureAIInference/Middleware/AzureAIInferenceChatRequestMessageConnector.cs b/dotnet/src/AutoGen.AzureAIInference/Middleware/AzureAIInferenceChatRequestMessageConnector.cs new file mode 100644 index 000000000000..9c5d22e2e7e7 --- /dev/null +++ b/dotnet/src/AutoGen.AzureAIInference/Middleware/AzureAIInferenceChatRequestMessageConnector.cs @@ -0,0 +1,302 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AzureAIInferenceChatRequestMessageConnector.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using AutoGen.Core; +using Azure.AI.Inference; + +namespace AutoGen.AzureAIInference; + +/// +/// This middleware converts the incoming to where T is before sending to agent. And converts the output to after receiving from agent. +/// Supported are +/// - +/// - +/// - +/// - +/// - +/// - where T is +/// - where TMessage1 is and TMessage2 is +/// +public class AzureAIInferenceChatRequestMessageConnector : IStreamingMiddleware +{ + private bool strictMode = false; + + /// + /// Create a new instance of . + /// + /// If true, will throw an + /// When the message type is not supported. If false, it will ignore the unsupported message type. + public AzureAIInferenceChatRequestMessageConnector(bool strictMode = false) + { + this.strictMode = strictMode; + } + + public string? Name => nameof(AzureAIInferenceChatRequestMessageConnector); + + public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) + { + var chatMessages = ProcessIncomingMessages(agent, context.Messages); + + var reply = await agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken); + + return PostProcessMessage(reply); + } + + public async IAsyncEnumerable InvokeAsync( + MiddlewareContext context, + IStreamingAgent agent, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var chatMessages = ProcessIncomingMessages(agent, context.Messages); + var streamingReply = agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken); + string? currentToolName = null; + await foreach (var reply in streamingReply) + { + if (reply is IMessage update) + { + if (update.Content.FunctionName is string functionName) + { + currentToolName = functionName; + } + else if (update.Content.ToolCallUpdate is StreamingFunctionToolCallUpdate toolCallUpdate && toolCallUpdate.Name is string toolCallName) + { + currentToolName = toolCallName; + } + var postProcessMessage = PostProcessStreamingMessage(update, currentToolName); + if (postProcessMessage != null) + { + yield return postProcessMessage; + } + } + else + { + if (this.strictMode) + { + throw new InvalidOperationException($"Invalid streaming message type {reply.GetType().Name}"); + } + else + { + yield return reply; + } + } + } + } + + public IMessage PostProcessMessage(IMessage message) + { + return message switch + { + IMessage m => PostProcessChatResponseMessage(m.Content, m.From), + IMessage m => PostProcessChatCompletions(m), + _ when strictMode is false => message, + _ => throw new InvalidOperationException($"Invalid return message type {message.GetType().Name}"), + }; + } + + public IMessage? PostProcessStreamingMessage(IMessage update, string? currentToolName) + { + if (update.Content.ContentUpdate is string contentUpdate && string.IsNullOrEmpty(contentUpdate) == false) + { + // text message + return new TextMessageUpdate(Role.Assistant, contentUpdate, from: update.From); + } + else if (update.Content.FunctionName is string functionName) + { + return new ToolCallMessageUpdate(functionName, string.Empty, from: update.From); + } + else if (update.Content.FunctionArgumentsUpdate is string functionArgumentsUpdate && currentToolName is string) + { + return new ToolCallMessageUpdate(currentToolName, functionArgumentsUpdate, from: update.From); + } + else if (update.Content.ToolCallUpdate is StreamingFunctionToolCallUpdate tooCallUpdate && currentToolName is string) + { + return new ToolCallMessageUpdate(tooCallUpdate.Name ?? currentToolName, tooCallUpdate.ArgumentsUpdate, from: update.From); + } + else + { + return null; + } + } + + private IMessage PostProcessChatCompletions(IMessage message) + { + // throw exception if prompt filter results is not null + if (message.Content.Choices[0].FinishReason == CompletionsFinishReason.ContentFiltered) + { + throw new InvalidOperationException("The content is filtered because its potential risk. Please try another input."); + } + + return PostProcessChatResponseMessage(message.Content.Choices[0].Message, message.From); + } + + private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponseMessage, string? from) + { + var textContent = chatResponseMessage.Content; + if (chatResponseMessage.ToolCalls.Where(tc => tc is ChatCompletionsFunctionToolCall).Any()) + { + var functionToolCalls = chatResponseMessage.ToolCalls + .Where(tc => tc is ChatCompletionsFunctionToolCall) + .Select(tc => (ChatCompletionsFunctionToolCall)tc); + + var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments) { ToolCallId = tc.Id }); + + return new ToolCallMessage(toolCalls, from) + { + Content = textContent, + }; + } + + if (textContent is string content && !string.IsNullOrEmpty(content)) + { + return new TextMessage(Role.Assistant, content, from); + } + + throw new InvalidOperationException("Invalid ChatResponseMessage"); + } + + public IEnumerable ProcessIncomingMessages(IAgent agent, IEnumerable messages) + { + return messages.SelectMany(m => + { + if (m is IMessage crm) + { + return [crm]; + } + else + { + var chatRequestMessages = m switch + { + TextMessage textMessage => ProcessTextMessage(agent, textMessage), + ImageMessage imageMessage when (imageMessage.From is null || imageMessage.From != agent.Name) => ProcessImageMessage(agent, imageMessage), + MultiModalMessage multiModalMessage when (multiModalMessage.From is null || multiModalMessage.From != agent.Name) => ProcessMultiModalMessage(agent, multiModalMessage), + ToolCallMessage toolCallMessage when (toolCallMessage.From is null || toolCallMessage.From == agent.Name) => ProcessToolCallMessage(agent, toolCallMessage), + ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage), + AggregateMessage aggregateMessage => ProcessFunctionCallMiddlewareMessage(agent, aggregateMessage), + _ when strictMode is false => [], + _ => throw new InvalidOperationException($"Invalid message type: {m.GetType().Name}"), + }; + + if (chatRequestMessages.Any()) + { + return chatRequestMessages.Select(cm => MessageEnvelope.Create(cm, m.From)); + } + else + { + return [m]; + } + } + }); + } + + private IEnumerable ProcessTextMessage(IAgent agent, TextMessage message) + { + if (message.Role == Role.System) + { + return [new ChatRequestSystemMessage(message.Content)]; + } + + if (agent.Name == message.From) + { + return [new ChatRequestAssistantMessage { Content = message.Content }]; + } + else + { + return message.From switch + { + null when message.Role == Role.User => [new ChatRequestUserMessage(message.Content)], + null when message.Role == Role.Assistant => [new ChatRequestAssistantMessage() { Content = message.Content }], + null => throw new InvalidOperationException("Invalid Role"), + _ => [new ChatRequestUserMessage(message.Content)] + }; + } + } + + private IEnumerable ProcessImageMessage(IAgent agent, ImageMessage message) + { + if (agent.Name == message.From) + { + // image message from assistant is not supported + throw new ArgumentException("ImageMessage is not supported when message.From is the same with agent"); + } + + var imageContentItem = this.CreateChatMessageImageContentItemFromImageMessage(message); + return [new ChatRequestUserMessage([imageContentItem])]; + } + + private IEnumerable ProcessMultiModalMessage(IAgent agent, MultiModalMessage message) + { + if (agent.Name == message.From) + { + // image message from assistant is not supported + throw new ArgumentException("MultiModalMessage is not supported when message.From is the same with agent"); + } + + IEnumerable items = message.Content.Select(ci => ci switch + { + TextMessage text => new ChatMessageTextContentItem(text.Content), + ImageMessage image => this.CreateChatMessageImageContentItemFromImageMessage(image), + _ => throw new NotImplementedException(), + }); + + return [new ChatRequestUserMessage(items)]; + } + + private ChatMessageImageContentItem CreateChatMessageImageContentItemFromImageMessage(ImageMessage message) + { + return message.Data is null && message.Url is not null + ? new ChatMessageImageContentItem(new Uri(message.Url)) + : new ChatMessageImageContentItem(message.Data, message.Data?.MediaType); + } + + private IEnumerable ProcessToolCallMessage(IAgent agent, ToolCallMessage message) + { + if (message.From is not null && message.From != agent.Name) + { + throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent"); + } + + var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments)); + var textContent = message.GetContent() ?? string.Empty; + var chatRequestMessage = new ChatRequestAssistantMessage() { Content = textContent }; + foreach (var tc in toolCall) + { + chatRequestMessage.ToolCalls.Add(tc); + } + + return [chatRequestMessage]; + } + + private IEnumerable ProcessToolCallResultMessage(ToolCallResultMessage message) + { + return message.ToolCalls + .Where(tc => tc.Result is not null) + .Select((tc, i) => new ChatRequestToolMessage(tc.Result, tc.ToolCallId ?? $"{tc.FunctionName}_{i}")); + } + + private IEnumerable ProcessFunctionCallMiddlewareMessage(IAgent agent, AggregateMessage aggregateMessage) + { + if (aggregateMessage.From is not null && aggregateMessage.From != agent.Name) + { + // convert as user message + var resultMessage = aggregateMessage.Message2; + + return resultMessage.ToolCalls.Select(tc => new ChatRequestUserMessage(tc.Result)); + } + else + { + var toolCallMessage1 = aggregateMessage.Message1; + var toolCallResultMessage = aggregateMessage.Message2; + + var assistantMessage = this.ProcessToolCallMessage(agent, toolCallMessage1); + var toolCallResults = this.ProcessToolCallResultMessage(toolCallResultMessage); + + return assistantMessage.Concat(toolCallResults); + } + } +} diff --git a/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveKernelBuilder.cs b/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveKernelBuilder.cs index a8f330154922..cc282fbba55c 100644 --- a/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveKernelBuilder.cs +++ b/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveKernelBuilder.cs @@ -1,127 +1,28 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // DotnetInteractiveKernelBuilder.cs -#if NET8_0_OR_GREATER -using AutoGen.DotnetInteractive.Extension; -using Microsoft.DotNet.Interactive; -using Microsoft.DotNet.Interactive.Commands; -using Microsoft.DotNet.Interactive.CSharp; -using Microsoft.DotNet.Interactive.FSharp; -using Microsoft.DotNet.Interactive.Jupyter; -using Microsoft.DotNet.Interactive.PackageManagement; -using Microsoft.DotNet.Interactive.PowerShell; - namespace AutoGen.DotnetInteractive; -public class DotnetInteractiveKernelBuilder +public static class DotnetInteractiveKernelBuilder { - private readonly CompositeKernel compositeKernel; - private DotnetInteractiveKernelBuilder() +#if NET8_0_OR_GREATER + public static InProccessDotnetInteractiveKernelBuilder CreateEmptyInProcessKernelBuilder() { - this.compositeKernel = new CompositeKernel(); - - // add jupyter connector - this.compositeKernel.AddKernelConnector( - new ConnectJupyterKernelCommand() - .AddConnectionOptions(new JupyterHttpKernelConnectionOptions()) - .AddConnectionOptions(new JupyterLocalKernelConnectionOptions())); + return new InProccessDotnetInteractiveKernelBuilder(); } - /// - /// Create an empty builder. - /// - /// - public static DotnetInteractiveKernelBuilder CreateEmptyBuilder() - { - return new DotnetInteractiveKernelBuilder(); - } - /// - /// Create a default builder with C# and F# kernels. - /// - public static DotnetInteractiveKernelBuilder CreateDefaultBuilder() + public static InProccessDotnetInteractiveKernelBuilder CreateDefaultInProcessKernelBuilder() { - return new DotnetInteractiveKernelBuilder() + return new InProccessDotnetInteractiveKernelBuilder() .AddCSharpKernel() .AddFSharpKernel(); } +#endif - public DotnetInteractiveKernelBuilder AddCSharpKernel(IEnumerable? aliases = null) - { - aliases ??= ["c#", "C#"]; - // create csharp kernel - var csharpKernel = new CSharpKernel() - .UseNugetDirective((k, resolvedPackageReference) => - { - - k.AddAssemblyReferences(resolvedPackageReference - .SelectMany(r => r.AssemblyPaths)); - return Task.CompletedTask; - }) - .UseKernelHelpers() - .UseWho() - .UseMathAndLaTeX() - .UseValueSharing(); - - this.AddKernel(csharpKernel, aliases); - - return this; - } - - public DotnetInteractiveKernelBuilder AddFSharpKernel(IEnumerable? aliases = null) - { - aliases ??= ["f#", "F#"]; - // create fsharp kernel - var fsharpKernel = new FSharpKernel() - .UseDefaultFormatting() - .UseKernelHelpers() - .UseWho() - .UseMathAndLaTeX() - .UseValueSharing(); - - this.AddKernel(fsharpKernel, aliases); - - return this; - } - - public DotnetInteractiveKernelBuilder AddPowershellKernel(IEnumerable? aliases = null) - { - aliases ??= ["pwsh", "powershell"]; - // create powershell kernel - var powershellKernel = new PowerShellKernel() - .UseProfiles() - .UseValueSharing(); - - this.AddKernel(powershellKernel, aliases); - - return this; - } - - public DotnetInteractiveKernelBuilder AddPythonKernel(string venv, string kernelName = "python", IEnumerable? aliases = null) - { - aliases ??= [kernelName]; - // create python kernel - var magicCommand = $"#!connect jupyter --kernel-name {kernelName} --kernel-spec {venv}"; - var connectCommand = new SubmitCode(magicCommand); - var result = this.compositeKernel.SendAsync(connectCommand).Result; - - result.ThrowOnCommandFailed(); - - return this; - } - - public CompositeKernel Build() - { - return this.compositeKernel - .UseDefaultMagicCommands() - .UseImportMagicCommand(); - } - - private DotnetInteractiveKernelBuilder AddKernel(Kernel kernel, IEnumerable? aliases = null) + public static DotnetInteractiveStdioKernelConnector CreateKernelBuilder(string workingDirectory, string kernelName = "root-proxy") { - this.compositeKernel.Add(kernel, aliases); - return this; + return new DotnetInteractiveStdioKernelConnector(workingDirectory, kernelName); } } -#endif diff --git a/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveStdioKernelConnector.cs b/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveStdioKernelConnector.cs new file mode 100644 index 000000000000..a3ea80a7b12a --- /dev/null +++ b/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveStdioKernelConnector.cs @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// DotnetInteractiveStdioKernelConnector.cs + +using AutoGen.DotnetInteractive.Extension; +using Microsoft.DotNet.Interactive; +using Microsoft.DotNet.Interactive.Commands; +using Microsoft.DotNet.Interactive.Connection; + +namespace AutoGen.DotnetInteractive; + +public class DotnetInteractiveStdioKernelConnector +{ + private string workingDirectory; + private InteractiveService interactiveService; + private string kernelName; + private List setupCommands = new List(); + + internal DotnetInteractiveStdioKernelConnector(string workingDirectory, string kernelName = "root-proxy") + { + this.workingDirectory = workingDirectory; + this.interactiveService = new InteractiveService(workingDirectory); + this.kernelName = kernelName; + } + + public DotnetInteractiveStdioKernelConnector RestoreDotnetInteractive() + { + if (this.interactiveService.RestoreDotnetInteractive()) + { + return this; + } + else + { + throw new Exception("Failed to restore dotnet interactive tool."); + } + } + + public DotnetInteractiveStdioKernelConnector AddPythonKernel( + string venv, + string kernelName = "python") + { + var magicCommand = $"#!connect jupyter --kernel-name {kernelName} --kernel-spec {venv}"; + var connectCommand = new SubmitCode(magicCommand); + + this.setupCommands.Add(connectCommand); + + return this; + } + + public async Task BuildAsync(CancellationToken ct = default) + { + var compositeKernel = new CompositeKernel(); + var url = KernelHost.CreateHostUri(this.kernelName); + var cmd = new string[] + { + "dotnet", + "tool", + "run", + "dotnet-interactive", + $"[cb-{this.kernelName}]", + "stdio", + //"--default-kernel", + //"csharp", + "--working-dir", + $@"""{workingDirectory}""", + }; + + var connector = new StdIoKernelConnector( + cmd, + this.kernelName, + url, + new DirectoryInfo(this.workingDirectory)); + + var rootProxyKernel = await connector.CreateRootProxyKernelAsync(); + + rootProxyKernel.KernelInfo.SupportedKernelCommands.Add(new(nameof(SubmitCode))); + + var dotnetKernel = await connector.CreateProxyKernelAsync(".NET"); + foreach (var setupCommand in this.setupCommands) + { + var setupCommandResult = await rootProxyKernel.SendAsync(setupCommand, ct); + setupCommandResult.ThrowOnCommandFailed(); + } + + return rootProxyKernel; + } +} diff --git a/dotnet/src/AutoGen.DotnetInteractive/InProccessDotnetInteractiveKernelBuilder.cs b/dotnet/src/AutoGen.DotnetInteractive/InProccessDotnetInteractiveKernelBuilder.cs new file mode 100644 index 000000000000..6ddd3d6b4178 --- /dev/null +++ b/dotnet/src/AutoGen.DotnetInteractive/InProccessDotnetInteractiveKernelBuilder.cs @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// InProccessDotnetInteractiveKernelBuilder.cs + +#if NET8_0_OR_GREATER +using AutoGen.DotnetInteractive.Extension; +using Microsoft.DotNet.Interactive; +using Microsoft.DotNet.Interactive.Commands; +using Microsoft.DotNet.Interactive.CSharp; +using Microsoft.DotNet.Interactive.FSharp; +using Microsoft.DotNet.Interactive.Jupyter; +using Microsoft.DotNet.Interactive.PackageManagement; +using Microsoft.DotNet.Interactive.PowerShell; + +namespace AutoGen.DotnetInteractive; + +/// +/// Build an in-proc dotnet interactive kernel. +/// +public class InProccessDotnetInteractiveKernelBuilder +{ + private readonly CompositeKernel compositeKernel; + + internal InProccessDotnetInteractiveKernelBuilder() + { + this.compositeKernel = new CompositeKernel(); + + // add jupyter connector + this.compositeKernel.AddKernelConnector( + new ConnectJupyterKernelCommand() + .AddConnectionOptions(new JupyterHttpKernelConnectionOptions()) + .AddConnectionOptions(new JupyterLocalKernelConnectionOptions())); + } + + public InProccessDotnetInteractiveKernelBuilder AddCSharpKernel(IEnumerable? aliases = null) + { + aliases ??= ["c#", "C#", "csharp"]; + // create csharp kernel + var csharpKernel = new CSharpKernel() + .UseNugetDirective((k, resolvedPackageReference) => + { + + k.AddAssemblyReferences(resolvedPackageReference + .SelectMany(r => r.AssemblyPaths)); + return Task.CompletedTask; + }) + .UseKernelHelpers() + .UseWho() + .UseMathAndLaTeX() + .UseValueSharing(); + + this.AddKernel(csharpKernel, aliases); + + return this; + } + + public InProccessDotnetInteractiveKernelBuilder AddFSharpKernel(IEnumerable? aliases = null) + { + aliases ??= ["f#", "F#", "fsharp"]; + // create fsharp kernel + var fsharpKernel = new FSharpKernel() + .UseDefaultFormatting() + .UseKernelHelpers() + .UseWho() + .UseMathAndLaTeX() + .UseValueSharing(); + + this.AddKernel(fsharpKernel, aliases); + + return this; + } + + public InProccessDotnetInteractiveKernelBuilder AddPowershellKernel(IEnumerable? aliases = null) + { + aliases ??= ["pwsh", "powershell"]; + // create powershell kernel + var powershellKernel = new PowerShellKernel() + .UseProfiles() + .UseValueSharing(); + + this.AddKernel(powershellKernel, aliases); + + return this; + } + + public InProccessDotnetInteractiveKernelBuilder AddPythonKernel(string venv, string kernelName = "python") + { + // create python kernel + var magicCommand = $"#!connect jupyter --kernel-name {kernelName} --kernel-spec {venv}"; + var connectCommand = new SubmitCode(magicCommand); + var result = this.compositeKernel.SendAsync(connectCommand).Result; + + result.ThrowOnCommandFailed(); + + return this; + } + + public CompositeKernel Build() + { + return this.compositeKernel + .UseDefaultMagicCommands() + .UseImportMagicCommand(); + } + + private InProccessDotnetInteractiveKernelBuilder AddKernel(Kernel kernel, IEnumerable? aliases = null) + { + this.compositeKernel.Add(kernel, aliases); + return this; + } +} +#endif diff --git a/dotnet/src/AutoGen/AutoGen.csproj b/dotnet/src/AutoGen/AutoGen.csproj index 3cb5a23da14c..88d9fca19ca2 100644 --- a/dotnet/src/AutoGen/AutoGen.csproj +++ b/dotnet/src/AutoGen/AutoGen.csproj @@ -15,6 +15,8 @@ + + diff --git a/dotnet/test/AutoGen.AzureAIInference.Tests/AutoGen.AzureAIInference.Tests.csproj b/dotnet/test/AutoGen.AzureAIInference.Tests/AutoGen.AzureAIInference.Tests.csproj new file mode 100644 index 000000000000..0eaebd1da0cb --- /dev/null +++ b/dotnet/test/AutoGen.AzureAIInference.Tests/AutoGen.AzureAIInference.Tests.csproj @@ -0,0 +1,16 @@ + + + + $(TestTargetFrameworks) + false + True + True + + + + + + + + + diff --git a/dotnet/test/AutoGen.AzureAIInference.Tests/ChatCompletionClientAgentTests.cs b/dotnet/test/AutoGen.AzureAIInference.Tests/ChatCompletionClientAgentTests.cs new file mode 100644 index 000000000000..d81b8881ac55 --- /dev/null +++ b/dotnet/test/AutoGen.AzureAIInference.Tests/ChatCompletionClientAgentTests.cs @@ -0,0 +1,533 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatCompletionClientAgentTests.cs + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using AutoGen.AzureAIInference.Extension; +using AutoGen.Core; +using AutoGen.Tests; +using Azure.AI.Inference; +using FluentAssertions; +using Xunit; + +namespace AutoGen.AzureAIInference.Tests; + +public partial class ChatCompletionClientAgentTests +{ + /// + /// Get the weather for a location. + /// + /// location + /// + [Function] + public async Task GetWeatherAsync(string location) + { + return $"The weather in {location} is sunny."; + } + + [ApiKeyFact("GH_API_KEY")] + public async Task ChatCompletionAgent_LLaMA3_1() + { + var client = CreateChatCompletionClient(); + var model = "meta-llama-3-8b-instruct"; + + var agent = new ChatCompletionsClientAgent(client, "assistant", model) + .RegisterMessageConnector(); + + var reply = await this.BasicChatAsync(agent); + reply.Should().BeOfType(); + + reply = await this.BasicChatWithContinuousMessageFromSameSenderAsync(agent); + reply.Should().BeOfType(); + } + + [ApiKeyFact("GH_API_KEY")] + public async Task BasicConversation_Mistra_Small() + { + var deployName = "Mistral-small"; + var client = CreateChatCompletionClient(); + var openAIChatAgent = new ChatCompletionsClientAgent( + chatCompletionsClient: client, + name: "assistant", + modelName: deployName); + + // By default, ChatCompletionClientAgent supports the following message types + // - IMessage + var chatMessageContent = MessageEnvelope.Create(new ChatRequestUserMessage("Hello")); + var reply = await openAIChatAgent.SendAsync(chatMessageContent); + + reply.Should().BeOfType>(); + reply.As>().From.Should().Be("assistant"); + reply.As>().Content.Choices.First().Message.Role.Should().Be(ChatRole.Assistant); + reply.As>().Content.Usage.TotalTokens.Should().BeGreaterThan(0); + + // test streaming + var streamingReply = openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); + + await foreach (var streamingMessage in streamingReply) + { + streamingMessage.Should().BeOfType>(); + streamingMessage.As>().From.Should().Be("assistant"); + } + } + + [ApiKeyFact("GH_API_KEY")] + public async Task ChatCompletionsMessageContentConnector_Phi3_Mini() + { + var deployName = "Phi-3-mini-4k-instruct"; + var openaiClient = CreateChatCompletionClient(); + var chatCompletionAgent = new ChatCompletionsClientAgent( + chatCompletionsClient: openaiClient, + name: "assistant", + modelName: deployName); + + MiddlewareStreamingAgent assistant = chatCompletionAgent + .RegisterMessageConnector(); + + var messages = new IMessage[] + { + MessageEnvelope.Create(new ChatRequestUserMessage("Hello")), + new TextMessage(Role.Assistant, "Hello", from: "user"), + new MultiModalMessage(Role.Assistant, + [ + new TextMessage(Role.Assistant, "Hello", from: "user"), + ], + from: "user"), + }; + + foreach (var message in messages) + { + var reply = await assistant.SendAsync(message); + + reply.Should().BeOfType(); + reply.As().From.Should().Be("assistant"); + } + + // test streaming + foreach (var message in messages) + { + var reply = assistant.GenerateStreamingReplyAsync([message]); + + await foreach (var streamingMessage in reply) + { + streamingMessage.Should().BeOfType(); + streamingMessage.As().From.Should().Be("assistant"); + } + } + } + + [ApiKeyFact("GH_API_KEY")] + public async Task ChatCompletionClientAgentToolCall_Mistral_Nemo() + { + var deployName = "Mistral-nemo"; + var chatCompletionClient = CreateChatCompletionClient(); + var agent = new ChatCompletionsClientAgent( + chatCompletionsClient: chatCompletionClient, + name: "assistant", + modelName: deployName); + + var functionCallMiddleware = new FunctionCallMiddleware( + functions: [this.GetWeatherAsyncFunctionContract]); + MiddlewareStreamingAgent assistant = agent + .RegisterMessageConnector(); + + assistant.StreamingMiddlewares.Count().Should().Be(1); + var functionCallAgent = assistant + .RegisterStreamingMiddleware(functionCallMiddleware); + + var question = "What's the weather in Seattle"; + var messages = new IMessage[] + { + MessageEnvelope.Create(new ChatRequestUserMessage(question)), + new TextMessage(Role.Assistant, question, from: "user"), + new MultiModalMessage(Role.Assistant, + [ + new TextMessage(Role.Assistant, question, from: "user"), + ], + from: "user"), + }; + + foreach (var message in messages) + { + var reply = await functionCallAgent.SendAsync(message); + + reply.Should().BeOfType(); + reply.As().From.Should().Be("assistant"); + reply.As().ToolCalls.Count().Should().Be(1); + reply.As().ToolCalls.First().FunctionName.Should().Be(this.GetWeatherAsyncFunctionContract.Name); + } + + // test streaming + foreach (var message in messages) + { + var reply = functionCallAgent.GenerateStreamingReplyAsync([message]); + ToolCallMessage? toolCallMessage = null; + await foreach (var streamingMessage in reply) + { + streamingMessage.Should().BeOfType(); + streamingMessage.As().From.Should().Be("assistant"); + if (toolCallMessage is null) + { + toolCallMessage = new ToolCallMessage(streamingMessage.As()); + } + else + { + toolCallMessage.Update(streamingMessage.As()); + } + } + + toolCallMessage.Should().NotBeNull(); + toolCallMessage!.From.Should().Be("assistant"); + toolCallMessage.ToolCalls.Count().Should().Be(1); + toolCallMessage.ToolCalls.First().FunctionName.Should().Be(this.GetWeatherAsyncFunctionContract.Name); + } + } + + [ApiKeyFact("GH_API_KEY")] + public async Task ChatCompletionClientAgentToolCallInvoking_gpt_4o_mini() + { + var deployName = "gpt-4o-mini"; + var client = CreateChatCompletionClient(); + var agent = new ChatCompletionsClientAgent( + chatCompletionsClient: client, + name: "assistant", + modelName: deployName); + + var functionCallMiddleware = new FunctionCallMiddleware( + functions: [this.GetWeatherAsyncFunctionContract], + functionMap: new Dictionary>> { { this.GetWeatherAsyncFunctionContract.Name!, this.GetWeatherAsyncWrapper } }); + MiddlewareStreamingAgent assistant = agent + .RegisterMessageConnector(); + + var functionCallAgent = assistant + .RegisterStreamingMiddleware(functionCallMiddleware); + + var question = "What's the weather in Seattle"; + var messages = new IMessage[] + { + MessageEnvelope.Create(new ChatRequestUserMessage(question)), + new TextMessage(Role.Assistant, question, from: "user"), + new MultiModalMessage(Role.Assistant, + [ + new TextMessage(Role.Assistant, question, from: "user"), + ], + from: "user"), + }; + + foreach (var message in messages) + { + var reply = await functionCallAgent.SendAsync(message); + + reply.Should().BeOfType(); + reply.From.Should().Be("assistant"); + reply.GetToolCalls()!.Count().Should().Be(1); + reply.GetToolCalls()!.First().FunctionName.Should().Be(this.GetWeatherAsyncFunctionContract.Name); + reply.GetContent()!.ToLower().Should().Contain("seattle"); + } + + // test streaming + foreach (var message in messages) + { + var reply = functionCallAgent.GenerateStreamingReplyAsync([message]); + await foreach (var streamingMessage in reply) + { + if (streamingMessage is not IMessage) + { + streamingMessage.Should().BeOfType(); + streamingMessage.As().From.Should().Be("assistant"); + } + else + { + streamingMessage.Should().BeOfType(); + streamingMessage.As().GetContent()!.ToLower().Should().Contain("seattle"); + } + } + } + } + + [ApiKeyFact("GH_API_KEY")] + public async Task ItCreateChatCompletionClientAgentWithChatCompletionOption_AI21_Jamba_Instruct() + { + var deployName = "AI21-Jamba-Instruct"; + var chatCompletionsClient = CreateChatCompletionClient(); + var options = new ChatCompletionsOptions() + { + Model = deployName, + Temperature = 0.7f, + MaxTokens = 1, + }; + + var openAIChatAgent = new ChatCompletionsClientAgent( + chatCompletionsClient: chatCompletionsClient, + name: "assistant", + options: options) + .RegisterMessageConnector(); + + var respond = await openAIChatAgent.SendAsync("hello"); + respond.GetContent()?.Should().NotBeNullOrEmpty(); + } + + [Fact] + public async Task ItThrowExceptionWhenChatCompletionOptionContainsMessages() + { + var client = new ChatCompletionsClient(new Uri("https://dummy.com"), new Azure.AzureKeyCredential("dummy")); + var options = new ChatCompletionsOptions([new ChatRequestUserMessage("hi")]) + { + Model = "dummy", + Temperature = 0.7f, + MaxTokens = 1, + }; + + var action = () => new ChatCompletionsClientAgent( + chatCompletionsClient: client, + name: "assistant", + options: options) + .RegisterMessageConnector(); + + action.Should().ThrowExactly().WithMessage("Messages should not be provided in options"); + } + + private ChatCompletionsClient CreateChatCompletionClient() + { + var apiKey = Environment.GetEnvironmentVariable("GH_API_KEY") ?? throw new Exception("Please set GH_API_KEY environment variable."); + var endpoint = "https://models.inference.ai.azure.com"; + return new ChatCompletionsClient(new Uri(endpoint), new Azure.AzureKeyCredential(apiKey)); + } + + /// + /// The agent should return a text message based on the chat history. + /// + /// + /// + private async Task BasicChatEndWithSelfMessageAsync(IAgent agent) + { + IMessage[] chatHistory = [ + new TextMessage(Role.Assistant, "Hello", from: "user"), + new TextMessage(Role.Assistant, "Hello", from: "user2"), + new TextMessage(Role.Assistant, "Hello", from: "user3"), + new TextMessage(Role.Assistant, "Hello", from: agent.Name), + ]; + + return await agent.GenerateReplyAsync(chatHistory); + } + + /// + /// The agent should return a text message based on the chat history. + /// + /// + /// + private async Task BasicChatAsync(IAgent agent) + { + IMessage[] chatHistory = [ + new TextMessage(Role.Assistant, "Hello", from: agent.Name), + new TextMessage(Role.Assistant, "Hello", from: "user"), + new TextMessage(Role.Assistant, "Hello", from: "user1"), + ]; + + return await agent.GenerateReplyAsync(chatHistory); + } + + /// + /// The agent should return a text message based on the chat history. This test the generate reply with continuous message from the same sender. + /// + private async Task BasicChatWithContinuousMessageFromSameSenderAsync(IAgent agent) + { + IMessage[] chatHistory = [ + new TextMessage(Role.Assistant, "Hello", from: "user"), + new TextMessage(Role.Assistant, "Hello", from: "user"), + new TextMessage(Role.Assistant, "Hello", from: agent.Name), + new TextMessage(Role.Assistant, "Hello", from: agent.Name), + ]; + + return await agent.GenerateReplyAsync(chatHistory); + } + + /// + /// The agent should return a text message based on the chat history. + /// + /// + /// + private async Task ImageChatAsync(IAgent agent) + { + var image = Path.Join("testData", "images", "square.png"); + var binaryData = File.ReadAllBytes(image); + var imageMessage = new ImageMessage(Role.Assistant, BinaryData.FromBytes(binaryData, "image/png"), from: "user"); + + IMessage[] chatHistory = [ + imageMessage, + new TextMessage(Role.Assistant, "What's in the picture", from: "user"), + ]; + + return await agent.GenerateReplyAsync(chatHistory); + } + + /// + /// The agent should return a text message based on the chat history. This test the generate reply with continuous image messages. + /// + /// + /// + private async Task MultipleImageChatAsync(IAgent agent) + { + var image1 = Path.Join("testData", "images", "square.png"); + var image2 = Path.Join("testData", "images", "background.png"); + var binaryData1 = File.ReadAllBytes(image1); + var binaryData2 = File.ReadAllBytes(image2); + var imageMessage1 = new ImageMessage(Role.Assistant, BinaryData.FromBytes(binaryData1, "image/png"), from: "user"); + var imageMessage2 = new ImageMessage(Role.Assistant, BinaryData.FromBytes(binaryData2, "image/png"), from: "user"); + + IMessage[] chatHistory = [ + imageMessage1, + imageMessage2, + new TextMessage(Role.Assistant, "What's in the picture", from: "user"), + ]; + + return await agent.GenerateReplyAsync(chatHistory); + } + + /// + /// The agent should return a text message based on the chat history. + /// + /// + /// + private async Task MultiModalChatAsync(IAgent agent) + { + var image = Path.Join("testData", "images", "square.png"); + var binaryData = File.ReadAllBytes(image); + var question = "What's in the picture"; + var imageMessage = new ImageMessage(Role.Assistant, BinaryData.FromBytes(binaryData, "image/png"), from: "user"); + var textMessage = new TextMessage(Role.Assistant, question, from: "user"); + + IMessage[] chatHistory = [ + new MultiModalMessage(Role.Assistant, [imageMessage, textMessage], from: "user"), + ]; + + return await agent.GenerateReplyAsync(chatHistory); + } + + /// + /// The agent should return a tool call message based on the chat history. + /// + /// + /// + private async Task ToolCallChatAsync(IAgent agent) + { + var question = "What's the weather in Seattle"; + var messages = new IMessage[] + { + new TextMessage(Role.Assistant, question, from: "user"), + }; + + return await agent.GenerateReplyAsync(messages); + } + + /// + /// The agent should throw an exception because tool call result is not available. + /// + private async Task ToolCallFromSelfChatAsync(IAgent agent) + { + var question = "What's the weather in Seattle"; + var messages = new IMessage[] + { + new TextMessage(Role.Assistant, question, from: "user"), + new ToolCallMessage("GetWeatherAsync", "Seattle", from: agent.Name), + }; + + return await agent.GenerateReplyAsync(messages); + } + + /// + /// mimic the further chat after tool call. The agent should return a text message based on the tool call result. + /// + private async Task ToolCallWithResultChatAsync(IAgent agent) + { + var question = "What's the weather in Seattle"; + var messages = new IMessage[] + { + new TextMessage(Role.Assistant, question, from: "user"), + new ToolCallMessage("GetWeatherAsync", "Seattle", from: "user"), + new ToolCallResultMessage("sunny", "GetWeatherAsync", "Seattle", from: agent.Name), + }; + + return await agent.GenerateReplyAsync(messages); + } + + /// + /// the agent should return a text message based on the tool call result. + /// + /// + /// + private async Task AggregateToolCallFromSelfChatAsync(IAgent agent) + { + var textMessage = new TextMessage(Role.Assistant, "What's the weather in Seattle", from: "user"); + var toolCallMessage = new ToolCallMessage("GetWeatherAsync", "Seattle", from: agent.Name); + var toolCallResultMessage = new ToolCallResultMessage("sunny", "GetWeatherAsync", "Seattle", from: agent.Name); + var aggregateToolCallMessage = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage, from: agent.Name); + + var messages = new IMessage[] + { + textMessage, + aggregateToolCallMessage, + }; + + return await agent.GenerateReplyAsync(messages); + } + + /// + /// the agent should return a text message based on the tool call result. Because the aggregate tool call message is from other, the message would be treated as an ordinary text message. + /// + private async Task AggregateToolCallFromOtherChatWithContinuousMessageAsync(IAgent agent) + { + var textMessage = new TextMessage(Role.Assistant, "What's the weather in Seattle", from: "user"); + var toolCallMessage = new ToolCallMessage("GetWeatherAsync", "Seattle", from: "other"); + var toolCallResultMessage = new ToolCallResultMessage("sunny", "GetWeatherAsync", "Seattle", from: "other"); + var aggregateToolCallMessage = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage, "other"); + + var messages = new IMessage[] + { + textMessage, + aggregateToolCallMessage, + }; + + return await agent.GenerateReplyAsync(messages); + } + + /// + /// The agent should throw an exception because tool call message from other is not allowed. + /// + private async Task ToolCallMessaageFromOtherChatAsync(IAgent agent) + { + var textMessage = new TextMessage(Role.Assistant, "What's the weather in Seattle", from: "user"); + var toolCallMessage = new ToolCallMessage("GetWeatherAsync", "Seattle", from: "other"); + + var messages = new IMessage[] + { + textMessage, + toolCallMessage, + }; + + return await agent.GenerateReplyAsync(messages); + } + + /// + /// The agent should throw an exception because multi-modal message from self is not allowed. + /// + /// + /// + private async Task MultiModalMessageFromSelfChatAsync(IAgent agent) + { + var image = Path.Join("testData", "images", "square.png"); + var binaryData = File.ReadAllBytes(image); + var question = "What's in the picture"; + var imageMessage = new ImageMessage(Role.Assistant, BinaryData.FromBytes(binaryData, "image/png"), from: agent.Name); + var textMessage = new TextMessage(Role.Assistant, question, from: agent.Name); + + IMessage[] chatHistory = [ + new MultiModalMessage(Role.Assistant, [imageMessage, textMessage], from: agent.Name), + ]; + + return await agent.GenerateReplyAsync(chatHistory); + } +} diff --git a/dotnet/test/AutoGen.AzureAIInference.Tests/ChatRequestMessageTests.cs b/dotnet/test/AutoGen.AzureAIInference.Tests/ChatRequestMessageTests.cs new file mode 100644 index 000000000000..d6e5c5283932 --- /dev/null +++ b/dotnet/test/AutoGen.AzureAIInference.Tests/ChatRequestMessageTests.cs @@ -0,0 +1,568 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatRequestMessageTests.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using System.Threading.Tasks; +using AutoGen.Core; +using AutoGen.Tests; +using Azure.AI.Inference; +using FluentAssertions; +using Xunit; + +namespace AutoGen.AzureAIInference.Tests; + +public class ChatRequestMessageTests +{ + private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions + { + WriteIndented = true, + IgnoreReadOnlyProperties = false, + }; + + [Fact] + public async Task ItProcessUserTextMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("Hello"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new TextMessage(Role.User, "Hello", "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItShortcutChatRequestMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("hello"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var userMessage = new ChatRequestUserMessage("hello"); + var chatRequestMessage = MessageEnvelope.Create(userMessage); + await agent.GenerateReplyAsync([chatRequestMessage]); + } + + [Fact] + public async Task ItShortcutMessageWhenStrictModelIsFalseAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + + var chatRequestMessage = ((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Should().Be("hello"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var userMessage = "hello"; + var chatRequestMessage = MessageEnvelope.Create(userMessage); + await agent.GenerateReplyAsync([chatRequestMessage]); + } + + [Fact] + public async Task ItThrowExceptionWhenStrictModeIsTrueAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // user message + var userMessage = "hello"; + var chatRequestMessage = MessageEnvelope.Create(userMessage); + Func action = async () => await agent.GenerateReplyAsync([chatRequestMessage]); + + await action.Should().ThrowAsync().WithMessage("Invalid message type: MessageEnvelope`1"); + } + + [Fact] + public async Task ItProcessAssistantTextMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("How can I help you?"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // assistant message + IMessage message = new TextMessage(Role.Assistant, "How can I help you?", "assistant"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItProcessSystemTextMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestSystemMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("You are a helpful AI assistant"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // system message + IMessage message = new TextMessage(Role.System, "You are a helpful AI assistant"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItProcessImageMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().BeNullOrEmpty(); + chatRequestMessage.MultimodalContentItems.Count().Should().Be(1); + chatRequestMessage.MultimodalContentItems.First().Should().BeOfType(); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new ImageMessage(Role.User, "https://example.com/image.png", "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItThrowExceptionWhenProcessingImageMessageFromSelfAndStrictModeIsTrueAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + var imageMessage = new ImageMessage(Role.Assistant, "https://example.com/image.png", "assistant"); + Func action = async () => await agent.GenerateReplyAsync([imageMessage]); + + await action.Should().ThrowAsync().WithMessage("Invalid message type: ImageMessage"); + } + + [Fact] + public async Task ItProcessMultiModalMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().BeNullOrEmpty(); + chatRequestMessage.MultimodalContentItems.Count().Should().Be(2); + chatRequestMessage.MultimodalContentItems.First().Should().BeOfType(); + chatRequestMessage.MultimodalContentItems.Last().Should().BeOfType(); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new MultiModalMessage( + Role.User, + [ + new TextMessage(Role.User, "Hello", "user"), + new ImageMessage(Role.User, "https://example.com/image.png", "user"), + ], "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItThrowExceptionWhenProcessingMultiModalMessageFromSelfAndStrictModeIsTrueAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + var multiModalMessage = new MultiModalMessage( + Role.Assistant, + [ + new TextMessage(Role.User, "Hello", "assistant"), + new ImageMessage(Role.User, "https://example.com/image.png", "assistant"), + ], "assistant"); + + Func action = async () => await agent.GenerateReplyAsync([multiModalMessage]); + + await action.Should().ThrowAsync().WithMessage("Invalid message type: MultiModalMessage"); + } + + [Fact] + public async Task ItProcessToolCallMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.ToolCalls.Count().Should().Be(1); + chatRequestMessage.Content.Should().Be("textContent"); + chatRequestMessage.ToolCalls.First().Should().BeOfType(); + var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.First(); + functionToolCall.Name.Should().Be("test"); + functionToolCall.Id.Should().Be("test"); + functionToolCall.Arguments.Should().Be("test"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new ToolCallMessage("test", "test", "assistant") + { + Content = "textContent", + }; + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItProcessParallelToolCallMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().BeNullOrEmpty(); + chatRequestMessage.ToolCalls.Count().Should().Be(2); + for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++) + { + chatRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType(); + var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.ElementAt(i); + functionToolCall.Name.Should().Be("test"); + functionToolCall.Id.Should().Be($"test_{i}"); + functionToolCall.Arguments.Should().Be("test"); + } + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var toolCalls = new[] + { + new ToolCall("test", "test"), + new ToolCall("test", "test"), + }; + IMessage message = new ToolCallMessage(toolCalls, "assistant"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItThrowExceptionWhenProcessingToolCallMessageFromUserAndStrictModeIsTrueAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(strictMode: true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + var toolCallMessage = new ToolCallMessage("test", "test", "user"); + Func action = async () => await agent.GenerateReplyAsync([toolCallMessage]); + await action.Should().ThrowAsync().WithMessage("Invalid message type: ToolCallMessage"); + } + + [Fact] + public async Task ItProcessToolCallResultMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("result"); + chatRequestMessage.ToolCallId.Should().Be("test"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new ToolCallResultMessage("result", "test", "test", "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItProcessParallelToolCallResultMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + msgs.Count().Should().Be(2); + + for (int i = 0; i < msgs.Count(); i++) + { + var innerMessage = msgs.ElementAt(i); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("result"); + chatRequestMessage.ToolCallId.Should().Be($"test_{i}"); + } + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var toolCalls = new[] + { + new ToolCall("test", "test", "result"), + new ToolCall("test", "test", "result"), + }; + IMessage message = new ToolCallResultMessage(toolCalls, "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItProcessFunctionCallMiddlewareMessageFromUserAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + msgs.Count().Should().Be(1); + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("result"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var toolCallMessage = new ToolCallMessage("test", "test", "user"); + var toolCallResultMessage = new ToolCallResultMessage("result", "test", "test", "user"); + var aggregateMessage = new AggregateMessage(toolCallMessage, toolCallResultMessage, "user"); + await agent.GenerateReplyAsync([aggregateMessage]); + } + + [Fact] + public async Task ItProcessFunctionCallMiddlewareMessageFromAssistantAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + msgs.Count().Should().Be(2); + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("result"); + chatRequestMessage.ToolCallId.Should().Be("test"); + + var toolCallMessage = msgs.First(); + toolCallMessage!.Should().BeOfType>(); + var toolCallRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)toolCallMessage!).Content; + toolCallRequestMessage.Content.Should().BeNullOrEmpty(); + toolCallRequestMessage.ToolCalls.Count().Should().Be(1); + toolCallRequestMessage.ToolCalls.First().Should().BeOfType(); + var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.First(); + functionToolCall.Name.Should().Be("test"); + functionToolCall.Id.Should().Be("test"); + functionToolCall.Arguments.Should().Be("test"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var toolCallMessage = new ToolCallMessage("test", "test", "assistant"); + var toolCallResultMessage = new ToolCallResultMessage("result", "test", "test", "assistant"); + var aggregateMessage = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage, "assistant"); + await agent.GenerateReplyAsync([aggregateMessage]); + } + + [Fact] + public async Task ItProcessParallelFunctionCallMiddlewareMessageFromAssistantAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + msgs.Count().Should().Be(3); + var toolCallMessage = msgs.First(); + toolCallMessage!.Should().BeOfType>(); + var toolCallRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)toolCallMessage!).Content; + toolCallRequestMessage.Content.Should().BeNullOrEmpty(); + toolCallRequestMessage.ToolCalls.Count().Should().Be(2); + + for (int i = 0; i < toolCallRequestMessage.ToolCalls.Count(); i++) + { + toolCallRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType(); + var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.ElementAt(i); + functionToolCall.Name.Should().Be("test"); + functionToolCall.Id.Should().Be($"test_{i}"); + functionToolCall.Arguments.Should().Be("test"); + } + + for (int i = 1; i < msgs.Count(); i++) + { + var toolCallResultMessage = msgs.ElementAt(i); + toolCallResultMessage!.Should().BeOfType>(); + var toolCallResultRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)toolCallResultMessage!).Content; + toolCallResultRequestMessage.Content.Should().Be("result"); + toolCallResultRequestMessage.ToolCallId.Should().Be($"test_{i - 1}"); + } + + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var toolCalls = new[] + { + new ToolCall("test", "test", "result"), + new ToolCall("test", "test", "result"), + }; + var toolCallMessage = new ToolCallMessage(toolCalls, "assistant"); + var toolCallResultMessage = new ToolCallResultMessage(toolCalls, "assistant"); + var aggregateMessage = new AggregateMessage(toolCallMessage, toolCallResultMessage, "assistant"); + await agent.GenerateReplyAsync([aggregateMessage]); + } + + [Fact] + public async Task ItConvertChatResponseMessageToTextMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // text message + var textMessage = CreateInstance(ChatRole.Assistant, "hello"); + var chatRequestMessage = MessageEnvelope.Create(textMessage); + + var message = await agent.GenerateReplyAsync([chatRequestMessage]); + message.Should().BeOfType(); + message.GetContent().Should().Be("hello"); + message.GetRole().Should().Be(Role.Assistant); + } + + [Fact] + public async Task ItConvertChatResponseMessageToToolCallMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // tool call message + var toolCallMessage = CreateInstance(ChatRole.Assistant, "textContent", new[] { new ChatCompletionsFunctionToolCall("test", "test", "test") }, new Dictionary()); + var chatRequestMessage = MessageEnvelope.Create(toolCallMessage); + var message = await agent.GenerateReplyAsync([chatRequestMessage]); + message.Should().BeOfType(); + message.GetToolCalls()!.Count().Should().Be(1); + message.GetToolCalls()!.First().FunctionName.Should().Be("test"); + message.GetToolCalls()!.First().FunctionArguments.Should().Be("test"); + message.GetContent().Should().Be("textContent"); + } + + [Fact] + public async Task ItReturnOriginalMessageWhenStrictModeIsFalseAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // text message + var textMessage = "hello"; + var messageToSend = MessageEnvelope.Create(textMessage); + + var message = await agent.GenerateReplyAsync([messageToSend]); + message.Should().BeOfType>(); + } + + [Fact] + public async Task ItThrowInvalidOperationExceptionWhenStrictModeIsTrueAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // text message + var textMessage = new ChatRequestUserMessage("hello"); + var messageToSend = MessageEnvelope.Create(textMessage); + Func action = async () => await agent.GenerateReplyAsync([messageToSend]); + + await action.Should().ThrowAsync().WithMessage("Invalid return message type MessageEnvelope`1"); + } + + [Fact] + public void ToOpenAIChatRequestMessageShortCircuitTest() + { + var agent = new EchoAgent("assistant"); + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + ChatRequestMessage[] messages = + [ + new ChatRequestUserMessage("Hello"), + new ChatRequestAssistantMessage() + { + Content = "How can I help you?", + }, + new ChatRequestSystemMessage("You are a helpful AI assistant"), + new ChatRequestToolMessage("test", "test"), + ]; + + foreach (var oaiMessage in messages) + { + IMessage message = new MessageEnvelope(oaiMessage); + var oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); + oaiMessages.Count().Should().Be(1); + //oaiMessages.First().Should().BeOfType>(); + if (oaiMessages.First() is IMessage chatRequestMessage) + { + chatRequestMessage.Content.Should().Be(oaiMessage); + } + else + { + // fail the test + Assert.True(false); + } + } + } + + private static T CreateInstance(params object[] args) + { + var type = typeof(T); + var instance = type.Assembly.CreateInstance( + type.FullName!, false, + BindingFlags.Instance | BindingFlags.NonPublic, + null, args, null, null); + return (T)instance!; + } +} diff --git a/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveServiceTest.cs b/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveServiceTest.cs index 0e36053c45e1..2e215a65332f 100644 --- a/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveServiceTest.cs +++ b/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveServiceTest.cs @@ -1,82 +1,82 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // DotnetInteractiveServiceTest.cs -using FluentAssertions; -using Xunit; -using Xunit.Abstractions; +//using FluentAssertions; +//using Xunit; +//using Xunit.Abstractions; -namespace AutoGen.DotnetInteractive.Tests; +//namespace AutoGen.DotnetInteractive.Tests; -public class DotnetInteractiveServiceTest : IDisposable -{ - private ITestOutputHelper _output; - private InteractiveService _interactiveService; - private string _workingDir; +//public class DotnetInteractiveServiceTest : IDisposable +//{ +// private ITestOutputHelper _output; +// private InteractiveService _interactiveService; +// private string _workingDir; - public DotnetInteractiveServiceTest(ITestOutputHelper output) - { - _output = output; - _workingDir = Path.Combine(Path.GetTempPath(), "test", Path.GetRandomFileName()); - if (!Directory.Exists(_workingDir)) - { - Directory.CreateDirectory(_workingDir); - } +// public DotnetInteractiveServiceTest(ITestOutputHelper output) +// { +// _output = output; +// _workingDir = Path.Combine(Path.GetTempPath(), "test", Path.GetRandomFileName()); +// if (!Directory.Exists(_workingDir)) +// { +// Directory.CreateDirectory(_workingDir); +// } - _interactiveService = new InteractiveService(_workingDir); - _interactiveService.StartAsync(_workingDir, default).Wait(); - } +// _interactiveService = new InteractiveService(_workingDir); +// _interactiveService.StartAsync(_workingDir, default).Wait(); +// } - public void Dispose() - { - _interactiveService.Dispose(); - } +// public void Dispose() +// { +// _interactiveService.Dispose(); +// } - [Fact] - public async Task ItRunCSharpCodeSnippetTestsAsync() - { - var cts = new CancellationTokenSource(); - var isRunning = await _interactiveService.StartAsync(_workingDir, cts.Token); +// [Fact] +// public async Task ItRunCSharpCodeSnippetTestsAsync() +// { +// var cts = new CancellationTokenSource(); +// var isRunning = await _interactiveService.StartAsync(_workingDir, cts.Token); - isRunning.Should().BeTrue(); +// isRunning.Should().BeTrue(); - _interactiveService.IsRunning().Should().BeTrue(); +// _interactiveService.IsRunning().Should().BeTrue(); - // test code snippet - var hello_world = @" -Console.WriteLine(""hello world""); -"; +// // test code snippet +// var hello_world = @" +//Console.WriteLine(""hello world""); +//"; - await this.TestCSharpCodeSnippet(_interactiveService, hello_world, "hello world"); - await this.TestCSharpCodeSnippet( - _interactiveService, - code: @" -Console.WriteLine(""hello world"" -", - expectedOutput: "Error: (2,32): error CS1026: ) expected"); +// await this.TestCSharpCodeSnippet(_interactiveService, hello_world, "hello world"); +// await this.TestCSharpCodeSnippet( +// _interactiveService, +// code: @" +//Console.WriteLine(""hello world"" +//", +// expectedOutput: "Error: (2,32): error CS1026: ) expected"); - await this.TestCSharpCodeSnippet( - service: _interactiveService, - code: "throw new Exception();", - expectedOutput: "Error: System.Exception: Exception of type 'System.Exception' was thrown"); - } +// await this.TestCSharpCodeSnippet( +// service: _interactiveService, +// code: "throw new Exception();", +// expectedOutput: "Error: System.Exception: Exception of type 'System.Exception' was thrown"); +// } - [Fact] - public async Task ItRunPowershellScriptTestsAsync() - { - // test power shell - var ps = @"Write-Output ""hello world"""; - await this.TestPowershellCodeSnippet(_interactiveService, ps, "hello world"); - } +// [Fact] +// public async Task ItRunPowershellScriptTestsAsync() +// { +// // test power shell +// var ps = @"Write-Output ""hello world"""; +// await this.TestPowershellCodeSnippet(_interactiveService, ps, "hello world"); +// } - private async Task TestPowershellCodeSnippet(InteractiveService service, string code, string expectedOutput) - { - var result = await service.SubmitPowershellCodeAsync(code, CancellationToken.None); - result.Should().StartWith(expectedOutput); - } +// private async Task TestPowershellCodeSnippet(InteractiveService service, string code, string expectedOutput) +// { +// var result = await service.SubmitPowershellCodeAsync(code, CancellationToken.None); +// result.Should().StartWith(expectedOutput); +// } - private async Task TestCSharpCodeSnippet(InteractiveService service, string code, string expectedOutput) - { - var result = await service.SubmitCSharpCodeAsync(code, CancellationToken.None); - result.Should().StartWith(expectedOutput); - } -} +// private async Task TestCSharpCodeSnippet(InteractiveService service, string code, string expectedOutput) +// { +// var result = await service.SubmitCSharpCodeAsync(code, CancellationToken.None); +// result.Should().StartWith(expectedOutput); +// } +//} diff --git a/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveStdioKernelConnectorTests.cs b/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveStdioKernelConnectorTests.cs new file mode 100644 index 000000000000..6bc361c72513 --- /dev/null +++ b/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveStdioKernelConnectorTests.cs @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// DotnetInteractiveStdioKernelConnectorTests.cs + +using AutoGen.DotnetInteractive.Extension; +using FluentAssertions; +using Microsoft.DotNet.Interactive; +using Xunit; +using Xunit.Abstractions; + +namespace AutoGen.DotnetInteractive.Tests; + +[Collection("Sequential")] +public class DotnetInteractiveStdioKernelConnectorTests +{ + private string _workingDir; + private Kernel kernel; + public DotnetInteractiveStdioKernelConnectorTests(ITestOutputHelper output) + { + _workingDir = Path.Combine(Path.GetTempPath(), "test", Path.GetRandomFileName()); + if (!Directory.Exists(_workingDir)) + { + Directory.CreateDirectory(_workingDir); + } + + kernel = DotnetInteractiveKernelBuilder + .CreateKernelBuilder(_workingDir) + .RestoreDotnetInteractive() + .AddPythonKernel("python3") + .BuildAsync().Result; + } + + + [Fact] + public async Task ItAddCSharpKernelTestAsync() + { + var csharpCode = """ + #r "nuget:Microsoft.ML, 1.5.2" + var str = "Hello" + ", World!"; + Console.WriteLine(str); + """; + + var result = await this.kernel.RunSubmitCodeCommandAsync(csharpCode, "csharp"); + result.Should().Contain("Hello, World!"); + } + + [Fact] + public async Task ItAddPowershellKernelTestAsync() + { + var powershellCode = @" + Write-Host 'Hello, World!' + "; + + var result = await this.kernel.RunSubmitCodeCommandAsync(powershellCode, "pwsh"); + result.Should().Contain("Hello, World!"); + } + + [Fact] + public async Task ItAddFSharpKernelTestAsync() + { + var fsharpCode = """ + printfn "Hello, World!" + """; + + var result = await this.kernel.RunSubmitCodeCommandAsync(fsharpCode, "fsharp"); + result.Should().Contain("Hello, World!"); + } + + [Fact] + public async Task ItAddPythonKernelTestAsync() + { + var pythonCode = """ + %pip install numpy + str = 'Hello' + ', World!' + print(str) + """; + + var result = await this.kernel.RunSubmitCodeCommandAsync(pythonCode, "python"); + result.Should().Contain("Hello, World!"); + } +} diff --git a/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveKernelBuilderTest.cs b/dotnet/test/AutoGen.DotnetInteractive.Tests/InProcessDotnetInteractiveKernelBuilderTest.cs similarity index 84% rename from dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveKernelBuilderTest.cs rename to dotnet/test/AutoGen.DotnetInteractive.Tests/InProcessDotnetInteractiveKernelBuilderTest.cs index 9565f120342c..517ee499efc1 100644 --- a/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveKernelBuilderTest.cs +++ b/dotnet/test/AutoGen.DotnetInteractive.Tests/InProcessDotnetInteractiveKernelBuilderTest.cs @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// DotnetInteractiveKernelBuilderTest.cs +// InProcessDotnetInteractiveKernelBuilderTest.cs using AutoGen.DotnetInteractive.Extension; using FluentAssertions; @@ -7,13 +7,13 @@ namespace AutoGen.DotnetInteractive.Tests; -public class DotnetInteractiveKernelBuilderTest +public class InProcessDotnetInteractiveKernelBuilderTest { [Fact] public async Task ItAddCSharpKernelTestAsync() { var kernel = DotnetInteractiveKernelBuilder - .CreateEmptyBuilder() + .CreateEmptyInProcessKernelBuilder() .AddCSharpKernel() .Build(); @@ -22,7 +22,7 @@ public async Task ItAddCSharpKernelTestAsync() Console.WriteLine("Hello, World!"); """; - var result = await kernel.RunSubmitCodeCommandAsync(csharpCode, "C#"); + var result = await kernel.RunSubmitCodeCommandAsync(csharpCode, "csharp"); result.Should().Contain("Hello, World!"); } @@ -30,7 +30,7 @@ public async Task ItAddCSharpKernelTestAsync() public async Task ItAddPowershellKernelTestAsync() { var kernel = DotnetInteractiveKernelBuilder - .CreateEmptyBuilder() + .CreateEmptyInProcessKernelBuilder() .AddPowershellKernel() .Build(); @@ -46,7 +46,7 @@ public async Task ItAddPowershellKernelTestAsync() public async Task ItAddFSharpKernelTestAsync() { var kernel = DotnetInteractiveKernelBuilder - .CreateEmptyBuilder() + .CreateEmptyInProcessKernelBuilder() .AddFSharpKernel() .Build(); @@ -55,7 +55,7 @@ public async Task ItAddFSharpKernelTestAsync() printfn "Hello, World!" """; - var result = await kernel.RunSubmitCodeCommandAsync(fsharpCode, "F#"); + var result = await kernel.RunSubmitCodeCommandAsync(fsharpCode, "fsharp"); result.Should().Contain("Hello, World!"); } @@ -63,7 +63,7 @@ public async Task ItAddFSharpKernelTestAsync() public async Task ItAddPythonKernelTestAsync() { var kernel = DotnetInteractiveKernelBuilder - .CreateEmptyBuilder() + .CreateEmptyInProcessKernelBuilder() .AddPythonKernel("python3") .Build(); diff --git a/dotnet/test/AutoGen.Test.Share/Attribute/EnvironmentSpecificFactAttribute.cs b/dotnet/test/AutoGen.Test.Share/Attribute/EnvironmentSpecificFactAttribute.cs new file mode 100644 index 000000000000..1361531cc9ed --- /dev/null +++ b/dotnet/test/AutoGen.Test.Share/Attribute/EnvironmentSpecificFactAttribute.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// EnvironmentSpecificFactAttribute.cs + +using Xunit; + +namespace AutoGen.Tests; + +/// +/// A base class for environment-specific fact attributes. +/// +[AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = true)] +public abstract class EnvironmentSpecificFactAttribute : FactAttribute +{ + private readonly string _skipMessage; + + /// + /// Creates a new instance of the class. + /// + /// The message to be used when skipping the test marked with this attribute. + protected EnvironmentSpecificFactAttribute(string skipMessage) + { + _skipMessage = skipMessage ?? throw new ArgumentNullException(nameof(skipMessage)); + } + + public sealed override string Skip => IsEnvironmentSupported() ? string.Empty : _skipMessage; + + /// + /// A method used to evaluate whether to skip a test marked with this attribute. Skips iff this method evaluates to false. + /// + protected abstract bool IsEnvironmentSupported(); +} diff --git a/dotnet/test/AutoGen.Test.Share/Attribute/OpenAIFact.cs b/dotnet/test/AutoGen.Test.Share/Attribute/OpenAIFact.cs new file mode 100644 index 000000000000..54d72cd61ab7 --- /dev/null +++ b/dotnet/test/AutoGen.Test.Share/Attribute/OpenAIFact.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OpenAIFact.cs + +namespace AutoGen.Tests; + +/// +/// A fact for tests requiring OPENAI_API_KEY env. +/// +public sealed class ApiKeyFactAttribute : EnvironmentSpecificFactAttribute +{ + private readonly string[] _envVariableNames; + public ApiKeyFactAttribute(params string[] envVariableNames) : base($"{envVariableNames} is not found in env") + { + _envVariableNames = envVariableNames; + } + + /// + protected override bool IsEnvironmentSupported() + { + return _envVariableNames.All(Environment.GetEnvironmentVariables().Contains); + } +} diff --git a/dotnet/test/AutoGen.Test.Share/AutoGen.Tests.Share.csproj b/dotnet/test/AutoGen.Test.Share/AutoGen.Tests.Share.csproj new file mode 100644 index 000000000000..21c71896ddc7 --- /dev/null +++ b/dotnet/test/AutoGen.Test.Share/AutoGen.Tests.Share.csproj @@ -0,0 +1,15 @@ + + + + $(TestTargetFrameworks) + enable + false + True + enable + + + + + + + diff --git a/dotnet/test/AutoGen.Test.Share/EchoAgent.cs b/dotnet/test/AutoGen.Test.Share/EchoAgent.cs new file mode 100644 index 000000000000..010b72d2add0 --- /dev/null +++ b/dotnet/test/AutoGen.Test.Share/EchoAgent.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// EchoAgent.cs + +using System.Runtime.CompilerServices; +using AutoGen.Core; + +namespace AutoGen.Tests; + +public class EchoAgent : IStreamingAgent +{ + public EchoAgent(string name) + { + Name = name; + } + public string Name { get; } + + public Task GenerateReplyAsync( + IEnumerable conversation, + GenerateReplyOptions? options = null, + CancellationToken ct = default) + { + // return the most recent message + var lastMessage = conversation.Last(); + lastMessage.From = this.Name; + + return Task.FromResult(lastMessage); + } + + public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + foreach (var message in messages) + { + message.From = this.Name; + yield return message; + } + } +} diff --git a/dotnet/test/AutoGen.Tests/Attribute/EnvironmentSpecificFactAttribute.cs b/dotnet/test/AutoGen.Tests/Attribute/EnvironmentSpecificFactAttribute.cs deleted file mode 100644 index 1042dec6f271..000000000000 --- a/dotnet/test/AutoGen.Tests/Attribute/EnvironmentSpecificFactAttribute.cs +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// EnvironmentSpecificFactAttribute.cs - -using System; -using Xunit; - -namespace AutoGen.Tests -{ - /// - /// A base class for environment-specific fact attributes. - /// - [AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = true)] - public abstract class EnvironmentSpecificFactAttribute : FactAttribute - { - private readonly string _skipMessage; - - /// - /// Creates a new instance of the class. - /// - /// The message to be used when skipping the test marked with this attribute. - protected EnvironmentSpecificFactAttribute(string skipMessage) - { - _skipMessage = skipMessage ?? throw new ArgumentNullException(nameof(skipMessage)); - } - - public sealed override string Skip => IsEnvironmentSupported() ? string.Empty : _skipMessage; - - /// - /// A method used to evaluate whether to skip a test marked with this attribute. Skips iff this method evaluates to false. - /// - protected abstract bool IsEnvironmentSupported(); - } -} diff --git a/dotnet/test/AutoGen.Tests/Attribute/OpenAIFact.cs b/dotnet/test/AutoGen.Tests/Attribute/OpenAIFact.cs deleted file mode 100644 index 44457d8f571c..000000000000 --- a/dotnet/test/AutoGen.Tests/Attribute/OpenAIFact.cs +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// OpenAIFact.cs - -using System; -using System.Linq; - -namespace AutoGen.Tests -{ - /// - /// A fact for tests requiring OPENAI_API_KEY env. - /// - public sealed class ApiKeyFactAttribute : EnvironmentSpecificFactAttribute - { - private readonly string[] _envVariableNames; - public ApiKeyFactAttribute(params string[] envVariableNames) : base($"{envVariableNames} is not found in env") - { - _envVariableNames = envVariableNames; - } - - /// - protected override bool IsEnvironmentSupported() - { - return _envVariableNames.All(Environment.GetEnvironmentVariables().Contains); - } - } -} diff --git a/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj b/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj index ce968b91f556..a0c3b815f22b 100644 --- a/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj +++ b/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj @@ -12,6 +12,7 @@ + diff --git a/dotnet/test/AutoGen.Tests/EchoAgent.cs b/dotnet/test/AutoGen.Tests/EchoAgent.cs deleted file mode 100644 index af5490218e8d..000000000000 --- a/dotnet/test/AutoGen.Tests/EchoAgent.cs +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// EchoAgent.cs - -using System.Collections.Generic; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; - -namespace AutoGen.Tests -{ - public class EchoAgent : IStreamingAgent - { - public EchoAgent(string name) - { - Name = name; - } - public string Name { get; } - - public Task GenerateReplyAsync( - IEnumerable conversation, - GenerateReplyOptions? options = null, - CancellationToken ct = default) - { - // return the most recent message - var lastMessage = conversation.Last(); - lastMessage.From = this.Name; - - return Task.FromResult(lastMessage); - } - - public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - foreach (var message in messages) - { - message.From = this.Name; - yield return message; - } - } - } -} diff --git a/dotnet/test/AutoGen.Tests/Orchestrator/RolePlayOrchestratorTests.cs b/dotnet/test/AutoGen.Tests/Orchestrator/RolePlayOrchestratorTests.cs index 5a2cebb66cff..f9ab09716b94 100644 --- a/dotnet/test/AutoGen.Tests/Orchestrator/RolePlayOrchestratorTests.cs +++ b/dotnet/test/AutoGen.Tests/Orchestrator/RolePlayOrchestratorTests.cs @@ -10,11 +10,14 @@ using AutoGen.Anthropic; using AutoGen.Anthropic.Extensions; using AutoGen.Anthropic.Utils; +using AutoGen.AzureAIInference; +using AutoGen.AzureAIInference.Extension; using AutoGen.Gemini; using AutoGen.Mistral; using AutoGen.Mistral.Extension; using AutoGen.OpenAI; using AutoGen.OpenAI.Extension; +using Azure.AI.Inference; using Azure.AI.OpenAI; using FluentAssertions; using Moq; @@ -304,6 +307,22 @@ public async Task Mistra_7b_CoderReviewerRunnerTestAsync() await CoderReviewerRunnerTestAsync(agent); } + [ApiKeyFact("GH_API_KEY")] + public async Task LLaMA_3_1_CoderReviewerRunnerTestAsync() + { + var apiKey = Environment.GetEnvironmentVariable("GH_API_KEY") ?? throw new InvalidOperationException("GH_API_KEY is not set."); + var endPoint = "https://models.inference.ai.azure.com"; + + var chatCompletionClient = new ChatCompletionsClient(new Uri(endPoint), new Azure.AzureKeyCredential(apiKey)); + var agent = new ChatCompletionsClientAgent( + chatCompletionsClient: chatCompletionClient, + name: "assistant", + modelName: "Meta-Llama-3.1-70B-Instruct") + .RegisterMessageConnector(); + + await CoderReviewerRunnerTestAsync(agent); + } + /// /// This test is to mimic the conversation among coder, reviewer and runner. /// The coder will write the code, the reviewer will review the code, and the runner will run the code. diff --git a/dotnet/website/articles/Agent-overview.md b/dotnet/website/articles/Agent-overview.md index 0b84cdc49ac7..586d231a6e7d 100644 --- a/dotnet/website/articles/Agent-overview.md +++ b/dotnet/website/articles/Agent-overview.md @@ -8,7 +8,6 @@ - Create an @AutoGen.OpenAI.OpenAIChatAgent: [Create an OpenAI chat agent](./OpenAIChatAgent-simple-chat.md) - Create a @AutoGen.SemanticKernel.SemanticKernelAgent: [Create a semantic kernel agent](./AutoGen.SemanticKernel/SemanticKernelAgent-simple-chat.md) - Create a @AutoGen.LMStudio.LMStudioAgent: [Connect to LM Studio](./Consume-LLM-server-from-LM-Studio.md) -- Create your own agent: [Create your own agent](./Create-your-own-agent.md) ## Chat with an agent To chat with an agent, typically you can invoke @AutoGen.Core.IAgent.GenerateReplyAsync*. On top of that, you can also use one of the extension methods like @AutoGen.Core.AgentExtension.SendAsync* as shortcuts. diff --git a/dotnet/website/articles/Installation.md b/dotnet/website/articles/Installation.md index 3ec5d3a470f4..30b55442d246 100644 --- a/dotnet/website/articles/Installation.md +++ b/dotnet/website/articles/Installation.md @@ -13,8 +13,9 @@ AutoGen.Net provides the following packages, you can choose to install one or mo - `AutoGen.LMStudio`: This package provides the integration agents from LM Studio. - `AutoGen.SemanticKernel`: This package provides the integration agents over semantic kernel. - `AutoGen.Gemini`: This package provides the integration agents from [Google Gemini](https://gemini.google.com/). +- `AutoGen.AzureAIInference`: This package provides the integration agents for [Azure AI Inference](https://www.nuget.org/packages/Azure.AI.Inference). - `AutoGen.SourceGenerator`: This package carries a source generator that adds support for type-safe function definition generation. -- `AutoGen.DotnetInteractive`: This packages carries dotnet interactive support to execute dotnet code snippet. +- `AutoGen.DotnetInteractive`: This packages carries dotnet interactive support to execute code snippets. The current supported language is C#, F#, powershell and python. >[!Note] > Help me choose diff --git a/test/agentchat/contrib/capabilities/test_transforms.py b/test/agentchat/contrib/capabilities/test_transforms.py index 46c61d9adc6f..34094a0008b7 100644 --- a/test/agentchat/contrib/capabilities/test_transforms.py +++ b/test/agentchat/contrib/capabilities/test_transforms.py @@ -9,8 +9,8 @@ MessageHistoryLimiter, MessageTokenLimiter, TextMessageCompressor, - _count_tokens, ) +from autogen.agentchat.contrib.capabilities.transforms_util import count_text_tokens class _MockTextCompressor: @@ -40,6 +40,26 @@ def get_no_content_messages() -> List[Dict]: return [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}] +def get_tool_messages() -> List[Dict]: + return [ + {"role": "user", "content": "hello"}, + {"role": "tool_calls", "content": "calling_tool"}, + {"role": "tool", "content": "tool_response"}, + {"role": "user", "content": "how are you"}, + {"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]}, + ] + + +def get_tool_messages_kept() -> List[Dict]: + return [ + {"role": "user", "content": "hello"}, + {"role": "tool_calls", "content": "calling_tool"}, + {"role": "tool", "content": "tool_response"}, + {"role": "tool_calls", "content": "calling_tool"}, + {"role": "tool", "content": "tool_response"}, + ] + + def get_text_compressors() -> List[TextCompressor]: compressors: List[TextCompressor] = [_MockTextCompressor()] try: @@ -57,6 +77,11 @@ def message_history_limiter() -> MessageHistoryLimiter: return MessageHistoryLimiter(max_messages=3) +@pytest.fixture +def message_history_limiter_keep_first() -> MessageHistoryLimiter: + return MessageHistoryLimiter(max_messages=3, keep_first_message=True) + + @pytest.fixture def message_token_limiter() -> MessageTokenLimiter: return MessageTokenLimiter(max_tokens_per_message=3) @@ -96,12 +121,43 @@ def _filter_dict_test( @pytest.mark.parametrize( "messages, expected_messages_len", - [(get_long_messages(), 3), (get_short_messages(), 3), (get_no_content_messages(), 2)], + [ + (get_long_messages(), 3), + (get_short_messages(), 3), + (get_no_content_messages(), 2), + (get_tool_messages(), 2), + (get_tool_messages_kept(), 2), + ], ) def test_message_history_limiter_apply_transform(message_history_limiter, messages, expected_messages_len): transformed_messages = message_history_limiter.apply_transform(messages) assert len(transformed_messages) == expected_messages_len + if messages == get_tool_messages_kept(): + assert transformed_messages[0]["role"] == "tool_calls" + assert transformed_messages[1]["role"] == "tool" + + +@pytest.mark.parametrize( + "messages, expected_messages_len", + [ + (get_long_messages(), 3), + (get_short_messages(), 3), + (get_no_content_messages(), 2), + (get_tool_messages(), 3), + (get_tool_messages_kept(), 3), + ], +) +def test_message_history_limiter_apply_transform_keep_first( + message_history_limiter_keep_first, messages, expected_messages_len +): + transformed_messages = message_history_limiter_keep_first.apply_transform(messages) + assert len(transformed_messages) == expected_messages_len + + if messages == get_tool_messages_kept(): + assert transformed_messages[1]["role"] == "tool_calls" + assert transformed_messages[2]["role"] == "tool" + @pytest.mark.parametrize( "messages, expected_logs, expected_effect", @@ -109,6 +165,8 @@ def test_message_history_limiter_apply_transform(message_history_limiter, messag (get_long_messages(), "Removed 2 messages. Number of messages reduced from 5 to 3.", True), (get_short_messages(), "No messages were removed.", False), (get_no_content_messages(), "No messages were removed.", False), + (get_tool_messages(), "Removed 3 messages. Number of messages reduced from 5 to 2.", True), + (get_tool_messages_kept(), "Removed 3 messages. Number of messages reduced from 5 to 2.", True), ], ) def test_message_history_limiter_get_logs(message_history_limiter, messages, expected_logs, expected_effect): @@ -131,7 +189,8 @@ def test_message_token_limiter_apply_transform( ): transformed_messages = message_token_limiter.apply_transform(copy.deepcopy(messages)) assert ( - sum(_count_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count + sum(count_text_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) + == expected_token_count ) assert len(transformed_messages) == expected_messages_len @@ -167,7 +226,8 @@ def test_message_token_limiter_with_threshold_apply_transform( ): transformed_messages = message_token_limiter_with_threshold.apply_transform(messages) assert ( - sum(_count_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count + sum(count_text_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) + == expected_token_count ) assert len(transformed_messages) == expected_messages_len @@ -240,56 +300,31 @@ def test_text_compression_with_filter(messages, text_compressor): assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=False) -@pytest.mark.parametrize("text_compressor", get_text_compressors()) -def test_text_compression_cache(text_compressor): - messages = get_long_messages() - mock_compressed_content = (1, {"content": "mock"}) - - with patch( - "autogen.agentchat.contrib.capabilities.transforms.TextMessageCompressor._cache_get", - MagicMock(return_value=(1, {"content": "mock"})), - ) as mocked_get, patch( - "autogen.agentchat.contrib.capabilities.transforms.TextMessageCompressor._cache_set", MagicMock() - ) as mocked_set: - compressor = TextMessageCompressor(text_compressor=text_compressor) - - compressor.apply_transform(messages) - compressor.apply_transform(messages) - - assert mocked_get.call_count == len(messages) - assert mocked_set.call_count == len(messages) - - # We already populated the cache with the mock content - # We need to test if we retrieve the correct content - compressor = TextMessageCompressor(text_compressor=text_compressor) - compressed_messages = compressor.apply_transform(messages) - - for message in compressed_messages: - assert message["content"] == mock_compressed_content[1] - - if __name__ == "__main__": long_messages = get_long_messages() short_messages = get_short_messages() no_content_messages = get_no_content_messages() + tool_messages = get_tool_messages() msg_history_limiter = MessageHistoryLimiter(max_messages=3) + msg_history_limiter_keep_first = MessageHistoryLimiter(max_messages=3, keep_first=True) msg_token_limiter = MessageTokenLimiter(max_tokens_per_message=3) msg_token_limiter_with_threshold = MessageTokenLimiter(max_tokens_per_message=1, min_tokens=10) # Test Parameters message_history_limiter_apply_transform_parameters = { - "messages": [long_messages, short_messages, no_content_messages], - "expected_messages_len": [3, 3, 2], + "messages": [long_messages, short_messages, no_content_messages, tool_messages], + "expected_messages_len": [3, 3, 2, 4], } message_history_limiter_get_logs_parameters = { - "messages": [long_messages, short_messages, no_content_messages], + "messages": [long_messages, short_messages, no_content_messages, tool_messages], "expected_logs": [ "Removed 2 messages. Number of messages reduced from 5 to 3.", "No messages were removed.", "No messages were removed.", + "Removed 1 messages. Number of messages reduced from 5 to 4.", ], - "expected_effect": [True, False, False], + "expected_effect": [True, False, False, True], } message_token_limiter_apply_transform_parameters = { @@ -322,6 +357,14 @@ def test_text_compression_cache(text_compressor): ): test_message_history_limiter_apply_transform(msg_history_limiter, messages, expected_messages_len) + for messages, expected_messages_len in zip( + message_history_limiter_apply_transform_parameters["messages"], + message_history_limiter_apply_transform_parameters["expected_messages_len"], + ): + test_message_history_limiter_apply_transform_keep_first( + msg_history_limiter_keep_first, messages, expected_messages_len + ) + for messages, expected_logs, expected_effect in zip( message_history_limiter_get_logs_parameters["messages"], message_history_limiter_get_logs_parameters["expected_logs"], diff --git a/test/agentchat/test_nested.py b/test/agentchat/test_nested.py index ee8da793fdec..04fc84b5b399 100755 --- a/test/agentchat/test_nested.py +++ b/test/agentchat/test_nested.py @@ -2,10 +2,12 @@ import os import sys +from typing import List import pytest import autogen +from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability sys.path.append(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) @@ -13,6 +15,23 @@ from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402 +class MockAgentReplies(AgentCapability): + def __init__(self, mock_messages: List[str]): + self.mock_messages = mock_messages + self.mock_message_index = 0 + + def add_to_agent(self, agent: autogen.ConversableAgent): + def mock_reply(recipient, messages, sender, config): + if self.mock_message_index < len(self.mock_messages): + reply_msg = self.mock_messages[self.mock_message_index] + self.mock_message_index += 1 + return [True, reply_msg] + else: + raise ValueError(f"No more mock messages available for {sender.name} to reply to {recipient.name}") + + agent.register_reply([autogen.Agent, None], mock_reply, position=2) + + @pytest.mark.skipif(skip_openai, reason=reason) def test_nested(): config_list = autogen.config_list_from_json(env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC) @@ -142,5 +161,216 @@ def writing_message(recipient, messages, sender, config): ) +def test_sync_nested_chat(): + def is_termination(msg): + if isinstance(msg, str) and msg == "FINAL_RESULT": + return True + elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT": + return True + return False + + inner_assistant = autogen.AssistantAgent( + "Inner-assistant", + is_termination_msg=is_termination, + ) + MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant) + + inner_assistant_2 = autogen.AssistantAgent( + "Inner-assistant-2", + ) + MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent( + inner_assistant_2 + ) + + assistant = autogen.AssistantAgent( + "Assistant", + ) + user = autogen.UserProxyAgent( + "User", + human_input_mode="NEVER", + is_termination_msg=is_termination, + ) + assistant.register_nested_chats( + [{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}], trigger=user + ) + chat_result = user.initiate_chat(assistant, message="Start chat") + assert len(chat_result.chat_history) == 2 + chat_messages = [msg["content"] for msg in chat_result.chat_history] + assert chat_messages == ["Start chat", "FINAL_RESULT"] + + +@pytest.mark.asyncio +async def test_async_nested_chat(): + def is_termination(msg): + if isinstance(msg, str) and msg == "FINAL_RESULT": + return True + elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT": + return True + return False + + inner_assistant = autogen.AssistantAgent( + "Inner-assistant", + is_termination_msg=is_termination, + ) + MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant) + + inner_assistant_2 = autogen.AssistantAgent( + "Inner-assistant-2", + ) + MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent( + inner_assistant_2 + ) + + assistant = autogen.AssistantAgent( + "Assistant", + ) + user = autogen.UserProxyAgent( + "User", + human_input_mode="NEVER", + is_termination_msg=is_termination, + ) + assistant.register_nested_chats( + [{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg", "chat_id": 1}], + trigger=user, + use_async=True, + ) + chat_result = await user.a_initiate_chat(assistant, message="Start chat") + assert len(chat_result.chat_history) == 2 + chat_messages = [msg["content"] for msg in chat_result.chat_history] + assert chat_messages == ["Start chat", "FINAL_RESULT"] + + +@pytest.mark.asyncio +async def test_async_nested_chat_chat_id_validation(): + def is_termination(msg): + if isinstance(msg, str) and msg == "FINAL_RESULT": + return True + elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT": + return True + return False + + inner_assistant = autogen.AssistantAgent( + "Inner-assistant", + is_termination_msg=is_termination, + ) + MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant) + + inner_assistant_2 = autogen.AssistantAgent( + "Inner-assistant-2", + ) + MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent( + inner_assistant_2 + ) + + assistant = autogen.AssistantAgent( + "Assistant", + ) + user = autogen.UserProxyAgent( + "User", + human_input_mode="NEVER", + is_termination_msg=is_termination, + ) + with pytest.raises(ValueError, match="chat_id is required for async nested chats"): + assistant.register_nested_chats( + [{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}], + trigger=user, + use_async=True, + ) + + +def test_sync_nested_chat_in_group(): + def is_termination(msg): + if isinstance(msg, str) and msg == "FINAL_RESULT": + return True + elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT": + return True + return False + + inner_assistant = autogen.AssistantAgent( + "Inner-assistant", + is_termination_msg=is_termination, + ) + MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant) + + inner_assistant_2 = autogen.AssistantAgent( + "Inner-assistant-2", + ) + MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent( + inner_assistant_2 + ) + + assistant = autogen.AssistantAgent( + "Assistant_In_Group_1", + ) + MockAgentReplies(["Assistant_In_Group_1 message 1"]).add_to_agent(assistant) + assistant2 = autogen.AssistantAgent( + "Assistant_In_Group_2", + ) + user = autogen.UserProxyAgent("User", human_input_mode="NEVER", is_termination_msg=is_termination) + group = autogen.GroupChat( + agents=[assistant, assistant2, user], + messages=[], + speaker_selection_method="round_robin", + ) + group_manager = autogen.GroupChatManager(groupchat=group) + assistant2.register_nested_chats( + [{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}], + trigger=group_manager, + ) + + chat_result = user.initiate_chat(group_manager, message="Start chat", summary_method="last_msg") + assert len(chat_result.chat_history) == 3 + chat_messages = [msg["content"] for msg in chat_result.chat_history] + assert chat_messages == ["Start chat", "Assistant_In_Group_1 message 1", "FINAL_RESULT"] + + +@pytest.mark.asyncio +async def test_async_nested_chat_in_group(): + def is_termination(msg): + if isinstance(msg, str) and msg == "FINAL_RESULT": + return True + elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT": + return True + return False + + inner_assistant = autogen.AssistantAgent( + "Inner-assistant", + is_termination_msg=is_termination, + ) + MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant) + + inner_assistant_2 = autogen.AssistantAgent( + "Inner-assistant-2", + ) + MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent( + inner_assistant_2 + ) + + assistant = autogen.AssistantAgent( + "Assistant_In_Group_1", + ) + MockAgentReplies(["Assistant_In_Group_1 message 1"]).add_to_agent(assistant) + assistant2 = autogen.AssistantAgent( + "Assistant_In_Group_2", + ) + user = autogen.UserProxyAgent("User", human_input_mode="NEVER", is_termination_msg=is_termination) + group = autogen.GroupChat( + agents=[assistant, assistant2, user], + messages=[], + speaker_selection_method="round_robin", + ) + group_manager = autogen.GroupChatManager(groupchat=group) + assistant2.register_nested_chats( + [{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg", "chat_id": 1}], + trigger=group_manager, + use_async=True, + ) + + chat_result = await user.a_initiate_chat(group_manager, message="Start chat", summary_method="last_msg") + assert len(chat_result.chat_history) == 3 + chat_messages = [msg["content"] for msg in chat_result.chat_history] + assert chat_messages == ["Start chat", "Assistant_In_Group_1 message 1", "FINAL_RESULT"] + + if __name__ == "__main__": test_nested() diff --git a/website/docs/Use-Cases/agent_chat.md b/website/docs/Use-Cases/agent_chat.md index 59156c0eb046..46b555b3d7ca 100644 --- a/website/docs/Use-Cases/agent_chat.md +++ b/website/docs/Use-Cases/agent_chat.md @@ -83,7 +83,7 @@ With the pluggable auto-reply function, one can choose to invoke conversations w - Hierarchical chat like in [OptiGuide](https://github.com/microsoft/optiguide). - [Dynamic Group Chat](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat.ipynb) which is a special form of hierarchical chat. In the system, we register a reply function in the group chat manager, which broadcasts messages and decides who the next speaker will be in a group chat setting. - [Finite State Machine graphs to set speaker transition constraints](https://microsoft.github.io/autogen/docs/notebooks/agentchat_groupchat_finite_state_machine) which is a special form of dynamic group chat. In this approach, a directed transition matrix is fed into group chat. Users can specify legal transitions or specify disallowed transitions. -- Nested chat like in [conversational chess](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_chess.ipynb). +- Nested chat like in [conversational chess](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_nested_chats_chess.ipynb). 2. LLM-Based Function Call diff --git a/website/docs/topics/handling_long_contexts/intro_to_transform_messages.md b/website/docs/topics/handling_long_contexts/intro_to_transform_messages.md index d0a53702c48b..52fea15d01e5 100644 --- a/website/docs/topics/handling_long_contexts/intro_to_transform_messages.md +++ b/website/docs/topics/handling_long_contexts/intro_to_transform_messages.md @@ -59,7 +59,28 @@ pprint.pprint(processed_messages) {'content': 'very very very very very very long string', 'role': 'user'}] ``` -By applying the `MessageHistoryLimiter`, we can see that we were able to limit the context history to the 3 most recent messages. +By applying the `MessageHistoryLimiter`, we can see that we were able to limit the context history to the 3 most recent messages. However, if the splitting point is between a "tool_calls" and "tool" pair, the complete pair will be included to obey the OpenAI API call constraints. + +```python +max_msg_transfrom = transforms.MessageHistoryLimiter(max_messages=3) + +messages = [ + {"role": "user", "content": "hello"}, + {"role": "tool_calls", "content": "calling_tool"}, + {"role": "tool", "content": "tool_response"}, + {"role": "user", "content": "how are you"}, + {"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]}, +] + +processed_messages = max_msg_transfrom.apply_transform(copy.deepcopy(messages)) +pprint.pprint(processed_messages) +``` +```console +[{'content': 'calling_tool', 'role': 'tool_calls'}, +{'content': 'tool_response', 'role': 'tool'}, +{'content': 'how are you', 'role': 'user'}, +{'content': [{'text': 'are you doing?', 'type': 'text'}], 'role': 'assistant'}] +``` #### Example 2: Limiting the Number of Tokens