Skip to content

Commit

Permalink
feat(ai-proxy): add client token; add rate-limit filter (#6045)
Browse files Browse the repository at this point in the history
* add api for client-token

* support check client token

* polish token create: add metadata; support createOrGet

* remove useless configs: providers/platforms

* fulfill audit from client token

* add rate-limit filter for client token

* fix golangci-lint

* comment out CODE-TEST cmp temporarily
  • Loading branch information
sfwn authored Sep 19, 2023
1 parent 442d083 commit 3b2bb5c
Show file tree
Hide file tree
Showing 22 changed files with 715 additions and 126 deletions.
22 changes: 22 additions & 0 deletions .erda/ai-proxy/migrations/ai-proxy/20230823-ai-proxy.sql
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,25 @@ CREATE TABLE `ai_proxy_filter_audit`
INDEX `idx_job_number` (`job_number`),
INDEX `idx_dingtalk_staff_id` (`dingtalk_staff_id`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COMMENT 'AI 审计表';

CREATE TABLE `ai_proxy_client_token`
(
`id` CHAR(36) NOT NULL COMMENT 'primary key',
`created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
`updated_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
`deleted_at` DATETIME NOT NULL DEFAULT '1970-01-01 00:00:00' COMMENT '删除时间, 1970-01-01 00:00:00 表示未删除',

`client_id` CHAR(36) NOT NULL COMMENT '会话所属的客户端 id',
`user_id` VARCHAR(191) NOT NULL COMMENT '客户端传入的自定义 user_id,客户端用来区分用户',
`token` CHAR(34) NOT NULL COMMENT 't_ 前缀,len: uuid(32)+2',
`expired_at` DATETIME NOT NULL DEFAULT '1970-01-01 00:00:00',
`metadata` MEDIUMTEXT NOT NULL COMMENT 'Token 元数据,主要包含 user 额外信息,用于审计',

PRIMARY KEY (`id`),
INDEX `idx_token` (`token`),
UNIQUE INDEX `unique_clientid_userid` (`client_id`, `user_id`, `deleted_at`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COMMENT 'AI 客户端 Token 表';
2 changes: 1 addition & 1 deletion .github/workflows/ci-it.yml
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ jobs:
- ./internal/apps/dop/...
- ./internal/apps/admin/... ./internal/apps/devflow/... ./internal/apps/gallery/...
- ./internal/apps/msp/...
- ./internal/apps/cmp/...
# - ./internal/apps/cmp/...
steps:
- name: Clone repo
uses: actions/checkout@v3
Expand Down
94 changes: 94 additions & 0 deletions api/proto/apps/aiproxy/client_token/client_token.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
syntax = "proto3";

package erda.apps.aiproxy;
option go_package = "github.com/erda-project/erda-proto-go/apps/aiproxy/client-token/pb";

import "google/api/annotations.proto";
import "apps/aiproxy/metadata/metadata.proto";
import "google/protobuf/timestamp.proto";
import "github.com/envoyproxy/protoc-gen-validate/validate/validate.proto";
import "common/http.proto";

service ClientTokenService {
rpc Create(ClientTokenCreateRequest) returns (ClientToken) {
option(google.api.http) = {
post: "/api/ai-proxy/clients/{clientId}/tokens?createOrGet={createOrGet}"
};
}

rpc Get(ClientTokenGetRequest) returns (ClientToken) {
option(google.api.http) = {
get: "/api/ai-proxy/clients/{clientId}/tokens/{token}"
};
}

rpc Delete(ClientTokenDeleteRequest) returns (common.VoidResponse) {
option(google.api.http) = {
delete: "/api/ai-proxy/clients/{clientId}/tokens/{token}"
};
}

rpc Update(ClientTokenUpdateRequest) returns (ClientToken) {
option(google.api.http) = {
put: "/api/ai-proxy/clients/{clientId}/tokens/{token}"
};
}

rpc Paging(ClientTokenPagingRequest) returns (ClientTokenPagingResponse) {
option(google.api.http) = {
get: "/api/ai-proxy/clients/{clientId}/tokens"
};
}
}

message ClientToken {
string id = 1;
google.protobuf.Timestamp createdAt = 2;
google.protobuf.Timestamp updatedAt = 3;
google.protobuf.Timestamp deletedAt = 4;

string clientId = 5 [(validate.rules).string = {len: 36}];
string userId = 6 [(validate.rules).string = {min_len:1, max_len: 191}];
string token = 7 [(validate.rules).string = {len: 34}];
google.protobuf.Timestamp expireAt = 8;
Metadata metadata = 9;
}

message ClientTokenCreateRequest {
string clientId = 1 [(validate.rules).string = {len: 36}];
string userId = 2 [(validate.rules).string = {min_len:1, max_len: 191}];
uint64 expireInHours = 3 [(validate.rules).uint64 = {ignore_empty: true, gte: 1, lte: 720}]; // max 30 days
Metadata metadata = 4;

bool createOrGet = 5; // get if token exists
}

message ClientTokenGetRequest {
string clientId = 1 [(validate.rules).string = {len: 36}];
string token = 2 [(validate.rules).string = {len: 34}];
}

message ClientTokenDeleteRequest {
string clientId = 1 [(validate.rules).string = {len: 36}];
string token = 2 [(validate.rules).string = {len: 34}];
}

message ClientTokenUpdateRequest {
string clientId = 1 [(validate.rules).string = {len: 36}];
string token = 2 [(validate.rules).string = {len: 34}];
uint64 expireInHours = 3 [(validate.rules).uint64 = {gte: 0, lte: 720}]; // max 30 days
Metadata metadata = 4;
}

message ClientTokenPagingRequest {
string clientId = 1 [(validate.rules).string = {len: 36}];
string userId = 2 [(validate.rules).string = {ignore_empty: true, max_len: 191}];
string token = 3 [(validate.rules).string = {ignore_empty: true, max_len: 34}];
uint64 pageNum = 4 [(validate.rules).uint64 = {ignore_empty: true, gte: 1}];
uint64 pageSize = 5 [(validate.rules).uint64 = {ignore_empty: true, gte: 1, lte: 1000}];
}

message ClientTokenPagingResponse {
int64 total = 1;
repeated ClientToken list = 2;
}
2 changes: 0 additions & 2 deletions cmd/ai-proxy/bootstrap.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ grpc-server@ai:

erda.app.ai-proxy:
routesRef: conf/routes.yml
providersRef: conf/providers.yml
platformsRef: conf/erda-platforms.yml
logLevel: ${LOG_LEVEL:debug}
openOnErda: ${OPEN_ON_ERDA:true} # 是否将 API 通过 Erda Openapi 暴露出来

Expand Down
1 change: 1 addition & 0 deletions internal/apps/ai-proxy/dependent_filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ import (
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/message-context"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/openai-director"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/prometheus-collector"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/rate-limit"
)
38 changes: 36 additions & 2 deletions internal/apps/ai-proxy/filters/audit/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ import (
"github.com/pkg/errors"

"github.com/erda-project/erda-infra/base/logs"
clientpb "github.com/erda-project/erda-proto-go/apps/aiproxy/client/pb"
clienttokenpb "github.com/erda-project/erda-proto-go/apps/aiproxy/client_token/pb"
modelpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model/pb"
modelproviderpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model_provider/pb"
"github.com/erda-project/erda/internal/apps/ai-proxy/models"
"github.com/erda-project/erda/internal/apps/ai-proxy/models/metadata"
"github.com/erda-project/erda/internal/apps/ai-proxy/providers/dao"
"github.com/erda-project/erda/internal/apps/ai-proxy/vars"
"github.com/erda-project/erda/pkg/http/httputil"
Expand Down Expand Up @@ -86,7 +89,8 @@ func (f *Audit) OnRequest(ctx context.Context, w http.ResponseWriter, infor reve
f.SetXRequestId,
f.SetRequestAt,
f.SetSource,
f.SetUserInfo,
f.SetUserInfoFromHeader,
f.SetUserInfoFromClientToken,
f.SetProvider,
f.SetModel,
f.SetOperationId,
Expand Down Expand Up @@ -225,7 +229,7 @@ func (f *Audit) SetSource(_ context.Context, header http.Header) error {
return nil
}

func (f *Audit) SetUserInfo(ctx context.Context, header http.Header) error {
func (f *Audit) SetUserInfoFromHeader(ctx context.Context, header http.Header) error {
f.Audit.Username = header.Get(vars.XAIProxyName)
if f.Audit.Username == "" {
f.Audit.Username = header.Get(vars.XAIProxyUsername)
Expand Down Expand Up @@ -253,6 +257,36 @@ func (f *Audit) SetUserInfo(ctx context.Context, header http.Header) error {
return nil
}

func (f *Audit) SetUserInfoFromClientToken(ctx context.Context) error {
_clientToken, ok := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map).Load(vars.MapKeyClientToken{})
if !ok || _clientToken == nil {
return nil
}
clientToken := _clientToken.(*clienttokenpb.ClientToken)
meta := metadata.FromProtobuf(clientToken.Metadata)
metaCfg := metadata.Config{IgnoreCase: true}
f.Audit.DingtalkStaffID = meta.MustGetValueByKey(vars.XAIProxyDingTalkStaffID, metaCfg)
f.Audit.Email = meta.MustGetValueByKey(vars.XAIProxyEmail, metaCfg)
f.Audit.JobNumber = meta.MustGetValueByKey(vars.XAIProxyJobNumber, metaCfg)
f.Audit.Username = meta.MustGetValueByKey(vars.XAIProxyName, metaCfg)
f.Audit.PhoneNumber = meta.MustGetValueByKey(vars.XAIProxyPhone, metaCfg)
if f.Audit.Source == "" { // use token's client's name
_client, ok := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map).Load(vars.MapKeyClient{})
if ok && _client != nil {
client := _client.(*clientpb.Client)
f.Audit.Source = client.Name
}
}
if f.Audit.Model == "" {
_model, ok := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map).Load(vars.MapKeyModel{})
if ok && _model != nil {
model := _model.(*modelpb.Model)
f.Audit.Model = model.Name
}
}
return nil
}

func (f *Audit) SetProvider(ctx context.Context) error {
prov, ok := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map).Load(vars.MapKeyModelProvider{})
if !ok || prov == nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/apps/ai-proxy/filters/audit/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func TestAudit_SetUserInfo(t *testing.T) {

f, _ := audit.New(nil)
a := f.(*audit.Audit)
if err := a.SetUserInfo(context.Background(), header); err != nil {
if err := a.SetUserInfoFromHeader(context.Background(), header); err != nil {
t.Fatal(err)
}
if a.Audit.Username != m[vars.XAIProxyName] {
Expand Down
49 changes: 36 additions & 13 deletions internal/apps/ai-proxy/filters/context/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,16 @@ import (

"github.com/erda-project/erda-infra/base/logs"
clientpb "github.com/erda-project/erda-proto-go/apps/aiproxy/client/pb"
clienttokenpb "github.com/erda-project/erda-proto-go/apps/aiproxy/client_token/pb"
modelpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model/pb"
modelproviderpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model_provider/pb"
promptpb "github.com/erda-project/erda-proto-go/apps/aiproxy/prompt/pb"
sessionpb "github.com/erda-project/erda-proto-go/apps/aiproxy/session/pb"
"github.com/erda-project/erda/internal/apps/ai-proxy/models/client_token"
"github.com/erda-project/erda/internal/apps/ai-proxy/models/metadata"
"github.com/erda-project/erda/internal/apps/ai-proxy/providers/dao"
"github.com/erda-project/erda/internal/apps/ai-proxy/vars"
"github.com/erda-project/erda/pkg/http/httputil"
"github.com/erda-project/erda/pkg/reverseproxy"
"github.com/erda-project/erda/pkg/strutil"
)
Expand Down Expand Up @@ -62,24 +65,44 @@ func (f *Context) OnRequest(ctx context.Context, w http.ResponseWriter, infor re
)
// find client
var client *clientpb.Client
ak := vars.TrimBearer(infor.Header().Get("Authorization"))
ak := vars.TrimBearer(infor.Header().Get(httputil.HeaderKeyAuthorization))
if ak == "" {
http.Error(w, "Authorization is required", http.StatusUnauthorized)
return reverseproxy.Intercept, nil
}
// try to remove Bearer
ak = strings.TrimPrefix(ak, "Bearer ")
clientPagingResult, err := q.ClientClient().Paging(ctx, &clientpb.ClientPagingRequest{
AccessKeyIds: []string{ak},
PageNum: 1,
PageSize: 1,
})
if err != nil || clientPagingResult.Total < 1 {
l.Errorf("failed to get client, access_key_id: %s, err: %v", ak, err)
http.Error(w, "Authorization is invalid", http.StatusForbidden)
return reverseproxy.Intercept, err
if strings.HasPrefix(ak, client_token.TokenPrefix) {
tokenPagingResp, err := q.ClientTokenClient().Paging(ctx, &clienttokenpb.ClientTokenPagingRequest{
PageSize: 1,
PageNum: 1,
Token: ak,
})
if err != nil || tokenPagingResp.Total < 1 {
l.Errorf("failed to get client token, token: %s, err: %v", ak, err)
http.Error(w, "Authorization is invalid", http.StatusForbidden)
return reverseproxy.Intercept, err
}
token := tokenPagingResp.List[0]
clientResp, err := q.ClientClient().Get(ctx, &clientpb.ClientGetRequest{ClientId: token.ClientId})
if err != nil {
l.Errorf("failed to get client, id: %s, err: %v", tokenPagingResp.List[0].ClientId, err)
http.Error(w, "Authorization is invalid", http.StatusForbidden)
return reverseproxy.Intercept, err
}
client = clientResp
m.Store(vars.MapKeyClientToken{}, token)
} else {
clientPagingResult, err := q.ClientClient().Paging(ctx, &clientpb.ClientPagingRequest{
AccessKeyIds: []string{ak},
PageNum: 1,
PageSize: 1,
})
if err != nil || clientPagingResult.Total < 1 {
l.Errorf("failed to get client, access_key_id: %s, err: %v", ak, err)
http.Error(w, "Authorization is invalid", http.StatusForbidden)
return reverseproxy.Intercept, err
}
client = clientPagingResult.List[0]
}
client = clientPagingResult.List[0]

// find model
var model *modelpb.Model
Expand Down
Empty file.
Loading

0 comments on commit 3b2bb5c

Please sign in to comment.