Skip to content

Commit

Permalink
chore: ai proxy (#6041)
Browse files Browse the repository at this point in the history
* model&model_provider support paging

* azure-director: set default api-version from model.metadata.public.api_version

* fix: treat ui input `undefined` as empty string

* mysql debug default to false

* remove useless code

* update gohub image to 1.0.9.alpha.1
  • Loading branch information
sfwn authored Sep 15, 2023
1 parent 168c3a4 commit 442d083
Show file tree
Hide file tree
Showing 21 changed files with 188 additions and 63 deletions.
2 changes: 1 addition & 1 deletion .erda/pipelines/ci-build-ce.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ stages:
- custom-script:
alias: erda
description: 运行自定义命令
image: registry.erda.cloud/erda/gohub:1.0.8
image: registry.erda.cloud/erda/gohub:1.0.9.alpha.1
commands:
- cp -a ${{ dirs.raw-erda }}/. .
- make proto-go-in-local
Expand Down
2 changes: 1 addition & 1 deletion .erda/pipelines/ci-deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ stages:
- stage:
- custom-script:
alias: erda
image: registry.erda.cloud/erda/gohub:1.0.8
image: registry.erda.cloud/erda/gohub:1.0.9.alpha.1
commands:
- cp -a ${{ dirs.raw-erda }}/. .
- make proto-go-in-local
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-it-noneed.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ jobs:
needs: CHECK-CHANGED-FILES
if: ${{ needs.CHECK-CHANGED-FILES.outputs.have-dbmigration == 'true' }}
container:
image: registry.erda.cloud/erda/gohub:1.0.8
image: registry.erda.cloud/erda/gohub:1.0.9.alpha.1
steps:
- name: Clone repo
uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-it.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
PREPARE:
runs-on: ubuntu-latest
container:
image: registry.erda.cloud/erda/gohub:1.0.8
image: registry.erda.cloud/erda/gohub:1.0.9.alpha.1
steps:
- name: Clone repo
uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion api/proto-go/Makefile

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 28 additions & 8 deletions api/proto/apps/aiproxy/model/model.proto
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ service ModelService {
delete: "/api/ai-proxy/models/{id}"
};
}

rpc Paging(ModelPagingRequest) returns (ModelPagingResponse) {
option(google.api.http) = {
get: "/api/ai-proxy/models?type={type}"
};
}
}

message Model {
Expand All @@ -50,18 +56,19 @@ message Model {
}

enum ModelType {
text_generation = 0;
image = 1;
audio = 2;
embedding = 3;
text_moderation = 4;
multimodal_text__visual = 5;
TYPE_UNSPECIFIED = 0;
text_generation = 1;
image = 2;
audio = 3;
embedding = 4;
text_moderation = 5;
multimodal_text__visual = 6;
}

message ModelCreateRequest {
string name = 1 [(validate.rules).string = {min_len: 4, max_len: 191}];
string desc = 2 [(validate.rules).string.max_len = 1024];
ModelType type = 3 [(validate.rules).enum = {defined_only: true}];
ModelType type = 3 [(validate.rules).enum = {defined_only: true, not_in: [0]}];
string providerId = 4 [(validate.rules).string = {len: 36}];
string apiKey = 5;
Metadata metadata = 6;
Expand All @@ -71,7 +78,7 @@ message ModelUpdateRequest {
string id = 1;
string name = 2 [(validate.rules).string = {min_len: 4, max_len: 191}];
string desc = 3 [(validate.rules).string.max_len = 1024];
ModelType type = 4;
ModelType type = 4 [(validate.rules).enum = {defined_only: true, not_in: [0]}];
string providerId = 5 [(validate.rules).string = {len: 36}];
string apiKey = 6;
Metadata metadata = 7;
Expand All @@ -83,4 +90,17 @@ message ModelDeleteRequest {

message ModelGetRequest {
string id = 1 [(validate.rules).string = {len: 36}];
}

message ModelPagingRequest {
uint64 pageNum = 1 [(validate.rules).uint64 = {ignore_empty: true, gte: 1}];
uint64 pageSize = 2 [(validate.rules).uint64 = {ignore_empty: true, gte: 1, lte: 1000}];
string name = 3 [(validate.rules).string = {ignore_empty: true, min_len: 2, max_len: 191}];
ModelType type = 4 [(validate.rules).enum = {defined_only: true}];
string providerId = 5 [(validate.rules).string = {ignore_empty: true, len: 36}];
}

message ModelPagingResponse {
int64 total = 1;
repeated Model list = 2;
}
27 changes: 23 additions & 4 deletions api/proto/apps/aiproxy/model_provider/model_provider.proto
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ service ModelProviderService {
put: "/api/ai-proxy/model-providers/{id}"
};
}

rpc Paging(ModelProviderPagingRequest) returns (ModelProviderPagingResponse) {
option(google.api.http) = {
get: "/api/ai-proxy/model-providers?type={type}"
};
}
}

message ModelProvider {
Expand All @@ -49,14 +55,15 @@ message ModelProvider {
}

enum ModelProviderType {
OpenAI = 0;
Azure = 1;
TYPE_UNSPECIFIED = 0;
OpenAI = 1;
Azure = 2;
}

message ModelProviderCreateRequest {
string name = 1 [(validate.rules).string = {min_len: 4, max_len: 191}];
string desc = 2 [(validate.rules).string = {min_len: 0, max_len: 1024}];
ModelProviderType type = 3 [(validate.rules).enum = {defined_only: true}];
ModelProviderType type = 3 [(validate.rules).enum = {defined_only: true, not_in: [0]}];
string apiKey = 4 [(validate.rules).string = {min_len: 0, max_len: 128}];
Metadata metadata = 5;
}
Expand All @@ -73,7 +80,19 @@ message ModelProviderUpdateRequest {
string id = 1 [(validate.rules).string = {len: 36}];
string name = 2 [(validate.rules).string = {min_len: 4, max_len: 191}];
string desc = 3 [(validate.rules).string = {min_len: 0, max_len: 1024}];
ModelProviderType type = 4 [(validate.rules).enum = {defined_only: true}];
ModelProviderType type = 4 [(validate.rules).enum = {defined_only: true, not_in: [0]}];
string apiKey = 5 [(validate.rules).string = {min_len: 0, max_len: 128}];
Metadata metadata = 6;
}

message ModelProviderPagingRequest {
uint64 pageNum = 1 [(validate.rules).uint64 = {ignore_empty: true, gte: 1}];
uint64 pageSize = 2 [(validate.rules).uint64 = {ignore_empty: true, gte: 1, lte: 1000}];
string name = 3 [(validate.rules).string = {ignore_empty: true, min_len: 2, max_len: 191}];
ModelProviderType type = 4 [(validate.rules).enum = {defined_only: true}];
}

message ModelProviderPagingResponse {
int64 total = 1;
repeated ModelProvider list = 2;
}
7 changes: 1 addition & 6 deletions cmd/ai-proxy/bootstrap.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@ erda.app.ai-proxy:
openOnErda: ${OPEN_ON_ERDA:true} # 是否将 API 通过 Erda Openapi 暴露出来

gorm.v2:
host: "${MYSQL_HOST}"
port: "${MYSQL_PORT}"
username: "${MYSQL_USERNAME}"
password: "${MYSQL_PASSWORD}"
database: "${MYSQL_DATABASE}"
debug: true
debug: ${MYSQL_DEBUG:false}

erda.apps.ai-proxy.dao:
erda.app.ai-proxy.metrics:
Expand Down
22 changes: 22 additions & 0 deletions internal/apps/ai-proxy/filters/azure-director/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,28 @@ func (f *AzureDirector) SetAPIKeyIfNotSpecified(ctx context.Context) error {
return nil
}

func (f *AzureDirector) SetModelAPIVersionIfNotSpecified(ctx context.Context) error {
value, ok := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map).Load(vars.MapKeyModel{})
if !ok || value == nil {
return errors.New("model not set in context map")
}
model := value.(*modelpb.Model)
meta := metadata.FromProtobuf(model.Metadata)
modelApiVersion, ok := meta.GetPublicValueByKey("api_version")
if !ok {
return nil
}
reverseproxy.AppendDirectors(ctx, func(req *http.Request) {
inputApiVersion := req.URL.Query().Get("api-version")
if inputApiVersion == "" {
queries := req.URL.Query()
queries.Set("api-version", modelApiVersion)
req.URL.RawQuery = queries.Encode()
}
})
return nil
}

func (f *AzureDirector) AddQueries(ctx context.Context) error {
return f.handleQueries(ctx, "AddQueries")
}
Expand Down
2 changes: 1 addition & 1 deletion internal/apps/ai-proxy/filters/context/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (f *Context) OnRequest(ctx context.Context, w http.ResponseWriter, infor re
// get from session if exists
headerSessionId := infor.Header().Get(vars.XAIProxySessionId)
headerModelId := infor.Header().Get(vars.XAIProxyModelId)
if headerSessionId != "" {
if headerSessionId != "" && headerSessionId != vars.UIValueUndefined {
_session, err := q.SessionClient().Get(ctx, &sessionpb.SessionGetRequest{Id: headerSessionId})
if err != nil {
l.Errorf("failed to get session, id: %s, err: %v", headerSessionId, err)
Expand Down
4 changes: 4 additions & 0 deletions internal/apps/ai-proxy/handlers/handler_model/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,7 @@ func (h *ModelHandler) Update(ctx context.Context, req *pb.ModelUpdateRequest) (
func (h *ModelHandler) Delete(ctx context.Context, req *pb.ModelDeleteRequest) (*commonpb.VoidResponse, error) {
return h.DAO.ModelClient().Delete(ctx, req)
}

func (h *ModelHandler) Paging(ctx context.Context, req *pb.ModelPagingRequest) (*pb.ModelPagingResponse, error) {
return h.DAO.ModelClient().Paging(ctx, req)
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,7 @@ func (h *ModelProviderHandler) Delete(ctx context.Context, req *pb.ModelProvider
func (h *ModelProviderHandler) Update(ctx context.Context, req *pb.ModelProviderUpdateRequest) (*pb.ModelProvider, error) {
return h.DAO.ModelProviderClient().Update(ctx, req)
}

func (h *ModelProviderHandler) Paging(ctx context.Context, req *pb.ModelProviderPagingRequest) (*pb.ModelProviderPagingResponse, error) {
return h.DAO.ModelProviderClient().Paging(ctx, req)
}
2 changes: 2 additions & 0 deletions internal/apps/ai-proxy/handlers/permission/method.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ var CheckModelProviderPerm = CheckPermissions(
&MethodPermission{Method: modelproviderpb.ModelProviderServiceServer.Get, OnlyAdmin: true},
&MethodPermission{Method: modelproviderpb.ModelProviderServiceServer.Update, OnlyAdmin: true},
&MethodPermission{Method: modelproviderpb.ModelProviderServiceServer.Delete, OnlyAdmin: true},
&MethodPermission{Method: modelproviderpb.ModelProviderServiceServer.Paging, OnlyAdmin: true},
)

var CheckModelPerm = CheckPermissions(
&MethodPermission{Method: modelpb.ModelServiceServer.Create, OnlyAdmin: true},
&MethodPermission{Method: modelpb.ModelServiceServer.Get, AdminOrAk: true},
&MethodPermission{Method: modelpb.ModelServiceServer.Update, OnlyAdmin: true},
&MethodPermission{Method: modelpb.ModelServiceServer.Delete, OnlyAdmin: true},
&MethodPermission{Method: modelpb.ModelServiceServer.Paging, OnlyAdmin: true},
)

var CheckClientModelRelationPerm = CheckPermissions(
Expand Down
22 changes: 12 additions & 10 deletions internal/apps/ai-proxy/models/metadata/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,26 +65,28 @@ func (m *Metadata) MergeMap() map[string]string {
return result
}

func (m *Metadata) GetPublicValueByKey(key string) interface{} {
func (m *Metadata) GetPublicValueByKey(key string) (string, bool) {
if m == nil {
return nil
return "", false
}
return m.Public[key]
v, ok := m.Public[key]
return v, ok
}

func (m *Metadata) GetSecretValueByKey(key string) interface{} {
func (m *Metadata) GetSecretValueByKey(key string) (string, bool) {
if m == nil {
return nil
return "", false
}
return m.Secret[key]
v, ok := m.Secret[key]
return v, ok
}

func (m *Metadata) GetValueByKey(key string) interface{} {
func (m *Metadata) GetValueByKey(key string) (string, bool) {
if m == nil {
return nil
return "", false
}
if v := m.GetPublicValueByKey(key); v != nil {
return v
if v, ok := m.GetPublicValueByKey(key); ok {
return v, ok
}
return m.GetSecretValueByKey(key)
}
Expand Down
32 changes: 32 additions & 0 deletions internal/apps/ai-proxy/models/model/dbclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,35 @@ func (dbClient *DBClient) Delete(ctx context.Context, req *pb.ModelDeleteRequest
}
return &commonpb.VoidResponse{}, nil
}

func (dbClient *DBClient) Paging(ctx context.Context, req *pb.ModelPagingRequest) (*pb.ModelPagingResponse, error) {
c := &Model{}
sql := dbClient.DB.Model(c)
if req.Name != "" {
sql = sql.Where("name LIKE ?", "%"+req.Name+"%")
}
if req.Type != pb.ModelType_TYPE_UNSPECIFIED {
c.Type = model_type.ModelType(req.Type)
}
c.ProviderID = req.ProviderId
sql = sql.Where(c)
var (
total int64
list Models
)
if req.PageNum == 0 {
req.PageNum = 1
}
if req.PageSize == 0 {
req.PageSize = 10
}
offset := (req.PageNum - 1) * req.PageSize
err := sql.Count(&total).Limit(int(req.PageSize)).Offset(int(offset)).Find(&list).Error
if err != nil {
return nil, err
}
return &pb.ModelPagingResponse{
Total: total,
List: list.ToProtobuf(),
}, nil
}
10 changes: 10 additions & 0 deletions internal/apps/ai-proxy/models/model/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,13 @@ func (m *Model) ToProtobuf() *pb.Model {
Metadata: m.Metadata.ToProtobuf(),
}
}

type Models []*Model

func (models Models) ToProtobuf() []*pb.Model {
var pbClients []*pb.Model
for _, c := range models {
pbClients = append(pbClients, c.ToProtobuf())
}
return pbClients
}
31 changes: 31 additions & 0 deletions internal/apps/ai-proxy/models/model_provider/dbclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,34 @@ func (dbClient *DBClient) Update(ctx context.Context, req *pb.ModelProviderUpdat
}
return dbClient.Get(ctx, &pb.ModelProviderGetRequest{Id: req.Id})
}

func (dbClient *DBClient) Paging(ctx context.Context, req *pb.ModelProviderPagingRequest) (*pb.ModelProviderPagingResponse, error) {
c := &ModelProvider{}
sql := dbClient.DB.Model(c)
if req.Name != "" {
sql = sql.Where("name LIKE ?", "%"+req.Name+"%")
}
if req.Type != pb.ModelProviderType_TYPE_UNSPECIFIED {
c.Type = model_provider_type.GetModelProviderTypeFromProtobuf(req.Type)
sql = sql.Where("type = ?", c.Type)
}
var (
total int64
list ModelProviders
)
if req.PageNum == 0 {
req.PageNum = 1
}
if req.PageSize == 0 {
req.PageSize = 10
}
offset := (req.PageNum - 1) * req.PageSize
err := sql.Count(&total).Limit(int(req.PageSize)).Offset(int(offset)).Find(&list).Error
if err != nil {
return nil, err
}
return &pb.ModelProviderPagingResponse{
Total: total,
List: list.ToProtobuf(),
}, nil
}
10 changes: 10 additions & 0 deletions internal/apps/ai-proxy/models/model_provider/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,13 @@ func (m *ModelProvider) ToProtobuf() *pb.ModelProvider {
Metadata: m.Metadata.ToProtobuf(),
}
}

type ModelProviders []*ModelProvider

func (modelProviders ModelProviders) ToProtobuf() []*pb.ModelProvider {
var pbClients []*pb.ModelProvider
for _, c := range modelProviders {
pbClients = append(pbClients, c.ToProtobuf())
}
return pbClients
}
Loading

0 comments on commit 442d083

Please sign in to comment.