Skip to content

Commit

Permalink
feat: caller supports setting system prompt and its opration
Browse files Browse the repository at this point in the history
  • Loading branch information
woorui committed Oct 10, 2024
1 parent 7fec9e9 commit 3594014
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 33 deletions.
29 changes: 24 additions & 5 deletions pkg/bridge/ai/caller.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,36 @@ func reduceFunc(messages chan ReduceMessage, logger *slog.Logger) core.AsyncHand
}
}

type promptOperation struct {
prompt string
operation SystemPromptOp
}

// SystemPromptOp defines the operation of system prompt
type SystemPromptOp int

const (
SystemPromptOpOverwrite SystemPromptOp = 0
SystemPromptOpDisabled SystemPromptOp = 1
SystemPromptOpPrefix SystemPromptOp = 2
)

// SetSystemPrompt sets the system prompt
func (c *Caller) SetSystemPrompt(prompt string) {
c.systemPrompt.Store(prompt)
func (c *Caller) SetSystemPrompt(prompt string, op SystemPromptOp) {
p := &promptOperation{
prompt: prompt,
operation: op,
}
c.systemPrompt.Store(p)
}

// SetSystemPrompt gets the system prompt
func (c *Caller) GetSystemPrompt() string {
func (c *Caller) GetSystemPrompt() (prompt string, op SystemPromptOp) {
if v := c.systemPrompt.Load(); v != nil {
return v.(string)
pop := v.(*promptOperation)
return pop.prompt, pop.operation
}
return ""
return "", SystemPromptOpOverwrite
}

// Metadata returns the metadata of caller.
Expand Down
11 changes: 8 additions & 3 deletions pkg/bridge/ai/caller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,14 @@ func TestCaller(t *testing.T) {

assert.Equal(t, md, caller.Metadata())

sysPrompt := "hello system prompt"
caller.SetSystemPrompt(sysPrompt)
assert.Equal(t, sysPrompt, caller.GetSystemPrompt())
var (
prompt = "hello system prompt"
op = SystemPromptOpPrefix
)
caller.SetSystemPrompt(prompt, op)
gotPrompt, gotOp := caller.GetSystemPrompt()
assert.Equal(t, prompt, gotPrompt)
assert.Equal(t, op, gotOp)
}

type testComponentCreator struct {
Expand Down
47 changes: 30 additions & 17 deletions pkg/bridge/ai/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,9 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
// 2. add those tools to request
req = srv.addToolsToRequest(req, tagTools)

// 3. over write system prompt to request
req = srv.overWriteSystemPrompt(req, caller.GetSystemPrompt())
// 3. operate system prompt to request
prompt, op := caller.GetSystemPrompt()
req = srv.opSystemPrompt(req, prompt, op)

var (
promptUsage = 0
Expand Down Expand Up @@ -537,32 +538,44 @@ func (srv *Service) addToolsToRequest(req openai.ChatCompletionRequest, tagTools
return req
}

func (srv *Service) overWriteSystemPrompt(req openai.ChatCompletionRequest, sysPrompt string) openai.ChatCompletionRequest {
// do nothing if system prompt is empty
if sysPrompt == "" {
func (srv *Service) opSystemPrompt(req openai.ChatCompletionRequest, sysPrompt string, op SystemPromptOp) openai.ChatCompletionRequest {
if op == SystemPromptOpDisabled {
return req
}
// over write system prompt
isOverWrite := false
for i, msg := range req.Messages {
var (
systemCount = 0
messages = []openai.ChatCompletionMessage{}
)
for _, msg := range req.Messages {
if msg.Role != "system" {
messages = append(messages, msg)
continue
}
req.Messages[i] = openai.ChatCompletionMessage{
Role: msg.Role,
Content: sysPrompt,
if systemCount == 0 {
content := ""
switch op {
case SystemPromptOpPrefix:
content = sysPrompt + "\n" + msg.Content
case SystemPromptOpOverwrite:
content = sysPrompt

Check warning on line 560 in pkg/bridge/ai/service.go

View check run for this annotation

Codecov / codecov/patch

pkg/bridge/ai/service.go#L554-L560

Added lines #L554 - L560 were not covered by tests
}
messages = append(messages, openai.ChatCompletionMessage{
Role: msg.Role,
Content: content,
})

Check warning on line 565 in pkg/bridge/ai/service.go

View check run for this annotation

Codecov / codecov/patch

pkg/bridge/ai/service.go#L562-L565

Added lines #L562 - L565 were not covered by tests
}
isOverWrite = true
systemCount++

Check warning on line 567 in pkg/bridge/ai/service.go

View check run for this annotation

Codecov / codecov/patch

pkg/bridge/ai/service.go#L567

Added line #L567 was not covered by tests
}
// append system prompt
if !isOverWrite {
req.Messages = append(req.Messages, openai.ChatCompletionMessage{
if systemCount == 0 {
message := openai.ChatCompletionMessage{
Role: "system",
Content: sysPrompt,
})
}
messages = append([]openai.ChatCompletionMessage{message}, req.Messages...)
}
req.Messages = messages

srv.logger.Debug(" #1 first call after overwrite", "request", fmt.Sprintf("%+v", req))
srv.logger.Debug(" #1 first call after operating", "request", fmt.Sprintf("%+v", req))

return req
}
Expand Down
19 changes: 11 additions & 8 deletions pkg/bridge/ai/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func TestServiceInvoke(t *testing.T) {
caller, err := service.LoadOrCreateCaller(&http.Request{})
assert.NoError(t, err)

caller.SetSystemPrompt(tt.args.systemPrompt)
caller.SetSystemPrompt(tt.args.systemPrompt, SystemPromptOpOverwrite)

resp, err := service.GetInvoke(context.TODO(), tt.args.userInstruction, tt.args.baseSystemMessage, "transID", caller, true)
assert.NoError(t, err)
Expand Down Expand Up @@ -151,15 +151,15 @@ func TestServiceChatCompletion(t *testing.T) {
wantRequest: []openai.ChatCompletionRequest{
{
Messages: []openai.ChatCompletionMessage{
{Role: "user", Content: "How is the weather today in Boston, MA?"},
{Role: "system", Content: "this is a system prompt"},
{Role: "user", Content: "How is the weather today in Boston, MA?"},
},
Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}},
},
{
Messages: []openai.ChatCompletionMessage{
{Role: "user", Content: "How is the weather today in Boston, MA?"},
{Role: "system", Content: "this is a system prompt"},
{Role: "user", Content: "How is the weather today in Boston, MA?"},
{Role: "assistant", ToolCalls: []openai.ToolCall{{ID: "call_abc123", Type: openai.ToolTypeFunction, Function: openai.FunctionCall{Name: "get_current_weather", Arguments: "{\n\"location\": \"Boston, MA\"\n}"}}}},
{Role: "tool", Content: "temperature: 31°C", ToolCallID: "call_abc123"},
},
Expand All @@ -184,8 +184,8 @@ func TestServiceChatCompletion(t *testing.T) {
wantRequest: []openai.ChatCompletionRequest{
{
Messages: []openai.ChatCompletionMessage{
{Role: "user", Content: "How are you"},
{Role: "system", Content: "You are an assistant."},
{Role: "user", Content: "How are you"},
},
Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}},
},
Expand All @@ -211,16 +211,16 @@ func TestServiceChatCompletion(t *testing.T) {
{
Stream: true,
Messages: []openai.ChatCompletionMessage{
{Role: "user", Content: "How is the weather today in Boston, MA?"},
{Role: "system", Content: "You are a weather assistant"},
{Role: "user", Content: "How is the weather today in Boston, MA?"},
},
Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}},
},
{
Stream: true,
Messages: []openai.ChatCompletionMessage{
{Role: "user", Content: "How is the weather today in Boston, MA?"},
{Role: "system", Content: "You are a weather assistant"},
{Role: "user", Content: "How is the weather today in Boston, MA?"},
{Role: "assistant", ToolCalls: []openai.ToolCall{{Index: toInt(0), ID: "call_9ctHOJqO3bYrpm2A6S7nHd5k", Type: openai.ToolTypeFunction, Function: openai.FunctionCall{Name: "get_current_weather", Arguments: "{\"location\":\"Boston, MA\"}"}}}},
{Role: "tool", Content: "temperature: 31°C", ToolCallID: "call_9ctHOJqO3bYrpm2A6S7nHd5k"},
},
Expand All @@ -247,8 +247,8 @@ func TestServiceChatCompletion(t *testing.T) {
{
Stream: true,
Messages: []openai.ChatCompletionMessage{
{Role: "user", Content: "How is the weather today in Boston, MA?"},
{Role: "system", Content: "You are a weather assistant"},
{Role: "user", Content: "How is the weather today in Boston, MA?"},
},
Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}},
},
Expand Down Expand Up @@ -279,7 +279,7 @@ func TestServiceChatCompletion(t *testing.T) {
caller, err := service.LoadOrCreateCaller(&http.Request{})
assert.NoError(t, err)

caller.SetSystemPrompt(tt.args.systemPrompt)
caller.SetSystemPrompt(tt.args.systemPrompt, SystemPromptOpOverwrite)

w := httptest.NewRecorder()
err = service.GetChatCompletions(context.TODO(), tt.args.request, "transID", caller, w)
Expand Down Expand Up @@ -501,3 +501,6 @@ var toolCallResp = `{
"total_tokens": 99
}
}`

// []openai.ChatCompletionRequest{openai.ChatCompletionRequest{Model:"", Messages:[]openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role:"user", Content:"How is the weather today in Boston, MA?", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall(nil), ToolCallID:""}, openai.ChatCompletionMessage{Role:"system", Content:"this is a system prompt", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall(nil), ToolCallID:""}}, MaxTokens:0, Temperature:0, TopP:0, N:0, Stream:false, Stop:[]string(nil), PresencePenalty:0, ResponseFormat:(*openai.ChatCompletionResponseFormat)(nil), Seed:(*int)(nil), FrequencyPenalty:0, LogitBias:map[string]int(nil), LogProbs:false, TopLogProbs:0, User:"", Functions:[]openai.FunctionDefinition(nil), FunctionCall:interface {}(nil), Tools:[]openai.Tool{openai.Tool{Type:"function", Function:(*openai.FunctionDefinition)(0xc00004f9c0)}}, ToolChoice:interface {}(nil), StreamOptions:(*openai.StreamOptions)(nil), ParallelToolCalls:interface {}(nil)}, openai.ChatCompletionRequest{Model:"", Messages:[]openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role:"user", Content:"How is the weather today in Boston, MA?", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall(nil), ToolCallID:""}, openai.ChatCompletionMessage{Role:"system", Content:"this is a system prompt", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall(nil), ToolCallID:""}, openai.ChatCompletionMessage{Role:"assistant", Content:"", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall{openai.ToolCall{Index:(*int)(nil), ID:"call_abc123", Type:"function", Function:openai.FunctionCall{Name:"get_current_weather", Arguments:"{\n\"location\": \"Boston, MA\"\n}"}}}, ToolCallID:""}, openai.ChatCompletionMessage{Role:"tool", Content:"temperature: 31°C", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall(nil), ToolCallID:"call_abc123"}}, MaxTokens:0, Temperature:0, TopP:0, N:0, Stream:false, Stop:[]string(nil), PresencePenalty:0, ResponseFormat:(*openai.ChatCompletionResponseFormat)(nil), Seed:(*int)(nil), FrequencyPenalty:0, LogitBias:map[string]int(nil), LogProbs:false, TopLogProbs:0, User:"", Functions:[]openai.FunctionDefinition(nil), FunctionCall:interface {}(nil), Tools:[]openai.Tool(nil), ToolChoice:interface {}(nil), StreamOptions:(*openai.StreamOptions)(nil), ParallelToolCalls:interface {}(nil)}}
// []openai.ChatCompletionRequest{openai.ChatCompletionRequest{Model:"", Messages:[]openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role:"user", Content:"How is the weather today in Boston, MA?", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall(nil), ToolCallID:""}}, MaxTokens:0, Temperature:0, TopP:0, N:0, Stream:false, Stop:[]string(nil), PresencePenalty:0, ResponseFormat:(*openai.ChatCompletionResponseFormat)(nil), Seed:(*int)(nil), FrequencyPenalty:0, LogitBias:map[string]int(nil), LogProbs:false, TopLogProbs:0, User:"", Functions:[]openai.FunctionDefinition(nil), FunctionCall:interface {}(nil), Tools:[]openai.Tool{openai.Tool{Type:"function", Function:(*openai.FunctionDefinition)(0xc00004fbc0)}}, ToolChoice:interface {}(nil), StreamOptions:(*openai.StreamOptions)(nil), ParallelToolCalls:interface {}(nil)}, openai.ChatCompletionRequest{Model:"", Messages:[]openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role:"user", Content:"How is the weather today in Boston, MA?", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall(nil), ToolCallID:""}, openai.ChatCompletionMessage{Role:"assistant", Content:"", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall{openai.ToolCall{Index:(*int)(nil), ID:"call_abc123", Type:"function", Function:openai.FunctionCall{Name:"get_current_weather", Arguments:"{\n\"location\": \"Boston, MA\"\n}"}}}, ToolCallID:""}, openai.ChatCompletionMessage{Role:"tool", Content:"temperature: 31°C", MultiContent:[]openai.ChatMessagePart(nil), Name:"", FunctionCall:(*openai.FunctionCall)(nil), ToolCalls:[]openai.ToolCall(nil), ToolCallID:"call_abc123"}}, MaxTokens:0, Temperature:0, TopP:0, N:0, Stream:false, Stop:[]string(nil), PresencePenalty:0, ResponseFormat:(*openai.ChatCompletionResponseFormat)(nil), Seed:(*int)(nil), FrequencyPenalty:0, LogitBias:map[string]int(nil), LogProbs:false, TopLogProbs:0, User:"", Functions:[]openai.FunctionDefinition(nil), FunctionCall:interface {}(nil), Tools:[]openai.Tool(nil), ToolChoice:interface {}(nil), StreamOptions:(*openai.StreamOptions)(nil), ParallelToolCalls:interface {}(nil)}}

0 comments on commit 3594014

Please sign in to comment.