Skip to content

Commit

Permalink
Add the add_tool(), remove_tool() and remove_all_tools() method…
Browse files Browse the repository at this point in the history
…s for `AssistantAgent`
  • Loading branch information
Jean-Marc Le Roux committed Dec 16, 2024
1 parent 7eaffa8 commit 6f3b0bb
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,24 +231,10 @@ def __init__(
else:
self._system_messages = [SystemMessage(content=system_message)]
self._tools: List[Tool] = []
if tools is not None:
if model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling.")
for tool in tools:
if isinstance(tool, Tool):
self._tools.append(tool)
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
self._tools.append(FunctionTool(tool, description=description))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
# Check if tool names are unique.
tool_names = [tool.name for tool in self._tools]
if len(tool_names) != len(set(tool_names)):
raise ValueError(f"Tool names must be unique: {tool_names}")
self._model_context: List[LLMMessage] = []
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False
# Handoff tools.
self._handoff_tools: List[Tool] = []
self._handoffs: Dict[str, HandoffBase] = {}
Expand All @@ -258,24 +244,54 @@ def __init__(
for handoff in handoffs:
if isinstance(handoff, str):
handoff = HandoffBase(target=handoff)
if handoff.name in self._handoffs:
raise ValueError(f"Handoff name {handoff.name} already exists.")
if isinstance(handoff, HandoffBase):
self._handoff_tools.append(handoff.handoff_tool)
self._handoffs[handoff.name] = handoff
else:
raise ValueError(f"Unsupported handoff type: {type(handoff)}")
# Check if handoff tool names are unique.
handoff_tool_names = [tool.name for tool in self._handoff_tools]
if len(handoff_tool_names) != len(set(handoff_tool_names)):
raise ValueError(f"Handoff names must be unique: {handoff_tool_names}")
if tools is not None:
for tool in tools:
self.add_tool(tool)

def add_tool(self, tool: Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]) -> None:
new_tool = None
if self._model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling.")
if isinstance(tool, Tool):
new_tool = tool
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
new_tool = FunctionTool(tool, description=description)
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
# Check if tool names are unique.
if any(tool.name == new_tool.name for tool in self._tools):
raise ValueError(f"Tool names must be unique: {new_tool.name}")
# Check if handoff tool names not in tool names.
if any(name in tool_names for name in handoff_tool_names):
handoff_tool_names = [handoff.name for handoff in self._handoffs.values()]
if new_tool.name in handoff_tool_names:
raise ValueError(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; "
f"tool names: {new_tool.name}"
)
self._model_context: List[LLMMessage] = []
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False
self._tools.append(new_tool)

def remove_all_tools(self) -> None:
"""Remove all tools."""
self._tools.clear()

def remove_tool(self, tool_name: str) -> None:
"""Remove tools by name."""
for tool in self._tools:
if tool.name == tool_name:
self._tools.remove(tool)
return
raise ValueError(f"Tool {tool_name} not found.")

@property
def produced_message_types(self) -> List[type[ChatMessage]]:
Expand Down
75 changes: 75 additions & 0 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,78 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
else:
assert message == result.messages[index]
index += 1


def test_tool_management():
model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="")
agent = AssistantAgent(name="test_assistant", model_client=model_client)

# Test function to be used as a tool
def sample_tool() -> str:
return "sample result"

# Test adding a tool
tool = FunctionTool(sample_tool, description="Sample tool")
agent.add_tool(tool)
assert len(agent._tools) == 1

# Test adding duplicate tool
with pytest.raises(ValueError, match="Tool names must be unique"):
agent.add_tool(tool)

# Test tool collision with handoff
agent_with_handoff = AssistantAgent(
name="test_assistant", model_client=model_client, handoffs=[Handoff(target="other_agent")]
)

conflicting_tool = FunctionTool(sample_tool, name="transfer_to_other_agent", description="Sample tool")
with pytest.raises(ValueError, match="Handoff names must be unique from tool names"):
agent_with_handoff.add_tool(conflicting_tool)

# Test removing a tool
agent.remove_tool(tool.name)
assert len(agent._tools) == 0

# Test removing non-existent tool
with pytest.raises(ValueError, match="Tool non_existent_tool not found"):
agent.remove_tool("non_existent_tool")

# Test removing all tools
agent.add_tool(tool)
assert len(agent._tools) == 1
agent.remove_all_tools()
assert len(agent._tools) == 0

# Test idempotency of remove_all_tools
agent.remove_all_tools()
assert len(agent._tools) == 0


def test_callable_tool_addition():
model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="")
agent = AssistantAgent(name="test_assistant", model_client=model_client)

# Test adding a callable directly
def documented_tool() -> str:
"""This is a documented tool"""
return "result"

agent.add_tool(documented_tool)
assert len(agent._tools) == 1
assert agent._tools[0].description == "This is a documented tool"

# Test adding async callable
async def async_tool() -> str:
return "async result"

agent.add_tool(async_tool)
assert len(agent._tools) == 2


def test_invalid_tool_addition():
model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="")
agent = AssistantAgent(name="test_assistant", model_client=model_client)

# Test adding invalid tool type
with pytest.raises(ValueError, match="Unsupported tool type"):
agent.add_tool("not a tool")

0 comments on commit 6f3b0bb

Please sign in to comment.