Skip to content

Commit

Permalink
Min tokens in token limiter (microsoft#2400)
Browse files Browse the repository at this point in the history
* Add minimum token threshold in MessageHistoryLimiter

* Update transforms tests for the threshold

* Move min_threshold_tokens from Message to Token Limiter

* Optimize _check_tokens_threshold

Co-authored-by: Wael Karkoub <[email protected]>

* Apply requested changes (renaming, phrasing, validations)

* Fix format

* Fix _check_tokens_threshold logic

* Update docs and notebook

* Improve phrasing

* Add min_tokens example in notebook

* Add min_tokens example in website docs

* Add min_tokens example in notebook

* Update website docs to be in sync with get_logs change

---------

Co-authored-by: Wael Karkoub <[email protected]>
Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
3 people authored Apr 29, 2024
1 parent 5a007e0 commit 11a4342
Show file tree
Hide file tree
Showing 4 changed files with 351 additions and 152 deletions.
45 changes: 39 additions & 6 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ class MessageHistoryLimiter:
def __init__(self, max_messages: Optional[int] = None):
"""
Args:
max_messages (None or int): Maximum number of messages to keep in the context.
Must be greater than 0 if not None.
max_messages Optional[int]: Maximum number of messages to keep in the context. Must be greater than 0 if not None.
"""
self._validate_max_messages(max_messages)
self._max_messages = max_messages
Expand All @@ -70,6 +69,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
Returns:
List[Dict]: A new list containing the most recent messages up to the specified maximum.
"""

if self._max_messages is None:
return messages

Expand Down Expand Up @@ -108,20 +108,23 @@ class MessageTokenLimiter:
The truncation process follows these steps in order:
1. Messages are processed in reverse order (newest to oldest).
2. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
1. The minimum tokens threshold (`min_tokens`) is checked (0 by default). If the total number of tokens in messages
are less than this threshold, then the messages are returned as is. In other case, the following process is applied.
2. Messages are processed in reverse order (newest to oldest).
3. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
and other types of content, only the text content is truncated.
3. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
4. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
exceeds this limit, the current message being processed get truncated to meet the total token count and any
remaining messages get discarded.
4. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
5. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
original message order.
"""

def __init__(
self,
max_tokens_per_message: Optional[int] = None,
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
model: str = "gpt-3.5-turbo-0613",
):
"""
Expand All @@ -130,11 +133,14 @@ def __init__(
Must be greater than or equal to 0 if not None.
max_tokens (Optional[int]): Maximum number of tokens to keep in the chat history.
Must be greater than or equal to 0 if not None.
min_tokens (Optional[int]): Minimum number of tokens in messages to apply the transformation.
Must be greater than or equal to 0 if not None.
model (str): The target OpenAI model for tokenization alignment.
"""
self._model = model
self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message)
self._max_tokens = self._validate_max_tokens(max_tokens)
self._min_tokens = self._validate_min_tokens(min_tokens, max_tokens)

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Applies token truncation to the conversation history.
Expand All @@ -147,6 +153,11 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""
assert self._max_tokens_per_message is not None
assert self._max_tokens is not None
assert self._min_tokens is not None

# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not self._are_min_tokens_reached(messages):
return messages

temp_messages = copy.deepcopy(messages)
processed_messages = []
Expand Down Expand Up @@ -194,6 +205,19 @@ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages:
return logs_str, True
return "No tokens were truncated.", False

def _are_min_tokens_reached(self, messages: List[Dict]) -> bool:
"""
Returns True if no minimum tokens restrictions are applied.
Either if the total number of tokens in the messages is greater than or equal to the `min_theshold_tokens`,
or no minimum tokens threshold is set.
"""
if not self._min_tokens:
return True

messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= self._min_tokens

def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]:
if isinstance(contents, str):
return self._truncate_tokens(contents, n_tokens)
Expand Down Expand Up @@ -244,6 +268,15 @@ def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int

return max_tokens if max_tokens is not None else sys.maxsize

def _validate_min_tokens(self, min_tokens: int, max_tokens: int) -> int:
if min_tokens is None:
return 0
if min_tokens < 0:
raise ValueError("min_tokens must be None or greater than or equal to 0.")
if max_tokens is not None and min_tokens > max_tokens:
raise ValueError("min_tokens must not be more than max_tokens.")
return min_tokens


def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
token_count = 0
Expand Down
Loading

0 comments on commit 11a4342

Please sign in to comment.