Skip to content

Commit

Permalink
message-context filter put messageGroup into context
Browse files Browse the repository at this point in the history
  • Loading branch information
sfwn committed Nov 1, 2023
1 parent 2d1afd5 commit c055b56
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 12 deletions.
41 changes: 41 additions & 0 deletions internal/apps/ai-proxy/common/ctxhelper/message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2021 Terminus, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ctxhelper

import (
"context"
"sync"

"github.com/erda-project/erda/internal/apps/ai-proxy/models/message"
"github.com/erda-project/erda/internal/apps/ai-proxy/vars"
"github.com/erda-project/erda/pkg/reverseproxy"
)

func GetMessageGroup(ctx context.Context) (*message.Group, bool) {
value, ok := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map).Load(vars.MapKeyMessageGroup{})
if !ok || value == nil {
return nil, false
}
mg, ok := value.(*message.Group)
if !ok {
return nil, false
}
return mg, true
}

func PutMessageGroup(ctx context.Context, mg message.Group) {
m := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map)
m.Store(vars.MapKeyMessageGroup{}, &mg)
}
40 changes: 40 additions & 0 deletions internal/apps/ai-proxy/common/ctxhelper/prompt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) 2021 Terminus, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ctxhelper

import (
"context"
"sync"

"github.com/erda-project/erda/internal/apps/ai-proxy/vars"
"github.com/erda-project/erda/pkg/reverseproxy"
)

func GetUserPrompt(ctx context.Context) (string, bool) {
value, ok := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map).Load(vars.MapKeyUserPrompt{})
if !ok || value == nil {
return "", false
}
prompt, ok := value.(string)
if !ok {
return "", false
}
return prompt, true
}

func PutUserPrompt(ctx context.Context, prompt string) {
m := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map)
m.Store(vars.MapKeyUserPrompt{}, prompt)
}
27 changes: 27 additions & 0 deletions internal/apps/ai-proxy/filters/azure-director/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package azure_director
import (
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"os"
Expand All @@ -26,11 +27,14 @@ import (
"sync"

"github.com/pkg/errors"
"github.com/sashabaranov/go-openai"
"github.com/sirupsen/logrus"
"sigs.k8s.io/yaml"

"github.com/erda-project/erda-infra/base/logs"
modelpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model/pb"
modelproviderpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model_provider/pb"
"github.com/erda-project/erda/internal/apps/ai-proxy/common/ctxhelper"
"github.com/erda-project/erda/internal/apps/ai-proxy/models/metadata"
"github.com/erda-project/erda/internal/apps/ai-proxy/models/model_provider"
"github.com/erda-project/erda/internal/apps/ai-proxy/vars"
Expand Down Expand Up @@ -350,6 +354,29 @@ func (f *AzureDirector) handleQueries(ctx context.Context, funcName string) erro
return nil
}

func (f *AzureDirector) AddContextMessages(ctx context.Context) error {
messageGroup, ok := ctxhelper.GetMessageGroup(ctx)
if !ok {
return nil
}
reverseproxy.AppendDirectors(ctx, func(req *http.Request) {
infor := reverseproxy.NewInfor(ctx, req)
var openaiReq openai.ChatCompletionRequest
if err := json.NewDecoder(infor.BodyBuffer()).Decode(&openaiReq); err != nil && err != io.EOF {
logrus.Errorf("failed to decode request body, err: %v", err)
return
}
openaiReq.Messages = messageGroup.AllMessages
b, err := json.Marshal(&openaiReq)
if err != nil {
logrus.Errorf("failed to marshal request body, err: %v", err)
return
}
infor.SetBody(io.NopCloser(strings.NewReader(string(b))), int64(len(b)))
})
return nil
}

type Config struct {
Directors []string `json:"directors" yaml:"directors"`
}
Expand Down
27 changes: 15 additions & 12 deletions internal/apps/ai-proxy/filters/message-context/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
package message_context

import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"sort"
"sync"
Expand Down Expand Up @@ -76,7 +74,7 @@ func (c *SessionContext) OnRequest(ctx context.Context, _ http.ResponseWriter, i

// judge use session-id or prompt-id
sessionValue, sessionOk := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map).Load(vars.MapKeySession{})
promptValue, promptOk := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map).Load(vars.MapKeyPrompt{})
promptValue, promptOk := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map).Load(vars.MapKeyPromptTemplate{})

if !sessionOk && !promptOk {
return reverseproxy.Continue, nil
Expand All @@ -88,6 +86,7 @@ func (c *SessionContext) OnRequest(ctx context.Context, _ http.ResponseWriter, i
}

var allMessages message.Messages
var systemMessage *message.Message
var sessionTopicMessage *message.Message
var promptMessages message.Messages
var sessionPreviousMessages message.Messages
Expand Down Expand Up @@ -165,12 +164,13 @@ func (c *SessionContext) OnRequest(ctx context.Context, _ http.ResponseWriter, i

// 0. add system message
if c.Config.SysMsg != "" {
allMessages = append(allMessages, openai.ChatCompletionMessage{Role: openai.ChatMessageRoleSystem, Content: c.Config.SysMsg, Name: "Erda-AI-Assistant"})
systemMessage = &message.Message{Role: openai.ChatMessageRoleSystem, Content: c.Config.SysMsg, Name: "Erda-AI-Assistant"}
allMessages = append(allMessages, *systemMessage.ToOpenAI())
}

// 1. add session topic
if sessionTopicMessage != nil {
allMessages = append(allMessages, openai.ChatCompletionMessage(*sessionTopicMessage))
allMessages = append(allMessages, *sessionTopicMessage.ToOpenAI())
}
// 2. add prompt messages
allMessages = append(allMessages, promptMessages...)
Expand All @@ -179,14 +179,17 @@ func (c *SessionContext) OnRequest(ctx context.Context, _ http.ResponseWriter, i
// 4. add requested messages
allMessages = append(allMessages, requestedMessages...)

// set to request body
chatCompletionRequest.Messages = allMessages
b, err := json.Marshal(&chatCompletionRequest)
if err != nil {
l.Errorf("failed to marshal request body, err: %v", err)
return reverseproxy.Intercept, err
// 不同的模型,body 不同,不能直接 set,而是塞入上下文,由真正的 model filters 进行转换
messageGroup := message.Group{
AllMessages: allMessages,
SystemMessage: systemMessage,
SessionTopicMessage: sessionTopicMessage,
PromptTemplateMessages: promptMessages,
SessionPreviousMessages: sessionPreviousMessages,
RequestedMessages: requestedMessages,
}
infor.SetBody(io.NopCloser(bytes.NewBuffer(b)), int64(len(b)))
ctxhelper.PutMessageGroup(ctx, messageGroup)
ctxhelper.PutUserPrompt(ctx, requestedMessages[len(requestedMessages)-1].Content)

return reverseproxy.Continue, nil
}
25 changes: 25 additions & 0 deletions internal/apps/ai-proxy/filters/openai-director/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ import (
"sync"

"github.com/pkg/errors"
"github.com/sashabaranov/go-openai"
"github.com/sirupsen/logrus"
"sigs.k8s.io/yaml"

"github.com/erda-project/erda-infra/base/logs"
modelpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model/pb"
modelproviderpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model_provider/pb"
"github.com/erda-project/erda/internal/apps/ai-proxy/common/ctxhelper"
"github.com/erda-project/erda/internal/apps/ai-proxy/models/metadata"
"github.com/erda-project/erda/internal/apps/ai-proxy/vars"
"github.com/erda-project/erda/pkg/reverseproxy"
Expand Down Expand Up @@ -205,6 +207,29 @@ func (f *OpenaiDirector) AddModelInRequestBody(ctx context.Context) error {
return nil
}

func (f *OpenaiDirector) AddContextMessages(ctx context.Context) error {
messageGroup, ok := ctxhelper.GetMessageGroup(ctx)
if !ok {
return nil
}
reverseproxy.AppendDirectors(ctx, func(req *http.Request) {
infor := reverseproxy.NewInfor(ctx, req)
var openaiReq openai.ChatCompletionRequest
if err := json.NewDecoder(infor.BodyBuffer()).Decode(&openaiReq); err != nil && err != io.EOF {
logrus.Errorf("failed to decode request body, err: %v", err)
return
}
openaiReq.Messages = messageGroup.AllMessages
b, err := json.Marshal(&openaiReq)
if err != nil {
logrus.Errorf("failed to marshal request body, err: %v", err)
return
}
infor.SetBody(io.NopCloser(strings.NewReader(string(b))), int64(len(b)))
})
return nil
}

func (f *OpenaiDirector) AllDirectors() map[string]func(ctx context.Context) error {
if len(f.funcs) > 0 {
return f.funcs
Expand Down
26 changes: 26 additions & 0 deletions internal/apps/ai-proxy/models/message/group.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2021 Terminus, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package message

type Group struct {
AllMessages Messages // use this directly if you are not care about the details

// pay attention to the order
SystemMessage *Message
SessionTopicMessage *Message
PromptTemplateMessages Messages
SessionPreviousMessages Messages // from prompt template if provided in the http header
RequestedMessages Messages // normally from body prompt field, user input
}

0 comments on commit c055b56

Please sign in to comment.