diff --git a/lagent/agents/aggregator/tool_aggregator.py b/lagent/agents/aggregator/tool_aggregator.py index 58e492c2..0ea47379 100644 --- a/lagent/agents/aggregator/tool_aggregator.py +++ b/lagent/agents/aggregator/tool_aggregator.py @@ -2,8 +2,7 @@ from lagent.agents.aggregator.default_aggregator import DefaultAggregator from lagent.memory.base_memory import Memory -from lagent.prompts.parsers.tool_parser import MixedToolParser, ToolParser -from lagent.schema import AgentStatusCode +from lagent.prompts.parsers.tool_parser import MixedToolParser, ToolParser, ToolStatusCode class InternLMToolAggregator(DefaultAggregator): @@ -84,7 +83,7 @@ def aggregate(self, if message.sender == name: if isinstance(message.formatted, dict): parsed = message.formatted - if parsed['status'] == AgentStatusCode.SESSION_INVALID_ARG: + if parsed['status'] == ToolStatusCode.PARSING_ERROR: continue _message.append( dict( diff --git a/lagent/agents/stream.py b/lagent/agents/stream.py index f7a86074..cd41fc0a 100644 --- a/lagent/agents/stream.py +++ b/lagent/agents/stream.py @@ -9,8 +9,8 @@ from lagent.hooks import InternLMActionProcessor from lagent.llms import BaseLLM from lagent.memory import Memory -from lagent.prompts.parsers import InterpreterParser, MixedToolParser, PluginParser -from lagent.schema import AgentMessage, AgentStatusCode +from lagent.prompts.parsers import InterpreterParser, MixedToolParser, PluginParser, ToolStatusCode +from lagent.schema import AgentMessage from lagent.utils import create_object API_PREFIX = ( @@ -81,7 +81,7 @@ def __init__( action_hooks: List = [dict(type=InternLMActionProcessor)], finish_condition: Callable[ [AgentMessage], - bool] = lambda m: m.formatted['status'] == AgentStatusCode.END, + bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, max_turn: int = 4, **kwargs, ): @@ -165,7 +165,7 @@ def __init__( action_hooks: List = [dict(type=InternLMActionProcessor)], finish_condition: Callable[ [AgentMessage], - bool] = lambda m: m.formatted['status'] == AgentStatusCode.END, + bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, max_turn: int = 6, **kwargs, ): @@ -205,7 +205,7 @@ def __init__( action_hooks: List = [dict(type=InternLMActionProcessor)], finish_condition: Callable[ [AgentMessage], - bool] = lambda m: m.formatted['status'] == AgentStatusCode.END, + bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, max_turn: int = 4, **kwargs, ): @@ -289,7 +289,7 @@ def __init__( action_hooks: List = [dict(type=InternLMActionProcessor)], finish_condition: Callable[ [AgentMessage], - bool] = lambda m: m.formatted['status'] == AgentStatusCode.END, + bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, max_turn: int = 6, **kwargs, ): diff --git a/lagent/prompts/parsers/__init__.py b/lagent/prompts/parsers/__init__.py index de99b797..d9b2ce25 100644 --- a/lagent/prompts/parsers/__init__.py +++ b/lagent/prompts/parsers/__init__.py @@ -1,9 +1,9 @@ from .custom_parser import CustomFormatParser from .json_parser import JSONParser from .str_parser import StrParser -from .tool_parser import InterpreterParser, MixedToolParser, PluginParser, ToolParser +from .tool_parser import InterpreterParser, MixedToolParser, PluginParser, ToolParser, ToolStatusCode __all__ = [ 'CustomFormatParser', 'JSONParser', 'StrParser', 'ToolParser', - 'InterpreterParser', 'PluginParser', 'MixedToolParser' + 'InterpreterParser', 'PluginParser', 'MixedToolParser', 'ToolStatusCode' ] diff --git a/lagent/prompts/parsers/tool_parser.py b/lagent/prompts/parsers/tool_parser.py index 3836dcb3..53433127 100644 --- a/lagent/prompts/parsers/tool_parser.py +++ b/lagent/prompts/parsers/tool_parser.py @@ -1,10 +1,10 @@ import json +from enum import IntEnum # import re from typing import Any, Callable, List, Optional from lagent.prompts.parsers import StrParser -from lagent.schema import AgentStatusCode from lagent.utils import create_object, load_class_from_string @@ -15,6 +15,12 @@ def default_plugin_validate(plugin: str): return json.loads(plugin) +class ToolStatusCode(IntEnum): + NO_TOOL = 0 + VALID_TOOL = 1 + PARSING_ERROR = -1 + + class ToolParser(StrParser): def __init__(self, @@ -34,28 +40,20 @@ def __init__(self, validate, str) else validate def parse_response(self, data: str) -> dict: - # match = self.pattern.search(data) - # if not match: - # return dict( - # tool_type=None, - # thought=data, - # action=None, - # status=AgentStatusCode.END) - # thought, action = match.group(1), match.group(2).strip() if self.format_field['begin'] not in data: return dict( tool_type=None, thought=data, action=None, - status=AgentStatusCode.END) + status=ToolStatusCode.NO_TOOL) thought, action, *_ = data.split(self.format_field["begin"]) action = action.split(self.format_field['end'])[0] - status = AgentStatusCode.STREAM_ING + status = ToolStatusCode.VALID_TOOL if self.validate: try: action = self.validate(action) except Exception: - status = AgentStatusCode.SESSION_INVALID_ARG + status = ToolStatusCode.PARSING_ERROR return dict( tool_type=self.tool_type, thought=thought, @@ -131,7 +129,7 @@ def parse_response(self, data: str) -> dict: tool_type=None, thought=data, action=None, - status=AgentStatusCode.END) + status=ToolStatusCode.NO_TOOL) for name, parser in self.parsers.items(): res = parser.parse_response(data) if res['tool_type'] == name: diff --git a/lagent/schema.py b/lagent/schema.py index cb16c9c4..668846fb 100644 --- a/lagent/schema.py +++ b/lagent/schema.py @@ -88,7 +88,6 @@ class AgentStatusCode(IntEnum): class AgentMessage(BaseModel): - content: Any sender: str = 'user' formatted: Optional[Any] = None