From 35940149b139a39d07c54290c0d9aae073e5ddb9 Mon Sep 17 00:00:00 2001 From: woorui Date: Fri, 11 Oct 2024 03:00:55 +0800 Subject: [PATCH] feat: caller supports setting system prompt and its opration --- pkg/bridge/ai/caller.go | 29 +++++++++++++++++---- pkg/bridge/ai/caller_test.go | 11 +++++--- pkg/bridge/ai/service.go | 47 ++++++++++++++++++++++------------- pkg/bridge/ai/service_test.go | 19 ++++++++------ 4 files changed, 73 insertions(+), 33 deletions(-) diff --git a/pkg/bridge/ai/caller.go b/pkg/bridge/ai/caller.go index 44c5b1256..19a317910 100644 --- a/pkg/bridge/ai/caller.go +++ b/pkg/bridge/ai/caller.go @@ -108,17 +108,36 @@ func reduceFunc(messages chan ReduceMessage, logger *slog.Logger) core.AsyncHand } } +type promptOperation struct { + prompt string + operation SystemPromptOp +} + +// SystemPromptOp defines the operation of system prompt +type SystemPromptOp int + +const ( + SystemPromptOpOverwrite SystemPromptOp = 0 + SystemPromptOpDisabled SystemPromptOp = 1 + SystemPromptOpPrefix SystemPromptOp = 2 +) + // SetSystemPrompt sets the system prompt -func (c *Caller) SetSystemPrompt(prompt string) { - c.systemPrompt.Store(prompt) +func (c *Caller) SetSystemPrompt(prompt string, op SystemPromptOp) { + p := &promptOperation{ + prompt: prompt, + operation: op, + } + c.systemPrompt.Store(p) } // SetSystemPrompt gets the system prompt -func (c *Caller) GetSystemPrompt() string { +func (c *Caller) GetSystemPrompt() (prompt string, op SystemPromptOp) { if v := c.systemPrompt.Load(); v != nil { - return v.(string) + pop := v.(*promptOperation) + return pop.prompt, pop.operation } - return "" + return "", SystemPromptOpOverwrite } // Metadata returns the metadata of caller. diff --git a/pkg/bridge/ai/caller_test.go b/pkg/bridge/ai/caller_test.go index 655311a97..8abfd3056 100644 --- a/pkg/bridge/ai/caller_test.go +++ b/pkg/bridge/ai/caller_test.go @@ -22,9 +22,14 @@ func TestCaller(t *testing.T) { assert.Equal(t, md, caller.Metadata()) - sysPrompt := "hello system prompt" - caller.SetSystemPrompt(sysPrompt) - assert.Equal(t, sysPrompt, caller.GetSystemPrompt()) + var ( + prompt = "hello system prompt" + op = SystemPromptOpPrefix + ) + caller.SetSystemPrompt(prompt, op) + gotPrompt, gotOp := caller.GetSystemPrompt() + assert.Equal(t, prompt, gotPrompt) + assert.Equal(t, op, gotOp) } type testComponentCreator struct { diff --git a/pkg/bridge/ai/service.go b/pkg/bridge/ai/service.go index 4cef9e8ab..b2b6639f5 100644 --- a/pkg/bridge/ai/service.go +++ b/pkg/bridge/ai/service.go @@ -240,8 +240,9 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl // 2. add those tools to request req = srv.addToolsToRequest(req, tagTools) - // 3. over write system prompt to request - req = srv.overWriteSystemPrompt(req, caller.GetSystemPrompt()) + // 3. operate system prompt to request + prompt, op := caller.GetSystemPrompt() + req = srv.opSystemPrompt(req, prompt, op) var ( promptUsage = 0 @@ -537,32 +538,44 @@ func (srv *Service) addToolsToRequest(req openai.ChatCompletionRequest, tagTools return req } -func (srv *Service) overWriteSystemPrompt(req openai.ChatCompletionRequest, sysPrompt string) openai.ChatCompletionRequest { - // do nothing if system prompt is empty - if sysPrompt == "" { +func (srv *Service) opSystemPrompt(req openai.ChatCompletionRequest, sysPrompt string, op SystemPromptOp) openai.ChatCompletionRequest { + if op == SystemPromptOpDisabled { return req } - // over write system prompt - isOverWrite := false - for i, msg := range req.Messages { + var ( + systemCount = 0 + messages = []openai.ChatCompletionMessage{} + ) + for _, msg := range req.Messages { if msg.Role != "system" { + messages = append(messages, msg) continue } - req.Messages[i] = openai.ChatCompletionMessage{ - Role: msg.Role, - Content: sysPrompt, + if systemCount == 0 { + content := "" + switch op { + case SystemPromptOpPrefix: + content = sysPrompt + "\n" + msg.Content + case SystemPromptOpOverwrite: + content = sysPrompt + } + messages = append(messages, openai.ChatCompletionMessage{ + Role: msg.Role, + Content: content, + }) } - isOverWrite = true + systemCount++ } - // append system prompt - if !isOverWrite { - req.Messages = append(req.Messages, openai.ChatCompletionMessage{ + if systemCount == 0 { + message := openai.ChatCompletionMessage{ Role: "system", Content: sysPrompt, - }) + } + messages = append([]openai.ChatCompletionMessage{message}, req.Messages...) } + req.Messages = messages - srv.logger.Debug(" #1 first call after overwrite", "request", fmt.Sprintf("%+v", req)) + srv.logger.Debug(" #1 first call after operating", "request", fmt.Sprintf("%+v", req)) return req } diff --git a/pkg/bridge/ai/service_test.go b/pkg/bridge/ai/service_test.go index 474593a31..be2c5e7e1 100644 --- a/pkg/bridge/ai/service_test.go +++ b/pkg/bridge/ai/service_test.go @@ -110,7 +110,7 @@ func TestServiceInvoke(t *testing.T) { caller, err := service.LoadOrCreateCaller(&http.Request{}) assert.NoError(t, err) - caller.SetSystemPrompt(tt.args.systemPrompt) + caller.SetSystemPrompt(tt.args.systemPrompt, SystemPromptOpOverwrite) resp, err := service.GetInvoke(context.TODO(), tt.args.userInstruction, tt.args.baseSystemMessage, "transID", caller, true) assert.NoError(t, err) @@ -151,15 +151,15 @@ func TestServiceChatCompletion(t *testing.T) { wantRequest: []openai.ChatCompletionRequest{ { Messages: []openai.ChatCompletionMessage{ - {Role: "user", Content: "How is the weather today in Boston, MA?"}, {Role: "system", Content: "this is a system prompt"}, + {Role: "user", Content: "How is the weather today in Boston, MA?"}, }, Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, }, { Messages: []openai.ChatCompletionMessage{ - {Role: "user", Content: "How is the weather today in Boston, MA?"}, {Role: "system", Content: "this is a system prompt"}, + {Role: "user", Content: "How is the weather today in Boston, MA?"}, {Role: "assistant", ToolCalls: []openai.ToolCall{{ID: "call_abc123", Type: openai.ToolTypeFunction, Function: openai.FunctionCall{Name: "get_current_weather", Arguments: "{\n\"location\": \"Boston, MA\"\n}"}}}}, {Role: "tool", Content: "temperature: 31°C", ToolCallID: "call_abc123"}, }, @@ -184,8 +184,8 @@ func TestServiceChatCompletion(t *testing.T) { wantRequest: []openai.ChatCompletionRequest{ { Messages: []openai.ChatCompletionMessage{ - {Role: "user", Content: "How are you"}, {Role: "system", Content: "You are an assistant."}, + {Role: "user", Content: "How are you"}, }, Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, }, @@ -211,16 +211,16 @@ func TestServiceChatCompletion(t *testing.T) { { Stream: true, Messages: []openai.ChatCompletionMessage{ - {Role: "user", Content: "How is the weather today in Boston, MA?"}, {Role: "system", Content: "You are a weather assistant"}, + {Role: "user", Content: "How is the weather today in Boston, MA?"}, }, Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, }, { Stream: true, Messages: []openai.ChatCompletionMessage{ - {Role: "user", Content: "How is the weather today in Boston, MA?"}, {Role: "system", Content: "You are a weather assistant"}, + {Role: "user", Content: "How is the weather today in Boston, MA?"}, {Role: "assistant", ToolCalls: []openai.ToolCall{{Index: toInt(0), ID: "call_9ctHOJqO3bYrpm2A6S7nHd5k", Type: openai.ToolTypeFunction, Function: openai.FunctionCall{Name: "get_current_weather", Arguments: "{\"location\":\"Boston, MA\"}"}}}}, {Role: "tool", Content: "temperature: 31°C", ToolCallID: "call_9ctHOJqO3bYrpm2A6S7nHd5k"}, }, @@ -247,8 +247,8 @@ func TestServiceChatCompletion(t *testing.T) { { Stream: true, Messages: []openai.ChatCompletionMessage{ - {Role: "user", Content: "How is the weather today in Boston, MA?"}, {Role: "system", Content: "You are a weather assistant"}, + {Role: "user", Content: "How is the weather today in Boston, MA?"}, }, Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, }, @@ -279,7 +279,7 @@ func TestServiceChatCompletion(t *testing.T) { caller, err := service.LoadOrCreateCaller(&http.Request{}) assert.NoError(t, err) - caller.SetSystemPrompt(tt.args.systemPrompt) + caller.SetSystemPrompt(tt.args.systemPrompt, SystemPromptOpOverwrite) w := httptest.NewRecorder() err = service.GetChatCompletions(context.TODO(), tt.args.request, "transID", caller, w) @@ -501,3 +501,6 @@ var toolCallResp = `{ "total_tokens": 99 } }` + +// []openai.ChatCompletionRequest{openai.ChatCompletionRequest{Model:"", Messages:[]openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role:"user", Content:"How is the weather today in Boston, MA?", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall(nil), ToolCallID:""}, openai.ChatCompletionMessage{Role:"system", Content:"this is a system prompt", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall(nil), ToolCallID:""}}, MaxTokens:0, Temperature:0, TopP:0, N:0, Stream:false, Stop:[]string(nil), PresencePenalty:0, ResponseFormat:(*openai.ChatCompletionResponseFormat)(nil), Seed:(*int)(nil), FrequencyPenalty:0, LogitBias:map[string]int(nil), LogProbs:false, TopLogProbs:0, User:"", Functions:[]openai.FunctionDefinition(nil), FunctionCall:interface {}(nil), Tools:[]openai.Tool{openai.Tool{Type:"function", Function:(*openai.FunctionDefinition)(0xc00004f9c0)}}, ToolChoice:interface {}(nil), StreamOptions:(*openai.StreamOptions)(nil), ParallelToolCalls:interface {}(nil)}, openai.ChatCompletionRequest{Model:"", Messages:[]openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role:"user", Content:"How is the weather today in Boston, MA?", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall(nil), ToolCallID:""}, openai.ChatCompletionMessage{Role:"system", Content:"this is a system prompt", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall(nil), ToolCallID:""}, openai.ChatCompletionMessage{Role:"assistant", Content:"", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall{openai.ToolCall{Index:(*int)(nil), ID:"call_abc123", Type:"function", Function:openai.FunctionCall{Name:"get_current_weather", Arguments:"{\n\"location\": \"Boston, MA\"\n}"}}}, ToolCallID:""}, openai.ChatCompletionMessage{Role:"tool", Content:"temperature: 31°C", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall(nil), ToolCallID:"call_abc123"}}, MaxTokens:0, Temperature:0, TopP:0, N:0, Stream:false, Stop:[]string(nil), PresencePenalty:0, ResponseFormat:(*openai.ChatCompletionResponseFormat)(nil), Seed:(*int)(nil), FrequencyPenalty:0, LogitBias:map[string]int(nil), LogProbs:false, TopLogProbs:0, User:"", Functions:[]openai.FunctionDefinition(nil), FunctionCall:interface {}(nil), Tools:[]openai.Tool(nil), ToolChoice:interface {}(nil), StreamOptions:(*openai.StreamOptions)(nil), ParallelToolCalls:interface {}(nil)}} +// []openai.ChatCompletionRequest{openai.ChatCompletionRequest{Model:"", Messages:[]openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role:"user", Content:"How is the weather today in Boston, MA?", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall(nil), ToolCallID:""}}, MaxTokens:0, Temperature:0, TopP:0, N:0, Stream:false, Stop:[]string(nil), PresencePenalty:0, ResponseFormat:(*openai.ChatCompletionResponseFormat)(nil), Seed:(*int)(nil), FrequencyPenalty:0, LogitBias:map[string]int(nil), LogProbs:false, TopLogProbs:0, User:"", Functions:[]openai.FunctionDefinition(nil), FunctionCall:interface {}(nil), Tools:[]openai.Tool{openai.Tool{Type:"function", Function:(*openai.FunctionDefinition)(0xc00004fbc0)}}, ToolChoice:interface {}(nil), StreamOptions:(*openai.StreamOptions)(nil), ParallelToolCalls:interface {}(nil)}, openai.ChatCompletionRequest{Model:"", Messages:[]openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role:"user", Content:"How is the weather today in Boston, MA?", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall(nil), ToolCallID:""}, openai.ChatCompletionMessage{Role:"assistant", Content:"", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall{openai.ToolCall{Index:(*int)(nil), ID:"call_abc123", Type:"function", Function:openai.FunctionCall{Name:"get_current_weather", Arguments:"{\n\"location\": \"Boston, MA\"\n}"}}}, ToolCallID:""}, openai.ChatCompletionMessage{Role:"tool", Content:"temperature: 31°C", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall(nil), ToolCallID:"call_abc123"}}, MaxTokens:0, Temperature:0, TopP:0, N:0, Stream:false, Stop:[]string(nil), PresencePenalty:0, ResponseFormat:(*openai.ChatCompletionResponseFormat)(nil), Seed:(*int)(nil), FrequencyPenalty:0, LogitBias:map[string]int(nil), LogProbs:false, TopLogProbs:0, User:"", Functions:[]openai.FunctionDefinition(nil), FunctionCall:interface {}(nil), Tools:[]openai.Tool(nil), ToolChoice:interface {}(nil), StreamOptions:(*openai.StreamOptions)(nil), ParallelToolCalls:interface {}(nil)}}