-
Notifications
You must be signed in to change notification settings - Fork 65
/
workflow.py
357 lines (284 loc) · 11.9 KB
/
workflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from llama_index.core.llms import ChatMessage, LLM
from llama_index.core.program.function_program import get_function_tool
from llama_index.core.tools import (
BaseTool,
ToolSelection,
)
from llama_index.core.workflow import (
Event,
StartEvent,
StopEvent,
Workflow,
step,
Context,
)
from llama_index.core.workflow.events import InputRequiredEvent, HumanResponseEvent
from llama_index.llms.openai import OpenAI
from utils import FunctionToolWithContext
# ---- Pydantic models for config/llm prediction ----
class AgentConfig(BaseModel):
"""Used to configure an agent."""
model_config = ConfigDict(arbitrary_types_allowed=True)
name: str
description: str
system_prompt: str | None = None
tools: list[BaseTool] | None = None
tools_requiring_human_confirmation: list[str] = Field(default_factory=list)
class TransferToAgent(BaseModel):
"""Used to transfer the user to a specific agent."""
agent_name: str
class RequestTransfer(BaseModel):
"""Used to signal that either you don't have the tools to complete the task, or you've finished your task and want to transfer to another agent."""
pass
# ---- Events used to orchestrate the workflow ----
class ActiveSpeakerEvent(Event):
pass
class OrchestratorEvent(Event):
pass
class ToolCallEvent(Event):
tool_call: ToolSelection
tools: list[BaseTool]
class ToolCallResultEvent(Event):
chat_message: ChatMessage
class ToolRequestEvent(InputRequiredEvent):
tool_name: str
tool_id: str
tool_kwargs: dict
class ToolApprovedEvent(HumanResponseEvent):
tool_name: str
tool_id: str
tool_kwargs: dict
approved: bool
response: str | None = None
class ProgressEvent(Event):
msg: str
# ---- Workflow ----
DEFAULT_ORCHESTRATOR_PROMPT = (
"You are on orchestration agent.\n"
"Your job is to decide which agent to run based on the current state of the user and what they've asked to do.\n"
"You do not need to figure out dependencies between agents; the agents will handle that themselves.\n"
"Here the the agents you can choose from:\n{agent_context_str}\n\n"
"Here is the current user state:\n{user_state_str}\n\n"
"Please assist the user and transfer them as needed."
)
DEFAULT_TOOL_REJECT_STR = "The tool call was not approved, likely due to a mistake or preconditions not being met."
class ConciergeAgent(Workflow):
def __init__(
self,
orchestrator_prompt: str | None = None,
default_tool_reject_str: str | None = None,
**kwargs: Any,
):
super().__init__(**kwargs)
self.orchestrator_prompt = orchestrator_prompt or DEFAULT_ORCHESTRATOR_PROMPT
self.default_tool_reject_str = (
default_tool_reject_str or DEFAULT_TOOL_REJECT_STR
)
@step
async def setup(
self, ctx: Context, ev: StartEvent
) -> ActiveSpeakerEvent | OrchestratorEvent:
"""Sets up the workflow, validates inputs, and stores them in the context."""
active_speaker = await ctx.get("active_speaker", default="")
user_msg = ev.get("user_msg")
agent_configs = ev.get("agent_configs", default=[])
llm: LLM = ev.get("llm", default=OpenAI(model="gpt-4o", temperature=0.3))
chat_history = ev.get("chat_history", default=[])
initial_state = ev.get("initial_state", default={})
if (
user_msg is None
or agent_configs is None
or llm is None
or chat_history is None
):
raise ValueError(
"User message, agent configs, llm, and chat_history are required!"
)
if not llm.metadata.is_function_calling_model:
raise ValueError("LLM must be a function calling model!")
# store the agent configs in the context
agent_configs_dict = {ac.name: ac for ac in agent_configs}
await ctx.set("agent_configs", agent_configs_dict)
await ctx.set("llm", llm)
chat_history.append(ChatMessage(role="user", content=user_msg))
await ctx.set("chat_history", chat_history)
await ctx.set("user_state", initial_state)
# if there is an active speaker, we need to transfer forward the user to them
if active_speaker:
return ActiveSpeakerEvent()
# otherwise, we need to decide who the next active speaker is
return OrchestratorEvent(user_msg=user_msg)
@step
async def speak_with_sub_agent(
self, ctx: Context, ev: ActiveSpeakerEvent
) -> ToolCallEvent | ToolRequestEvent | StopEvent:
"""Speaks with the active sub-agent and handles tool calls (if any)."""
# Setup the agent for the active speaker
active_speaker = await ctx.get("active_speaker")
agent_config: AgentConfig = (await ctx.get("agent_configs"))[active_speaker]
chat_history = await ctx.get("chat_history")
llm = await ctx.get("llm")
user_state = await ctx.get("user_state")
user_state_str = "\n".join([f"{k}: {v}" for k, v in user_state.items()])
system_prompt = (
agent_config.system_prompt.strip()
+ f"\n\nHere is the current user state:\n{user_state_str}"
)
llm_input = [ChatMessage(role="system", content=system_prompt)] + chat_history
# inject the request transfer tool into the list of tools
tools = [get_function_tool(RequestTransfer)] + agent_config.tools
response = await llm.achat_with_tools(tools, chat_history=llm_input)
tool_calls: list[ToolSelection] = llm.get_tool_calls_from_response(
response, error_on_no_tool_call=False
)
if len(tool_calls) == 0:
chat_history.append(response.message)
await ctx.set("chat_history", chat_history)
return StopEvent(
result={
"response": response.message.content,
"chat_history": chat_history,
}
)
await ctx.set("num_tool_calls", len(tool_calls))
for tool_call in tool_calls:
if tool_call.tool_name == "RequestTransfer":
await ctx.set("active_speaker", None)
ctx.write_event_to_stream(
ProgressEvent(msg="Agent is requesting a transfer. Please hold.")
)
return OrchestratorEvent()
elif tool_call.tool_name in agent_config.tools_requiring_human_confirmation:
ctx.write_event_to_stream(
ToolRequestEvent(
prefix=f"Tool {tool_call.tool_name} requires human approval.",
tool_name=tool_call.tool_name,
tool_kwargs=tool_call.tool_kwargs,
tool_id=tool_call.tool_id,
)
)
else:
ctx.send_event(
ToolCallEvent(tool_call=tool_call, tools=agent_config.tools)
)
chat_history.append(response.message)
await ctx.set("chat_history", chat_history)
@step
async def handle_tool_approval(
self, ctx: Context, ev: ToolApprovedEvent
) -> ToolCallEvent | ToolCallResultEvent:
"""Handles the approval or rejection of a tool call."""
if ev.approved:
active_speaker = await ctx.get("active_speaker")
agent_config = (await ctx.get("agent_configs"))[active_speaker]
return ToolCallEvent(
tools=agent_config.tools,
tool_call=ToolSelection(
tool_id=ev.tool_id,
tool_name=ev.tool_name,
tool_kwargs=ev.tool_kwargs,
),
)
else:
return ToolCallResultEvent(
chat_message=ChatMessage(
role="tool",
content=ev.response or self.default_tool_reject_str,
)
)
@step(num_workers=4)
async def handle_tool_call(
self, ctx: Context, ev: ToolCallEvent
) -> ActiveSpeakerEvent:
"""Handles the execution of a tool call."""
tool_call = ev.tool_call
tools_by_name = {tool.metadata.get_name(): tool for tool in ev.tools}
tool_msg = None
tool = tools_by_name.get(tool_call.tool_name)
additional_kwargs = {
"tool_call_id": tool_call.tool_id,
"name": tool.metadata.get_name(),
}
if not tool:
tool_msg = ChatMessage(
role="tool",
content=f"Tool {tool_call.tool_name} does not exist",
additional_kwargs=additional_kwargs,
)
try:
if isinstance(tool, FunctionToolWithContext):
tool_output = await tool.acall(ctx, **tool_call.tool_kwargs)
else:
tool_output = await tool.acall(**tool_call.tool_kwargs)
tool_msg = ChatMessage(
role="tool",
content=tool_output.content,
additional_kwargs=additional_kwargs,
)
except Exception as e:
tool_msg = ChatMessage(
role="tool",
content=f"Encountered error in tool call: {e}",
additional_kwargs=additional_kwargs,
)
ctx.write_event_to_stream(
ProgressEvent(
msg=f"Tool {tool_call.tool_name} called with {tool_call.tool_kwargs} returned {tool_msg.content}"
)
)
return ToolCallResultEvent(chat_message=tool_msg)
@step
async def aggregate_tool_results(
self, ctx: Context, ev: ToolCallResultEvent
) -> ActiveSpeakerEvent:
"""Collects the results of all tool calls and updates the chat history."""
num_tool_calls = await ctx.get("num_tool_calls")
results = ctx.collect_events(ev, [ToolCallResultEvent] * num_tool_calls)
if not results:
return
chat_history = await ctx.get("chat_history")
for result in results:
chat_history.append(result.chat_message)
await ctx.set("chat_history", chat_history)
return ActiveSpeakerEvent()
@step
async def orchestrator(
self, ctx: Context, ev: OrchestratorEvent
) -> ActiveSpeakerEvent | StopEvent:
"""Decides which agent to run next, if any."""
agent_configs = await ctx.get("agent_configs")
chat_history = await ctx.get("chat_history")
agent_context_str = ""
for agent_name, agent_config in agent_configs.items():
agent_context_str += f"{agent_name}: {agent_config.description}\n"
user_state = await ctx.get("user_state")
user_state_str = "\n".join([f"{k}: {v}" for k, v in user_state.items()])
system_prompt = self.orchestrator_prompt.format(
agent_context_str=agent_context_str, user_state_str=user_state_str
)
llm_input = [ChatMessage(role="system", content=system_prompt)] + chat_history
llm = await ctx.get("llm")
# convert the TransferToAgent pydantic model to a tool
tools = [get_function_tool(TransferToAgent)]
response = await llm.achat_with_tools(tools, chat_history=llm_input)
tool_calls = llm.get_tool_calls_from_response(
response, error_on_no_tool_call=False
)
# if no tool calls were made, the orchestrator probably needs more information
if len(tool_calls) == 0:
chat_history.append(response.message)
return StopEvent(
result={
"response": response.message.content,
"chat_history": chat_history,
}
)
tool_call = tool_calls[0]
selected_agent = tool_call.tool_kwargs["agent_name"]
await ctx.set("active_speaker", selected_agent)
ctx.write_event_to_stream(
ProgressEvent(msg=f"Transferring to agent {selected_agent}")
)
return ActiveSpeakerEvent()