Skip to content

Commit

Permalink
add bailian-diretor for qwenv1
Browse files Browse the repository at this point in the history
  • Loading branch information
sfwn committed Nov 1, 2023
1 parent c055b56 commit 2f0f651
Show file tree
Hide file tree
Showing 14 changed files with 552 additions and 30 deletions.
1 change: 1 addition & 0 deletions api/proto/apps/aiproxy/model_provider/model_provider.proto
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ enum ModelProviderType {
TYPE_UNSPECIFIED = 0;
OpenAI = 1;
Azure = 2;
AliyunBailian = 3;
}

message ModelProviderCreateRequest {
Expand Down
10 changes: 7 additions & 3 deletions cmd/ai-proxy/conf/routes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ routes:
maxSize: 102400
message: { "messages": [ { "role": "ai-proxy", "content": "问题超长啦, 请重置会话", "name": "ai-proxy" } ] }
- name: context
- name: message-context
config:
sysMsg: "background: Your name is Erda AI Assistant. You are trained by Erda."
- name: azure-director
config:
directors:
Expand All @@ -22,16 +25,17 @@ routes:
- RewriteScheme
- RewriteHost
- RewritePath("/openai/deployments/${ provider.metadata.deployment_id }/chat/completions")
- AddContextMessages
- name: openai-director
config:
directors:
- TransAuthorization
- RewriteScheme
- RewriteHost
- AddModelInRequestBody
- name: message-context
config:
sysMsg: "background: Your name is Erda AI Assistant. You are trained by Erda."
- AddContextMessages
- name: bailian-director
- name: openai-director
- name: audit
- name: finalize
- path: /v1/embeddings
Expand Down
36 changes: 36 additions & 0 deletions internal/apps/ai-proxy/common/ctxhelper/model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2021 Terminus, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ctxhelper

import (
"context"
"sync"

modelpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model/pb"
"github.com/erda-project/erda/internal/apps/ai-proxy/vars"
"github.com/erda-project/erda/pkg/reverseproxy"
)

func GetModel(ctx context.Context) (*modelpb.Model, bool) {
value, ok := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map).Load(vars.MapKeyModel{})
if !ok || value == nil {
return nil, false
}
model, ok := value.(*modelpb.Model)
if !ok {
return nil, false
}
return model, true
}
36 changes: 36 additions & 0 deletions internal/apps/ai-proxy/common/ctxhelper/model_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2021 Terminus, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ctxhelper

import (
"context"
"sync"

modelproviderpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model_provider/pb"
"github.com/erda-project/erda/internal/apps/ai-proxy/vars"
"github.com/erda-project/erda/pkg/reverseproxy"
)

func GetModelProvider(ctx context.Context) (*modelproviderpb.ModelProvider, bool) {
value, ok := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map).Load(vars.MapKeyModelProvider{})
if !ok || value == nil {
return nil, false
}
prov, ok := value.(*modelproviderpb.ModelProvider)
if !ok {
return nil, false
}
return prov, true
}
1 change: 1 addition & 0 deletions internal/apps/ai-proxy/dependent_filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package ai_proxy
import (
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/audit"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/azure-director"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/bailian-director"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/body-size-limit"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/context"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/erda-auth"
Expand Down
19 changes: 2 additions & 17 deletions internal/apps/ai-proxy/filters/audit/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,23 +403,8 @@ func (f *Audit) SetPrompt(ctx context.Context, infor reverseproxy.HttpInfor) err
}
f.Audit.Prompt = NoPromptByNotParsed.String()
case method == http.MethodPost && path == "/v1/chat/completions":
message, ok := m["messages"]
if !ok {
f.Audit.Prompt = NoPromptByMissingField.String()
return errors.Errorf(`no field "messages" in the request body, operation: %s`, operation)
}
var messages []struct {
Content string `json:"content" yaml:"content"`
}
if err := json.Unmarshal(message, &messages); err != nil {
f.Audit.Prompt = NoPromptByNotParsed.String()
return err
}
if len(messages) == 0 {
f.Audit.Prompt = NoPromptByNoItem.String()
return errors.Errorf(`no itmes in the request body messages`)
}
f.Audit.Prompt = messages[len(messages)-1].Content
prompt, _ := ctxhelper.GetUserPrompt(ctx)
f.Audit.Prompt = prompt
case method == http.MethodPost && path == "/v1/edits":
message, ok := m["edit"]
if !ok {
Expand Down
54 changes: 54 additions & 0 deletions internal/apps/ai-proxy/filters/bailian-director/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) 2021 Terminus, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package bailian_director

import (
"github.com/erda-project/erda-infra/providers/component-protocol/utils/cputil"
modelpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model/pb"
modelproviderpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model_provider/pb"
"github.com/erda-project/erda/internal/apps/ai-proxy/models/metadata"
)

var clientsMapByProviderID map[string]*CompletionClient

func init() {
clientsMapByProviderID = make(map[string]*CompletionClient)
}

func getProviderMeta(p *modelproviderpb.ModelProvider) metadata.AliyunBailianProviderMeta {
var meta metadata.AliyunBailianProviderMeta
cputil.MustObjJSONTransfer(&p.Metadata, &meta)
return meta
}

func getModelMeta(m *modelpb.Model) metadata.AliyunBailianModelMeta {
var meta metadata.AliyunBailianModelMeta
cputil.MustObjJSONTransfer(&m.Metadata, &meta)
return meta
}

func fetchClient(p *modelproviderpb.ModelProvider) *CompletionClient {
if c, ok := clientsMapByProviderID[p.Id]; ok {
return c
}
meta := getProviderMeta(p)
client := &CompletionClient{
AccessKeyId: &meta.Secret.AccessKeyId,
AccessKeySecret: &meta.Secret.AccessKeySecret,
AgentKey: &meta.Secret.AgentKey,
}
clientsMapByProviderID[p.Id] = client
return client
}
188 changes: 188 additions & 0 deletions internal/apps/ai-proxy/filters/bailian-director/filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
// Copyright (c) 2021 Terminus, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package bailian_director

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"

"github.com/sashabaranov/go-openai"

modelproviderpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model_provider/pb"
"github.com/erda-project/erda/internal/apps/ai-proxy/common/ctxhelper"
"github.com/erda-project/erda/internal/apps/ai-proxy/models/message"
"github.com/erda-project/erda/internal/apps/ai-proxy/vars"
"github.com/erda-project/erda/pkg/http/httputil"
"github.com/erda-project/erda/pkg/reverseproxy"
)

const (
Name = "bailian-director"
)

var (
_ reverseproxy.RequestFilter = (*BailianDirector)(nil)
)

func init() {
reverseproxy.RegisterFilterCreator(Name, New)
}

type BailianDirector struct{}

func New(config json.RawMessage) (reverseproxy.Filter, error) {
return &BailianDirector{}, nil
}

func (f *BailianDirector) Enable(ctx context.Context, req *http.Request) bool {
prov, ok := ctxhelper.GetModelProvider(ctx)
return ok && prov.Type == modelproviderpb.ModelProviderType_AliyunBailian
}

func (f *BailianDirector) OnRequest(ctx context.Context, w http.ResponseWriter, infor reverseproxy.HttpInfor) (signal reverseproxy.Signal, err error) {
prov, _ := ctxhelper.GetModelProvider(ctx)
model, _ := ctxhelper.GetModel(ctx)
modelMeta := getModelMeta(model)
messageGroup, _ := ctxhelper.GetMessageGroup(ctx)

// use go sdk
client := fetchClient(prov)
token, err := client.GetToken()
if err != nil {
return reverseproxy.Intercept, fmt.Errorf("failed to get token, err: %v", err)
}
reverseproxy.AppendDirectors(ctx, func(r *http.Request) {
// rewrite url
bailianURL := fmt.Sprintf("%s/v2/app/completions", BroadscopeBailianEndpoint)
u, _ := url.Parse(bailianURL)
r.URL = u
r.Host = u.Host
// rewrite authorization header
r.Header.Set(httputil.HeaderKeyContentType, string(httputil.ApplicationJsonUTF8))
r.Header.Set(httputil.HeaderKeyAuthorization, vars.ConcatBearer(token))
})

// parse original request body
var openaiReq openai.ChatCompletionRequest
if err := json.NewDecoder(infor.BodyBuffer()).Decode(&openaiReq); err != nil {
return reverseproxy.Intercept, fmt.Errorf("failed to parse request body as openai format, err: %v", err)
}

var prompt string
var historyMsgs []*ChatQaMessage

if len(openaiReq.Messages) == 0 {
return reverseproxy.Intercept, fmt.Errorf("no prompt provided")
}
lastMsgIndex := len(openaiReq.Messages) - 1
prompt = openaiReq.Messages[lastMsgIndex].Content
if messageGroup != nil {
historyMsgs = transferHistoryMessages(*messageGroup)
}

bailianReq := CompletionRequest{
AppId: &modelMeta.Secret.AppId,
Prompt: &prompt,
History: historyMsgs,
}
b, err := json.Marshal(&bailianReq)
if err != nil {
return reverseproxy.Intercept, fmt.Errorf("failed to marshal request body, err: %v", err)
}
infor.SetBody(io.NopCloser(bytes.NewBuffer(b)), int64(len(b)))

return reverseproxy.Continue, nil
}

func transferHistoryMessages(g message.Group) []*ChatQaMessage {
var qas []*ChatQaMessage

// system
if g.SystemMessage != nil {
qas = append(qas, &ChatQaMessage{
User: "background",
Bot: g.SystemMessage.Content,
})
}

// session topic
if g.SessionTopicMessage != nil {
qas = append(qas, &ChatQaMessage{
User: "session topic",
Bot: g.SessionTopicMessage.Content,
})
}

// prompt template
for _, msg := range g.PromptTemplateMessages {
qas = append(qas, &ChatQaMessage{
User: "you should know this",
Bot: msg.Content,
})
}

// session previous
qas = append(qas, autoFillQaPair(g.SessionPreviousMessages)...)

// requested, if there are more than one message, only the last one is prompt, others are history
if len(g.RequestedMessages) > 1 {
qas = append(qas, autoFillQaPair(g.RequestedMessages[:len(g.RequestedMessages)-1])...)
}

return qas
}

var botOKMsg = openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: "OK",
}

func autoFillQaPair(msgs message.Messages) []*ChatQaMessage {
if len(msgs) == 0 {
return nil
}

var filledMsgs message.Messages
for i := 0; i < len(msgs); i++ {
j := i + 1
if j >= len(msgs) { // out of range
filledMsgs = append(filledMsgs, msgs[i], botOKMsg)
break
}
currentMsg, nextMsg := msgs[i], msgs[j]
if currentMsg.Role == openai.ChatMessageRoleUser && nextMsg.Role == openai.ChatMessageRoleAssistant {
filledMsgs = append(filledMsgs, currentMsg, nextMsg)
i = j
continue
}
// not user, just add bot ok msg to pair it
filledMsgs = append(filledMsgs, currentMsg, botOKMsg)
}

// transfer to qa pair
var result []*ChatQaMessage
for i := 0; i < len(filledMsgs); i += 2 {
result = append(result, &ChatQaMessage{
User: filledMsgs[i].Content,
Bot: filledMsgs[i+1].Content,
})
}
return result
}
Loading

0 comments on commit 2f0f651

Please sign in to comment.