Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft, Feedback Needed] Memory in AgentChat #4438

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ToolCallResultMessage,
)
from ._base_chat_agent import BaseChatAgent
from ..memory._base_memory import Memory, MemoryQueryResult

event_logger = logging.getLogger(EVENT_LOGGER_NAME)

Expand Down Expand Up @@ -60,10 +61,12 @@ def set_defaults(cls, values: Dict[str, Any]) -> Dict[str, Any]:
else:
name = values["name"]
if not isinstance(name, str):
raise ValueError(f"Handoff name must be a string: {values['name']}")
raise ValueError(
f"Handoff name must be a string: {values['name']}")
# Check if name is a valid identifier.
if not name.isidentifier():
raise ValueError(f"Handoff name must be a valid identifier: {values['name']}")
raise ValueError(
f"Handoff name must be a valid identifier: {values['name']}")
if values.get("message") is None:
values["message"] = (
f"Transferred to {values['target']}, adopting the role of {values['target']} immediately."
Expand Down Expand Up @@ -203,22 +206,29 @@ def __init__(
name: str,
model_client: ChatCompletionClient,
*,
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
tools: List[Tool | Callable[..., Any] |
Callable[..., Awaitable[Any]]] | None = None,
handoffs: List[Handoff | str] | None = None,
memory: Memory | None = None,
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str
| None = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
):
super().__init__(name=name, description=description)
self._model_client = model_client
self._memory = memory

self._system_messages: List[SystemMessage | UserMessage |
AssistantMessage | FunctionExecutionResultMessage] = []
if system_message is None:
self._system_messages = []
else:
self._system_messages = [SystemMessage(content=system_message)]
self._tools: List[Tool] = []
if tools is not None:
if model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling.")
raise ValueError(
"The model does not support function calling.")
for tool in tools:
if isinstance(tool, Tool):
self._tools.append(tool)
Expand All @@ -227,7 +237,8 @@ def __init__(
description = tool.__doc__
else:
description = ""
self._tools.append(FunctionTool(tool, description=description))
self._tools.append(FunctionTool(
tool, description=description))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
# Check if tool names are unique.
Expand All @@ -239,26 +250,42 @@ def __init__(
self._handoffs: Dict[str, Handoff] = {}
if handoffs is not None:
if model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling, which is needed for handoffs.")
raise ValueError(
"The model does not support function calling, which is needed for handoffs.")
for handoff in handoffs:
if isinstance(handoff, str):
handoff = Handoff(target=handoff)
if isinstance(handoff, Handoff):
self._handoff_tools.append(handoff.handoff_tool)
self._handoffs[handoff.name] = handoff
else:
raise ValueError(f"Unsupported handoff type: {type(handoff)}")
raise ValueError(
f"Unsupported handoff type: {type(handoff)}")
# Check if handoff tool names are unique.
handoff_tool_names = [tool.name for tool in self._handoff_tools]
if len(handoff_tool_names) != len(set(handoff_tool_names)):
raise ValueError(f"Handoff names must be unique: {handoff_tool_names}")
raise ValueError(
f"Handoff names must be unique: {handoff_tool_names}")
# Check if handoff tool names not in tool names.
if any(name in tool_names for name in handoff_tool_names):
raise ValueError(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
self._model_context: List[LLMMessage] = []

def _format_memory_context(self, results: List[MemoryQueryResult]) -> str:
if not results or not self._memory: # Guard against no memory
return ""

context_lines = []
for i, result in enumerate(results, 1):
context_lines.append(
self._memory.config.context_format.format(
i=i, content=result.entry.content, score=result.score)
)

return "".join(context_lines)

@property
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the assistant agent produces."""
Expand All @@ -270,44 +297,70 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")
raise AssertionError(
"The stream should have returned the final result.")

async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentMessage | Response, None]:
# Query memory if available with the last message
memory_context = ""
if self._memory is not None and messages:
try:
last_message = messages[-1]
# ensure the last message is a text message or multimodal message
if not isinstance(last_message, TextMessage) and not isinstance(last_message, MultiModalMessage):
raise ValueError(
"Memory query failed: Last message must be a text message or multimodal message.")
results: List[MemoryQueryResult] = await self._memory.query(messages[-1].content, cancellation_token=cancellation_token)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is saying that the default memory implementation is RAG with the last message.

I feel that using the last message for the RAG is reasonable but restrictive, could we just pass messages and have the memory query decide?

A high level comment that I'm not sure whether to call this memory or datastore/database.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@husseinmozannar , I agree about passing the entire context and let the query method decide.

 query: ContentItem | List[ContentItem] 

I am flexible on naming.
I like Memory because it connotes "just in time" retrieval/recall of content relevant to a step (in a task) an agent is about to take. Memory also gives a sense of what is stored inside the memory - in this case it really should be content relevant to task completion (not just anything that can be in a database).

memory_context = self._format_memory_context(results)
except Exception as e:
event_logger.warning(f"Memory query failed: {e}")

# Add messages to the model context.
for msg in messages:
if isinstance(msg, MultiModalMessage) and self._model_client.capabilities["vision"] is False:
raise ValueError("The model does not support vision.")
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
self._model_context.append(UserMessage(
content=msg.content, source=msg.source))

# Inner messages.
inner_messages: List[AgentMessage] = []

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
# Prepare messages for model with memory context if available
llm_messages = self._system_messages
if memory_context:
llm_messages = llm_messages + \
[SystemMessage(content=memory_context)]
llm_messages = llm_messages + self._model_context

# Generate inference result
result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)

# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
self._model_context.append(AssistantMessage(
content=result.content, source=self.name))

# Run tool calls until the model produces a string response.
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
tool_call_msg = ToolCallMessage(content=result.content, source=self.name, models_usage=result.usage)
tool_call_msg = ToolCallMessage(
content=result.content, source=self.name, models_usage=result.usage)
event_logger.debug(tool_call_msg)
# Add the tool call message to the output.
inner_messages.append(tool_call_msg)
yield tool_call_msg

# Execute the tool calls.
results = await asyncio.gather(
execution_results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
)
tool_call_result_msg = ToolCallResultMessage(content=results, source=self.name)
tool_call_result_msg = ToolCallResultMessage(
content=execution_results, source=self.name)
event_logger.debug(tool_call_result_msg)
self._model_context.append(FunctionExecutionResultMessage(content=results))
self._model_context.append(
FunctionExecutionResultMessage(content=execution_results))
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg

Expand All @@ -318,7 +371,8 @@ async def on_messages_stream(
handoffs.append(self._handoffs[call.name])
if len(handoffs) > 0:
if len(handoffs) > 1:
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
raise ValueError(
f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
# Return the output messages to signal the handoff.
yield Response(
chat_message=HandoffMessage(
Expand All @@ -329,15 +383,22 @@ async def on_messages_stream(
return

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
llm_messages = (
self._system_messages
+ ([SystemMessage(content=memory_context)]
if memory_context else [])
+ self._model_context
)
result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
self._model_context.append(AssistantMessage(
content=result.content, source=self.name))

assert isinstance(result.content, str)
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
chat_message=TextMessage(
content=result.content, source=self.name, models_usage=result.usage),
inner_messages=inner_messages,
)

Expand All @@ -348,9 +409,11 @@ async def _execute_tool_call(
try:
if not self._tools + self._handoff_tools:
raise ValueError("No tools are available.")
tool = next((t for t in self._tools + self._handoff_tools if t.name == tool_call.name), None)
tool = next((t for t in self._tools +
self._handoff_tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
raise ValueError(
f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
result_as_str = tool.return_value_as_string(result)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from datetime import datetime
from typing import Any, Dict, List, Protocol, Union, runtime_checkable

from autogen_core.base import CancellationToken
from autogen_core.components import Image
from pydantic import BaseModel, ConfigDict, Field


class BaseMemoryConfig(BaseModel):
"""Base configuration for memory implementations."""

k: int = Field(default=5, description="Number of results to return")
score_threshold: float | None = Field(default=None, description="Minimum relevance score")
context_format: str = Field(
default="Context {i}: {content} (score: {score:.2f})\n Use this information to address relevant tasks.",
description="Format string for memory results in prompt",
)

model_config = ConfigDict(arbitrary_types_allowed=True)


class MemoryEntry(BaseModel):
"""A memory entry containing content and metadata."""

content: Union[str, List[Union[str, Image]]]
colombod marked this conversation as resolved.
Show resolved Hide resolved
"""The content of the memory entry - can be text or multimodal."""

metadata: Dict[str, Any] = Field(default_factory=dict)
"""Optional metadata associated with the memory entry."""

timestamp: datetime = Field(default_factory=datetime.now)
"""When the memory was created."""

source: str | None = None
"""Optional source identifier for the memory."""

model_config = ConfigDict(arbitrary_types_allowed=True)


class MemoryQueryResult(BaseModel):
"""Result from a memory query including the entry and its relevance score."""

entry: MemoryEntry
"""The memory entry."""

score: float
"""Relevance score for this result. Higher means more relevant."""

model_config = ConfigDict(arbitrary_types_allowed=True)


@runtime_checkable
class Memory(Protocol):
"""Protocol defining the interface for memory implementations."""

@property
def name(self) -> str | None:
"""The name of this memory implementation."""
...

@property
def config(self) -> BaseMemoryConfig:
"""The configuration for this memory implementation."""
...

async def query(
self,
query: Union[str, Image, List[Union[str, Image]]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

query also could benefit of mimetypes

cancellation_token: CancellationToken | None = None,
**kwargs: Any,
) -> List[MemoryQueryResult]:
"""
Query the memory store and return relevant entries.

Args:
query: Text, image or multimodal query
cancellation_token: Optional token to cancel operation
**kwargs: Additional implementation-specific parameters

Returns:
List of memory entries with relevance scores
"""
...

async def add(self, entry: MemoryEntry, cancellation_token: CancellationToken | None = None) -> None:
"""
Add a new entry to memory.

Args:
entry: The memory entry to add
cancellation_token: Optional token to cancel operation
"""
...

async def clear(self) -> None:
"""Clear all entries from memory."""
...

async def cleanup(self) -> None:
"""Clean up any resources used by the memory implementation."""
...
Loading
Loading