From 11a43421e349113f699b9b706e6515d7023bf879 Mon Sep 17 00:00:00 2001 From: giorgossideris <56915448+giorgossideris@users.noreply.github.com> Date: Mon, 29 Apr 2024 05:11:16 +0300 Subject: [PATCH] Min tokens in token limiter (#2400) * Add minimum token threshold in MessageHistoryLimiter * Update transforms tests for the threshold * Move min_threshold_tokens from Message to Token Limiter * Optimize _check_tokens_threshold Co-authored-by: Wael Karkoub * Apply requested changes (renaming, phrasing, validations) * Fix format * Fix _check_tokens_threshold logic * Update docs and notebook * Improve phrasing * Add min_tokens example in notebook * Add min_tokens example in website docs * Add min_tokens example in notebook * Update website docs to be in sync with get_logs change --------- Co-authored-by: Wael Karkoub Co-authored-by: Chi Wang --- .../contrib/capabilities/transforms.py | 45 +++- notebook/agentchat_transform_messages.ipynb | 226 ++++++++++++------ .../contrib/capabilities/test_transforms.py | 115 +++++++-- website/docs/topics/long_contexts.md | 117 +++++---- 4 files changed, 351 insertions(+), 152 deletions(-) diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py index 6dc1d59fe9c7..279faed8c9d6 100644 --- a/autogen/agentchat/contrib/capabilities/transforms.py +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -51,8 +51,7 @@ class MessageHistoryLimiter: def __init__(self, max_messages: Optional[int] = None): """ Args: - max_messages (None or int): Maximum number of messages to keep in the context. - Must be greater than 0 if not None. + max_messages Optional[int]: Maximum number of messages to keep in the context. Must be greater than 0 if not None. """ self._validate_max_messages(max_messages) self._max_messages = max_messages @@ -70,6 +69,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: Returns: List[Dict]: A new list containing the most recent messages up to the specified maximum. """ + if self._max_messages is None: return messages @@ -108,13 +108,15 @@ class MessageTokenLimiter: The truncation process follows these steps in order: - 1. Messages are processed in reverse order (newest to oldest). - 2. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text + 1. The minimum tokens threshold (`min_tokens`) is checked (0 by default). If the total number of tokens in messages + are less than this threshold, then the messages are returned as is. In other case, the following process is applied. + 2. Messages are processed in reverse order (newest to oldest). + 3. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text and other types of content, only the text content is truncated. - 3. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count + 4. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count exceeds this limit, the current message being processed get truncated to meet the total token count and any remaining messages get discarded. - 4. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the + 5. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the original message order. """ @@ -122,6 +124,7 @@ def __init__( self, max_tokens_per_message: Optional[int] = None, max_tokens: Optional[int] = None, + min_tokens: Optional[int] = None, model: str = "gpt-3.5-turbo-0613", ): """ @@ -130,11 +133,14 @@ def __init__( Must be greater than or equal to 0 if not None. max_tokens (Optional[int]): Maximum number of tokens to keep in the chat history. Must be greater than or equal to 0 if not None. + min_tokens (Optional[int]): Minimum number of tokens in messages to apply the transformation. + Must be greater than or equal to 0 if not None. model (str): The target OpenAI model for tokenization alignment. """ self._model = model self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message) self._max_tokens = self._validate_max_tokens(max_tokens) + self._min_tokens = self._validate_min_tokens(min_tokens, max_tokens) def apply_transform(self, messages: List[Dict]) -> List[Dict]: """Applies token truncation to the conversation history. @@ -147,6 +153,11 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: """ assert self._max_tokens_per_message is not None assert self._max_tokens is not None + assert self._min_tokens is not None + + # if the total number of tokens in the messages is less than the min_tokens, return the messages as is + if not self._are_min_tokens_reached(messages): + return messages temp_messages = copy.deepcopy(messages) processed_messages = [] @@ -194,6 +205,19 @@ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: return logs_str, True return "No tokens were truncated.", False + def _are_min_tokens_reached(self, messages: List[Dict]) -> bool: + """ + Returns True if no minimum tokens restrictions are applied. + + Either if the total number of tokens in the messages is greater than or equal to the `min_theshold_tokens`, + or no minimum tokens threshold is set. + """ + if not self._min_tokens: + return True + + messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg) + return messages_tokens >= self._min_tokens + def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]: if isinstance(contents, str): return self._truncate_tokens(contents, n_tokens) @@ -244,6 +268,15 @@ def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int return max_tokens if max_tokens is not None else sys.maxsize + def _validate_min_tokens(self, min_tokens: int, max_tokens: int) -> int: + if min_tokens is None: + return 0 + if min_tokens < 0: + raise ValueError("min_tokens must be None or greater than or equal to 0.") + if max_tokens is not None and min_tokens > max_tokens: + raise ValueError("min_tokens must not be more than max_tokens.") + return min_tokens + def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int: token_count = 0 diff --git a/notebook/agentchat_transform_messages.ipynb b/notebook/agentchat_transform_messages.ipynb index ab8bc762fc76..d0216e05dd2d 100644 --- a/notebook/agentchat_transform_messages.ipynb +++ b/notebook/agentchat_transform_messages.ipynb @@ -24,16 +24,15 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "47773f79-c0fd-4993-bc6e-3d1a57690118", "metadata": {}, "outputs": [], "source": [ "import copy\n", - "import os\n", "import pprint\n", "import re\n", - "from typing import Dict, List\n", + "from typing import Dict, List, Tuple\n", "\n", "import autogen\n", "from autogen.agentchat.contrib.capabilities import transform_messages, transforms" @@ -41,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "id": "9f09246b-a7d0-4238-b62c-1e72c7d815b3", "metadata": {}, "outputs": [], @@ -95,7 +94,7 @@ "Imagine a scenario where the LLM generates an extensive amount of text, surpassing the token limit imposed by your API provider. To address this issue, you can leverage `TransformMessages` along with its constituent transformations, `MessageHistoryLimiter` and `MessageTokenLimiter`.\n", "\n", "- `MessageHistoryLimiter`: You can restrict the total number of messages considered as context history. This transform is particularly useful when you want to limit the conversational context to a specific number of recent messages, ensuring efficient processing and response generation.\n", - "- `MessageTokenLimiter`: Enables you to cap the total number of tokens, either on a per-message basis or across the entire context history (or both). This transformation is invaluable when you need to adhere to strict token limits imposed by your API provider, preventing unnecessary costs or errors caused by exceeding the allowed token count." + "- `MessageTokenLimiter`: Enables you to cap the total number of tokens, either on a per-message basis or across the entire context history (or both). This transformation is invaluable when you need to adhere to strict token limits imposed by your API provider, preventing unnecessary costs or errors caused by exceeding the allowed token count. Additionally, a `min_tokens` threshold can be applied, ensuring that the transformation is only applied when the number of tokens is not less than the specified threshold." ] }, { @@ -109,7 +108,7 @@ "max_msg_transfrom = transforms.MessageHistoryLimiter(max_messages=3)\n", "\n", "# Limit the token limit per message to 10 tokens\n", - "token_limit_transform = transforms.MessageTokenLimiter(max_tokens_per_message=3)" + "token_limit_transform = transforms.MessageTokenLimiter(max_tokens_per_message=3, min_tokens=10)" ] }, { @@ -170,7 +169,6 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[33mTruncated 6 tokens. Tokens reduced from 15 to 9\u001b[0m\n", "[{'content': 'hello', 'role': 'user'},\n", " {'content': [{'text': 'there', 'type': 'text'}], 'role': 'assistant'},\n", " {'content': 'how', 'role': 'user'},\n", @@ -185,6 +183,40 @@ "pprint.pprint(processed_messages)" ] }, + { + "cell_type": "markdown", + "id": "86a98e08", + "metadata": {}, + "source": [ + "Also, the `min_tokens` threshold is set to 10, indicating that the transformation will not be applied if the total number of tokens in the messages is less than that. This is especially beneficial when the transformation should only occur after a certain number of tokens has been reached, such as in the context window of the model. An example is provided below." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "05c42ffc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[{'content': 'hello there, how are you?', 'role': 'user'},\n", + " {'content': [{'text': 'hello', 'type': 'text'}], 'role': 'assistant'}]\n" + ] + } + ], + "source": [ + "short_messages = [\n", + " {\"role\": \"user\", \"content\": \"hello there, how are you?\"},\n", + " {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": \"hello\"}]},\n", + "]\n", + "\n", + "processed_short_messages = token_limit_transform.apply_transform(copy.deepcopy(short_messages))\n", + "\n", + "pprint.pprint(processed_short_messages)" + ] + }, { "cell_type": "markdown", "id": "35fa2844-bd83-42ac-8275-959f093b7bc7", @@ -197,7 +229,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "80e53623-2830-41b7-8ae2-bf3668071657", "metadata": {}, "outputs": [ @@ -211,7 +243,7 @@ "\n", "--------------------------------------------------------------------------------\n", "Encountered an error with the base assistant\n", - "Error code: 429 - {'error': {'message': 'Request too large for gpt-3.5-turbo in organization org-U58JZBsXUVAJPlx2MtPYmdx1 on tokens per min (TPM): Limit 60000, Requested 1252546. The input or output tokens must be reduced in order to run successfully. Visit https://platform.openai.com/account/rate-limits to learn more.', 'type': 'tokens', 'param': None, 'code': 'rate_limit_exceeded'}}\n", + "Error code: 400 - {'error': {'message': \"This model's maximum context length is 16385 tokens. However, your messages resulted in 1009487 tokens. Please reduce the length of the messages.\", 'type': 'invalid_request_error', 'param': 'messages', 'code': 'context_length_exceeded'}}\n", "\n", "\n", "\n", @@ -220,38 +252,42 @@ "plot and save a graph of x^2 from -10 to 10\n", "\n", "--------------------------------------------------------------------------------\n", - "\u001b[33mTruncated 3804 tokens. Tokens reduced from 4019 to 215\u001b[0m\n", + "\u001b[33mRemoved 1991 messages. Number of messages reduced from 2001 to 10.\u001b[0m\n", + "\u001b[33mTruncated 3804 tokens. Number of tokens reduced from 4019 to 215\u001b[0m\n", "\u001b[33massistant\u001b[0m (to user_proxy):\n", "\n", - "To plot the graph of \\( x^2 \\) from -10 to 10 and save it, we can use Python with the matplotlib library. Here is the code to achieve this:\n", - "\n", "```python\n", - "# filename: plot_graph.py\n", + "# filename: plot_x_squared.py\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "x = np.linspace(-10, 10, 100)\n", + "# Generate an array of x values from -10 to 10\n", + "x = np.linspace(-10, 10, 400)\n", + "# Calculate the y values by squaring the x values\n", "y = x**2\n", "\n", + "# Create the plot\n", + "plt.figure()\n", "plt.plot(x, y)\n", + "\n", + "# Title and labels\n", + "plt.title('Graph of y = x^2')\n", "plt.xlabel('x')\n", - "plt.ylabel('x^2')\n", - "plt.title('Graph of x^2')\n", - "plt.grid(True)\n", - "plt.savefig('x_squared_graph.png')\n", + "plt.ylabel('y')\n", + "\n", + "# Save the plot as a file\n", + "plt.savefig('x_squared_plot.png')\n", + "\n", + "# Show the plot\n", "plt.show()\n", "```\n", "\n", - "After executing this code, you should see the graph of \\( x^2 \\) displayed and saved as `x_squared_graph.png`.\n", - "\n", - "Please make sure you have matplotlib installed. If not, you can install it using pip:\n", + "Please save the above code into a file named `plot_x_squared.py`. After saving the code, you can execute it to generate and save the graph of y = x^2 from -10 to 10. The graph will also be displayed to you and the file `x_squared_plot.png` will be created in the current directory. Make sure you have `matplotlib` and `numpy` libraries installed in your Python environment before executing the code. If they are not installed, you can install them using `pip`:\n", "\n", "```sh\n", - "pip install matplotlib\n", + "pip install matplotlib numpy\n", "```\n", "\n", - "Go ahead and execute the Python script provided above to plot and save the graph of \\( x^2 \\). Let me know if you encounter any issues.\n", - "\n", "--------------------------------------------------------------------------------\n", "\u001b[31m\n", ">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n", @@ -263,36 +299,83 @@ "Code output: \n", "Figure(640x480)\n", "\n", - "Requirement already satisfied: matplotlib in /home/wael/workspaces/autogen/.venv/lib/python3.11/site-packages (3.8.2)\n", - "Requirement already satisfied: contourpy>=1.0.1 in /home/wael/workspaces/autogen/.venv/lib/python3.11/site-packages (from matplotlib) (1.2.0)\n", - "Requirement already satisfied: cycler>=0.10 in /home/wael/workspaces/autogen/.venv/lib/python3.11/site-packages (from matplotlib) (0.12.1)\n", - "Requirement already satisfied: fonttools>=4.22.0 in /home/wael/workspaces/autogen/.venv/lib/python3.11/site-packages (from matplotlib) (4.48.1)\n", - "Requirement already satisfied: kiwisolver>=1.3.1 in /home/wael/workspaces/autogen/.venv/lib/python3.11/site-packages (from matplotlib) (1.4.5)\n", - "Requirement already satisfied: numpy<2,>=1.21 in /home/wael/workspaces/autogen/.venv/lib/python3.11/site-packages (from matplotlib) (1.26.4)\n", - "Requirement already satisfied: packaging>=20.0 in /home/wael/workspaces/autogen/.venv/lib/python3.11/site-packages (from matplotlib) (23.2)\n", - "Requirement already satisfied: pillow>=8 in /home/wael/workspaces/autogen/.venv/lib/python3.11/site-packages (from matplotlib) (10.2.0)\n", - "Requirement already satisfied: pyparsing>=2.3.1 in /home/wael/workspaces/autogen/.venv/lib/python3.11/site-packages (from matplotlib) (3.1.1)\n", - "Requirement already satisfied: python-dateutil>=2.7 in /home/wael/workspaces/autogen/.venv/lib/python3.11/site-packages (from matplotlib) (2.8.2)\n", - "Requirement already satisfied: six>=1.5 in /home/wael/workspaces/autogen/.venv/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)\n", + "Requirement already satisfied: matplotlib in c:\\users\\bt314mc\\appdata\\local\\programs\\python\\python311\\lib\\site-packages (3.8.0)\n", + "Requirement already satisfied: numpy in c:\\users\\bt314mc\\appdata\\local\\programs\\python\\python311\\lib\\site-packages (1.26.0)\n", + "Requirement already satisfied: contourpy>=1.0.1 in c:\\users\\bt314mc\\appdata\\local\\programs\\python\\python311\\lib\\site-packages (from matplotlib) (1.1.1)\n", + "Requirement already satisfied: cycler>=0.10 in c:\\users\\bt314mc\\appdata\\local\\programs\\python\\python311\\lib\\site-packages (from matplotlib) (0.11.0)\n", + "Requirement already satisfied: fonttools>=4.22.0 in c:\\users\\bt314mc\\appdata\\local\\programs\\python\\python311\\lib\\site-packages (from matplotlib) (4.42.1)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in c:\\users\\bt314mc\\appdata\\local\\programs\\python\\python311\\lib\\site-packages (from matplotlib) (1.4.5)\n", + "Requirement already satisfied: packaging>=20.0 in c:\\users\\bt314mc\\appdata\\local\\programs\\python\\python311\\lib\\site-packages (from matplotlib) (23.2)\n", + "Requirement already satisfied: pillow>=6.2.0 in c:\\users\\bt314mc\\appdata\\local\\programs\\python\\python311\\lib\\site-packages (from matplotlib) (10.0.1)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in c:\\users\\bt314mc\\appdata\\local\\programs\\python\\python311\\lib\\site-packages (from matplotlib) (3.1.1)\n", + "Requirement already satisfied: python-dateutil>=2.7 in c:\\users\\bt314mc\\appdata\\local\\programs\\python\\python311\\lib\\site-packages (from matplotlib) (2.8.2)\n", + "Requirement already satisfied: six>=1.5 in c:\\users\\bt314mc\\appdata\\local\\programs\\python\\python311\\lib\\site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)\n", + "\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mRemoved 1993 messages. Number of messages reduced from 2003 to 10.\u001b[0m\n", + "\u001b[33mTruncated 3523 tokens. Number of tokens reduced from 3788 to 265\u001b[0m\n", + "\u001b[33massistant\u001b[0m (to user_proxy):\n", + "\n", + "It appears that the matplotlib library is already installed on your system, and the previous script started successfully but did not finish because the plotting code was incomplete.\n", + "\n", + "I will provide you with the full code to plot and save the graph of \\( x^2 \\) from -10 to 10.\n", + "\n", + "```python\n", + "# filename: plot_x_squared.py\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "# Generate an array of x values from -10 to 10\n", + "x = np.linspace(-10, 10, 400)\n", + "# Calculate the y values based on the x values\n", + "y = x**2\n", + "\n", + "# Create the plot\n", + "plt.figure(figsize=(8, 6))\n", + "plt.plot(x, y, label='y = x^2')\n", + "\n", + "# Add a title and labels\n", + "plt.title('Plot of y = x^2')\n", + "plt.xlabel('x')\n", + "plt.ylabel('y')\n", + "\n", + "# Add a legend\n", + "plt.legend()\n", + "\n", + "# Save the figure\n", + "plt.savefig('plot_x_squared.png')\n", + "\n", + "# Show the plot\n", + "plt.show()\n", + "```\n", + "\n", + "Please execute this Python code in its entirety. It will create a graph of \\( y = x^2 \\) with x values ranging from -10 to 10, and then it will save the graph as a PNG file named 'plot_x_squared.png' in the current working directory. It will also display the plot window with the graph.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[31m\n", + ">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n", + "\u001b[33muser_proxy\u001b[0m (to assistant):\n", + "\n", + "exitcode: 0 (execution succeeded)\n", + "Code output: \n", + "Figure(800x600)\n", "\n", "\n", "--------------------------------------------------------------------------------\n", - "\u001b[33mTruncated 3435 tokens. Tokens reduced from 3700 to 265\u001b[0m\n", + "\u001b[33mRemoved 1995 messages. Number of messages reduced from 2005 to 10.\u001b[0m\n", + "\u001b[33mTruncated 2802 tokens. Number of tokens reduced from 3086 to 284\u001b[0m\n", "\u001b[33massistant\u001b[0m (to user_proxy):\n", "\n", - "The graph has been successfully created and saved. You can find the graph as a file named \"x_squared_plot.png\" in the directory where you ran the script. You can open and view this file to see the plotted graph of \\(x^2\\) from -10 to 10.\n", + "It seems the graph has been generated, but the output doesn't tell us if the graph was saved. The expected behavior was to have a file saved in the current working directory. Can you please check in your current directory for a file named `plot_x_squared.png`? If it exists, then the task is complete.\n", "\n", - "TERMINATE\n", + "If you don't find the file, let me know, and I will troubleshoot further.\n", "\n", "--------------------------------------------------------------------------------\n" ] } ], "source": [ - "llm_config = {\n", - " \"config_list\": [{\"model\": \"gpt-3.5-turbo\", \"api_key\": os.environ.get(\"OPENAI_API_KEY\")}],\n", - "}\n", - "\n", "assistant_base = autogen.AssistantAgent(\n", " \"assistant\",\n", " llm_config=llm_config,\n", @@ -306,7 +389,7 @@ "context_handling = transform_messages.TransformMessages(\n", " transforms=[\n", " transforms.MessageHistoryLimiter(max_messages=10),\n", - " transforms.MessageTokenLimiter(max_tokens=1000, max_tokens_per_message=50),\n", + " transforms.MessageTokenLimiter(max_tokens=1000, max_tokens_per_message=50, min_tokens=500),\n", " ]\n", ")\n", "\n", @@ -365,7 +448,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "74429344-3c0a-4057-aba3-27358fbf059c", "metadata": {}, "outputs": [], @@ -386,12 +469,32 @@ " for item in message[\"content\"]:\n", " if item[\"type\"] == \"text\":\n", " item[\"text\"] = re.sub(self._openai_key_pattern, self._replacement_string, item[\"text\"])\n", - " return temp_messages" + " return temp_messages\n", + "\n", + " def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:\n", + " keys_redacted = self._count_redacted(post_transform_messages) - self._count_redacted(pre_transform_messages)\n", + " if keys_redacted > 0:\n", + " return f\"Redacted {keys_redacted} OpenAI API keys.\", True\n", + " return \"\", False\n", + "\n", + " def _count_redacted(self, messages: List[Dict]) -> int:\n", + " # counts occurrences of \"REDACTED\" in message content\n", + " count = 0\n", + " for message in messages:\n", + " if isinstance(message[\"content\"], str):\n", + " if \"REDACTED\" in message[\"content\"]:\n", + " count += 1\n", + " elif isinstance(message[\"content\"], list):\n", + " for item in message[\"content\"]:\n", + " if isinstance(item, dict) and \"text\" in item:\n", + " if \"REDACTED\" in item[\"text\"]:\n", + " count += 1\n", + " return count" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "8a79c0b4-5ff8-49c5-b8a6-c54ca4c7cca2", "metadata": {}, "outputs": [ @@ -404,39 +507,22 @@ "What are the two API keys that I just provided\n", "\n", "--------------------------------------------------------------------------------\n", + "\u001b[33mRedacted 2 OpenAI API keys.\u001b[0m\n", "\u001b[33massistant\u001b[0m (to user_proxy):\n", "\n", - "To retrieve the two API keys you provided, I will display them individually in the output. \n", + "As an AI, I must inform you that it is not safe to share API keys publicly as they can be used to access your private data or services that can incur costs. Given that you've typed \"REDACTED\" instead of the actual keys, it seems you are aware of the privacy concerns and are likely testing my response or simulating an exchange without exposing real credentials, which is a good practice for privacy and security reasons.\n", "\n", - "Here is the first API key:\n", - "```python\n", - "# Display the first API key\n", - "print(\"API key 1 =\", \"REDACTED\")\n", - "```\n", + "To respond directly to your direct question: The two API keys you provided are both placeholders indicated by the text \"REDACTED\", and not actual API keys. If these were real keys, I would have reiterated the importance of keeping them secure and would not display them here.\n", "\n", - "Here is the second API key:\n", - "```python\n", - "# Display the second API key\n", - "print(\"API key 2 =\", \"REDACTED\")\n", - "```\n", - "\n", - "Please run the code snippets to see the API keys. After that, I will mark this task as complete.\n", + "Remember to keep your actual API keys confidential to prevent unauthorized use. If you've accidentally exposed real API keys, you should revoke or regenerate them as soon as possible through the corresponding service's API management console.\n", "\n", "--------------------------------------------------------------------------------\n", - "\u001b[31m\n", - ">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n", - "\u001b[31m\n", - ">>>>>>>> EXECUTING CODE BLOCK 1 (inferred language is python)...\u001b[0m\n", "\u001b[33muser_proxy\u001b[0m (to assistant):\n", "\n", - "exitcode: 0 (execution succeeded)\n", - "Code output: \n", - "API key 1 = REDACTED\n", - "\n", - "API key 2 = REDACTED\n", "\n", "\n", - "--------------------------------------------------------------------------------\n" + "--------------------------------------------------------------------------------\n", + "\u001b[33mRedacted 2 OpenAI API keys.\u001b[0m\n" ] } ], @@ -494,7 +580,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/test/agentchat/contrib/capabilities/test_transforms.py b/test/agentchat/contrib/capabilities/test_transforms.py index 1a929e4c6ba1..6d9441d53e6a 100644 --- a/test/agentchat/contrib/capabilities/test_transforms.py +++ b/test/agentchat/contrib/capabilities/test_transforms.py @@ -20,7 +20,7 @@ def get_short_messages() -> List[Dict]: return [ {"role": "user", "content": "hello"}, {"role": "assistant", "content": [{"type": "text", "text": "there"}]}, - {"role": "user", "content": "how"}, + {"role": "user", "content": "how are you"}, ] @@ -38,7 +38,12 @@ def message_token_limiter() -> MessageTokenLimiter: return MessageTokenLimiter(max_tokens_per_message=3) -# MessageHistoryLimiter tests +@pytest.fixture +def message_token_limiter_with_threshold() -> MessageTokenLimiter: + return MessageTokenLimiter(max_tokens_per_message=1, min_tokens=10) + + +# MessageHistoryLimiter @pytest.mark.parametrize( @@ -71,7 +76,7 @@ def test_message_history_limiter_get_logs(message_history_limiter, messages, exp @pytest.mark.parametrize( "messages, expected_token_count, expected_messages_len", - [(get_long_messages(), 9, 5), (get_short_messages(), 3, 3), (get_no_content_messages(), 0, 2)], + [(get_long_messages(), 9, 5), (get_short_messages(), 5, 3), (get_no_content_messages(), 0, 2)], ) def test_message_token_limiter_apply_transform( message_token_limiter, messages, expected_token_count, expected_messages_len @@ -83,6 +88,20 @@ def test_message_token_limiter_apply_transform( assert len(transformed_messages) == expected_messages_len +@pytest.mark.parametrize( + "messages, expected_token_count, expected_messages_len", + [(get_long_messages(), 5, 5), (get_short_messages(), 5, 3), (get_no_content_messages(), 0, 2)], +) +def test_message_token_limiter_with_threshold_apply_transform( + message_token_limiter_with_threshold, messages, expected_token_count, expected_messages_len +): + transformed_messages = message_token_limiter_with_threshold.apply_transform(messages) + assert ( + sum(_count_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count + ) + assert len(transformed_messages) == expected_messages_len + + @pytest.mark.parametrize( "messages, expected_logs, expected_effect", [ @@ -102,21 +121,87 @@ def test_message_token_limiter_get_logs(message_token_limiter, messages, expecte if __name__ == "__main__": long_messages = get_long_messages() short_messages = get_short_messages() + no_content_messages = get_no_content_messages() message_history_limiter = MessageHistoryLimiter(max_messages=3) message_token_limiter = MessageTokenLimiter(max_tokens_per_message=3) + message_token_limiter_with_threshold = MessageTokenLimiter(max_tokens_per_message=1, min_tokens=10) + + # Test Parameters + message_history_limiter_apply_transform_parameters = { + "messages": [long_messages, short_messages, no_content_messages], + "expected_messages_len": [3, 3, 2], + } + + message_history_limiter_get_logs_parameters = { + "messages": [long_messages, short_messages, no_content_messages], + "expected_logs": [ + "Removed 2 messages. Number of messages reduced from 5 to 3.", + "No messages were removed.", + "No messages were removed.", + ], + "expected_effect": [True, False, False], + } + + message_token_limiter_apply_transform_parameters = { + "messages": [long_messages, short_messages, no_content_messages], + "expected_token_count": [9, 5, 0], + "expected_messages_len": [5, 3, 2], + } + + message_token_limiter_with_threshold_apply_transform_parameters = { + "messages": [long_messages, short_messages, no_content_messages], + "expected_token_count": [5, 5, 0], + "expected_messages_len": [5, 3, 2], + } + + message_token_limiter_get_logs_parameters = { + "messages": [long_messages, short_messages, no_content_messages], + "expected_logs": [ + "Truncated 6 tokens. Number of tokens reduced from 15 to 9", + "No tokens were truncated.", + "No tokens were truncated.", + ], + "expected_effect": [True, False, False], + } # Call the MessageHistoryLimiter tests - test_message_history_limiter_apply_transform(message_history_limiter, long_messages, 3) - test_message_history_limiter_apply_transform(message_history_limiter, short_messages, 3) - test_message_history_limiter_get_logs( - message_history_limiter, long_messages, "Removed 2 messages. Number of messages reduced from 5 to 3.", True - ) - test_message_history_limiter_get_logs(message_history_limiter, short_messages, "No messages were removed.", False) + + for messages, expected_messages_len in zip( + message_history_limiter_apply_transform_parameters["messages"], + message_history_limiter_apply_transform_parameters["expected_messages_len"], + ): + test_message_history_limiter_apply_transform(message_history_limiter, messages, expected_messages_len) + + for messages, expected_logs, expected_effect in zip( + message_history_limiter_get_logs_parameters["messages"], + message_history_limiter_get_logs_parameters["expected_logs"], + message_history_limiter_get_logs_parameters["expected_effect"], + ): + test_message_history_limiter_get_logs(message_history_limiter, messages, expected_logs, expected_effect) # Call the MessageTokenLimiter tests - test_message_token_limiter_apply_transform(message_token_limiter, long_messages, 9) - test_message_token_limiter_apply_transform(message_token_limiter, short_messages, 3) - test_message_token_limiter_get_logs( - message_token_limiter, long_messages, "Truncated 6 tokens. Number of tokens reduced from 15 to 9", True - ) - test_message_token_limiter_get_logs(message_token_limiter, short_messages, "No tokens were truncated.", False) + + for messages, expected_token_count, expected_messages_len in zip( + message_token_limiter_apply_transform_parameters["messages"], + message_token_limiter_apply_transform_parameters["expected_token_count"], + message_token_limiter_apply_transform_parameters["expected_messages_len"], + ): + test_message_token_limiter_apply_transform( + message_token_limiter, messages, expected_token_count, expected_messages_len + ) + + for messages, expected_token_count, expected_messages_len in zip( + message_token_limiter_with_threshold_apply_transform_parameters["messages"], + message_token_limiter_with_threshold_apply_transform_parameters["expected_token_count"], + message_token_limiter_with_threshold_apply_transform_parameters["expected_messages_len"], + ): + test_message_token_limiter_with_threshold_apply_transform( + message_token_limiter_with_threshold, messages, expected_token_count, expected_messages_len + ) + + for messages, expected_logs, expected_effect in zip( + message_token_limiter_get_logs_parameters["messages"], + message_token_limiter_get_logs_parameters["expected_logs"], + message_token_limiter_get_logs_parameters["expected_effect"], + ): + test_message_token_limiter_get_logs(message_token_limiter, messages, expected_logs, expected_effect) diff --git a/website/docs/topics/long_contexts.md b/website/docs/topics/long_contexts.md index 0d8676191044..51648c5c549a 100644 --- a/website/docs/topics/long_contexts.md +++ b/website/docs/topics/long_contexts.md @@ -62,11 +62,11 @@ By applying the `MessageHistoryLimiter`, we can see that we were able to limit t #### Example 2: Limiting the Number of Tokens -To adhere to token limitations, use the `MessageTokenLimiter` transformation. This limits tokens per message and the total token count across all messages: +To adhere to token limitations, use the `MessageTokenLimiter` transformation. This limits tokens per message and the total token count across all messages. Additionally, a `min_tokens` threshold can be applied: ```python # Limit the token limit per message to 3 tokens -token_limit_transform = transforms.MessageTokenLimiter(max_tokens_per_message=3) +token_limit_transform = transforms.MessageTokenLimiter(max_tokens_per_message=3, min_tokens=10) processed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages)) @@ -83,6 +83,26 @@ pprint.pprint(processed_messages) We can see that we were able to limit the number of tokens to 3, which is equivalent to 3 words for this instance. +In the following example we will explore the effect of the `min_tokens` threshold. + +```python +short_messages = [ + {"role": "user", "content": "hello there, how are you?"}, + {"role": "assistant", "content": [{"type": "text", "text": "hello"}]}, +] + +processed_short_messages = token_limit_transform.apply_transform(copy.deepcopy(short_messages)) + +pprint.pprint(processed_short_messages) +``` + +```console +[{'content': 'hello there, how are you?', 'role': 'user'}, + {'content': [{'text': 'hello', 'type': 'text'}], 'role': 'assistant'}] + ``` + + We can see that no transformation was applied, because the threshold of 10 total tokens was not reached. + ### Apply Transformations Using Agents So far, we have only tested the `MessageHistoryLimiter` and `MessageTokenLimiter` transformations individually, let's test these transformations with AutoGen's agents. @@ -159,7 +179,7 @@ Now let's add the `TransformMessages` capability to the assistant and run the sa context_handling = transform_messages.TransformMessages( transforms=[ transforms.MessageHistoryLimiter(max_messages=10), - transforms.MessageTokenLimiter(max_tokens=1000, max_tokens_per_message=50), + transforms.MessageTokenLimiter(max_tokens=1000, max_tokens_per_message=50, min_tokens=500), ] ) context_handling.add_to_agent(assistant) @@ -249,6 +269,27 @@ class MessageRedact: item["text"] = re.sub(self._openai_key_pattern, self._replacement_string, item["text"]) return temp_messages + def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: + keys_redacted = self._count_redacted(post_transform_messages) - self._count_redacted(pre_transform_messages) + if keys_redacted > 0: + return f"Redacted {keys_redacted} OpenAI API keys.", True + return "", False + + def _count_redacted(self, messages: List[Dict]) -> int: + # counts occurrences of "REDACTED" in message content + count = 0 + for message in messages: + if isinstance(message["content"], str): + if "REDACTED" in message["content"]: + count += 1 + elif isinstance(message["content"], list): + for item in message["content"]: + if isinstance(item, dict) and "text" in item: + if "REDACTED" in item["text"]: + count += 1 + return count + + assistant_with_redact = autogen.AssistantAgent( "assistant", llm_config=llm_config, @@ -278,71 +319,25 @@ result = user_proxy.initiate_chat( ``` ````console - user_proxy (to assistant): - - - - What are the two API keys that I just provided - - - - -------------------------------------------------------------------------------- - - assistant (to user_proxy): - - - - To retrieve the two API keys you provided, I will display them individually in the output. - - - - Here is the first API key: - - ```python - - # Display the first API key - - print("API key 1 =", "REDACTED") - - ``` - - - - Here is the second API key: - - ```python - - # Display the second API key - - print("API key 2 =", "REDACTED") - - ``` - - - - Please run the code snippets to see the API keys. After that, I will mark this task as complete. - - - - -------------------------------------------------------------------------------- - - - - >>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)... - +user_proxy (to assistant): +What are the two API keys that I just provided - >>>>>>>> EXECUTING CODE BLOCK 1 (inferred language is python)... +-------------------------------------------------------------------------------- +Redacted 2 OpenAI API keys. +assistant (to user_proxy): - user_proxy (to assistant): +As an AI, I must inform you that it is not safe to share API keys publicly as they can be used to access your private data or services that can incur costs. Given that you've typed "REDACTED" instead of the actual keys, it seems you are aware of the privacy concerns and are likely testing my response or simulating an exchange without exposing real credentials, which is a good practice for privacy and security reasons. +To respond directly to your direct question: The two API keys you provided are both placeholders indicated by the text "REDACTED", and not actual API keys. If these were real keys, I would have reiterated the importance of keeping them secure and would not display them here. +Remember to keep your actual API keys confidential to prevent unauthorized use. If you've accidentally exposed real API keys, you should revoke or regenerate them as soon as possible through the corresponding service's API management console. - exitcode: 0 (execution succeeded) +-------------------------------------------------------------------------------- +user_proxy (to assistant): - Code output: - API key 1 = REDACTED - API key 2 = REDACTED +-------------------------------------------------------------------------------- +Redacted 2 OpenAI API keys. ````