Skip to content

Commit

Permalink
Add the add_tool(), remove_tool() and remove_all_tools() method…
Browse files Browse the repository at this point in the history
…s for `AssistantAgent`
  • Loading branch information
Jean-Marc Le Roux committed Dec 16, 2024
1 parent 7eaffa8 commit 14e4dc1
Showing 1 changed file with 44 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,24 +231,10 @@ def __init__(
else:
self._system_messages = [SystemMessage(content=system_message)]
self._tools: List[Tool] = []
if tools is not None:
if model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling.")
for tool in tools:
if isinstance(tool, Tool):
self._tools.append(tool)
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
self._tools.append(FunctionTool(tool, description=description))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
# Check if tool names are unique.
tool_names = [tool.name for tool in self._tools]
if len(tool_names) != len(set(tool_names)):
raise ValueError(f"Tool names must be unique: {tool_names}")
self._model_context: List[LLMMessage] = []
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False
# Handoff tools.
self._handoff_tools: List[Tool] = []
self._handoffs: Dict[str, HandoffBase] = {}
Expand All @@ -258,24 +244,54 @@ def __init__(
for handoff in handoffs:
if isinstance(handoff, str):
handoff = HandoffBase(target=handoff)
if handoff.name in self._handoffs:
raise ValueError(f"Handoff name {handoff.name} already exists.")
if isinstance(handoff, HandoffBase):
self._handoff_tools.append(handoff.handoff_tool)
self._handoffs[handoff.name] = handoff
else:
raise ValueError(f"Unsupported handoff type: {type(handoff)}")
# Check if handoff tool names are unique.
handoff_tool_names = [tool.name for tool in self._handoff_tools]
if len(handoff_tool_names) != len(set(handoff_tool_names)):
raise ValueError(f"Handoff names must be unique: {handoff_tool_names}")
if tools is not None:
for tool in tools:
self.add_tool(tool)

def add_tool(self, tool: Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]) -> None:
new_tool = None
if self._model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling.")
if isinstance(tool, Tool):
new_tool = tool
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
new_tool = FunctionTool(tool, description=description)
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
# Check if tool names are unique.
if new_tool.name in self._tools:
raise ValueError(f"Tool names must be unique: {new_tool.name}")
# Check if handoff tool names not in tool names.
if any(name in tool_names for name in handoff_tool_names):
handoff_tool_names = [handoff.name for handoff in self._handoffs.values()]
if new_tool.name in handoff_tool_names:
raise ValueError(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; "
f"tool names: {new_tool.name}"
)
self._model_context: List[LLMMessage] = []
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False
self._tools.append(new_tool)

def remove_all_tools(self) -> None:
"""Remove all tools."""
self._tools.clear()

def remove_tool(self, tool_name: str) -> None:
"""Remove tools by name."""
for tool in self._tools:
if tool.name == tool_name:
self._tools.remove(tool)
return
raise ValueError(f"Tool {tool_name} not found.")

@property
def produced_message_types(self) -> List[type[ChatMessage]]:
Expand Down

0 comments on commit 14e4dc1

Please sign in to comment.