Skip to content

Commit

Permalink
Add the add_tools(), remove_tools() and remove_all_tools() meth…
Browse files Browse the repository at this point in the history
…ods for `AssistantAgent`
  • Loading branch information
Jean-Marc Le Roux committed Dec 4, 2024
1 parent 8b05e03 commit 8c25461
Showing 1 changed file with 42 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,24 +182,7 @@ def __init__(
else:
self._system_messages = [SystemMessage(content=system_message)]
self._tools: List[Tool] = []
if tools is not None:
if model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling.")
for tool in tools:
if isinstance(tool, Tool):
self._tools.append(tool)
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
self._tools.append(FunctionTool(tool, description=description))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
# Check if tool names are unique.
tool_names = [tool.name for tool in self._tools]
if len(tool_names) != len(set(tool_names)):
raise ValueError(f"Tool names must be unique: {tool_names}")
self._model_context: List[LLMMessage] = []
# Handoff tools.
self._handoff_tools: List[Tool] = []
self._handoffs: Dict[str, HandoffBase] = {}
Expand All @@ -214,6 +197,27 @@ def __init__(
self._handoffs[handoff.name] = handoff
else:
raise ValueError(f"Unsupported handoff type: {type(handoff)}")
if tools is not None:
self.add_tools(tools)

def add_tools(self, tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]]) -> None:
if self._model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling.")
for tool in tools:
if isinstance(tool, Tool):
self._tools.append(tool)
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
self._tools.append(FunctionTool(tool, description=description))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
# Check if tool names are unique.
tool_names = [tool.name for tool in self._tools]
if len(tool_names) != len(set(tool_names)):
raise ValueError(f"Tool names must be unique: {tool_names}")
# Check if handoff tool names are unique.
handoff_tool_names = [tool.name for tool in self._handoff_tools]
if len(handoff_tool_names) != len(set(handoff_tool_names)):
Expand All @@ -223,7 +227,26 @@ def __init__(
raise ValueError(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
self._model_context: List[LLMMessage] = []

def remove_all_tools(self) -> None:
"""Remove all tools."""
self._tools = []

def remove_tools(self, tool_names: List[str]) -> None:
"""Remove tools by name."""
for name in tool_names:
for tool in self._tools:
if tool.name == name:
self._tools.remove(tool)
break
for tool in self._handoff_tools:
if tool.name == name:
self._handoff_tools.remove(tool)
break
for handoff in self._handoffs.values():
if handoff.name == name:
self._handoffs.pop(handoff.name)
break

@property
def produced_message_types(self) -> List[type[ChatMessage]]:
Expand Down

0 comments on commit 8c25461

Please sign in to comment.