Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ai-proxy): add client token; add rate-limit filter #6045

Merged
merged 8 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions .erda/ai-proxy/migrations/ai-proxy/20230823-ai-proxy.sql
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,25 @@ CREATE TABLE `ai_proxy_filter_audit`
INDEX `idx_job_number` (`job_number`),
INDEX `idx_dingtalk_staff_id` (`dingtalk_staff_id`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COMMENT 'AI 审计表';

CREATE TABLE `ai_proxy_client_token`
(
`id` CHAR(36) NOT NULL COMMENT 'primary key',
`created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
`updated_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
`deleted_at` DATETIME NOT NULL DEFAULT '1970-01-01 00:00:00' COMMENT '删除时间, 1970-01-01 00:00:00 表示未删除',

`client_id` CHAR(36) NOT NULL COMMENT '会话所属的客户端 id',
`user_id` VARCHAR(191) NOT NULL COMMENT '客户端传入的自定义 user_id,客户端用来区分用户',
`token` CHAR(34) NOT NULL COMMENT 't_ 前缀,len: uuid(32)+2',
`expired_at` DATETIME NOT NULL DEFAULT '1970-01-01 00:00:00',
`metadata` MEDIUMTEXT NOT NULL COMMENT 'Token 元数据,主要包含 user 额外信息,用于审计',

PRIMARY KEY (`id`),
INDEX `idx_token` (`token`),
UNIQUE INDEX `unique_clientid_userid` (`client_id`, `user_id`, `deleted_at`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COMMENT 'AI 客户端 Token 表';
2 changes: 1 addition & 1 deletion .github/workflows/ci-it.yml
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ jobs:
- ./internal/apps/dop/...
- ./internal/apps/admin/... ./internal/apps/devflow/... ./internal/apps/gallery/...
- ./internal/apps/msp/...
- ./internal/apps/cmp/...
# - ./internal/apps/cmp/...
steps:
- name: Clone repo
uses: actions/checkout@v3
Expand Down
94 changes: 94 additions & 0 deletions api/proto/apps/aiproxy/client_token/client_token.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
syntax = "proto3";

package erda.apps.aiproxy;
option go_package = "github.com/erda-project/erda-proto-go/apps/aiproxy/client-token/pb";

import "google/api/annotations.proto";
import "apps/aiproxy/metadata/metadata.proto";
import "google/protobuf/timestamp.proto";
import "github.com/envoyproxy/protoc-gen-validate/validate/validate.proto";
import "common/http.proto";

service ClientTokenService {
rpc Create(ClientTokenCreateRequest) returns (ClientToken) {
option(google.api.http) = {
post: "/api/ai-proxy/clients/{clientId}/tokens?createOrGet={createOrGet}"
};
}

rpc Get(ClientTokenGetRequest) returns (ClientToken) {
option(google.api.http) = {
get: "/api/ai-proxy/clients/{clientId}/tokens/{token}"
};
}

rpc Delete(ClientTokenDeleteRequest) returns (common.VoidResponse) {
option(google.api.http) = {
delete: "/api/ai-proxy/clients/{clientId}/tokens/{token}"
};
}

rpc Update(ClientTokenUpdateRequest) returns (ClientToken) {
option(google.api.http) = {
put: "/api/ai-proxy/clients/{clientId}/tokens/{token}"
};
}

rpc Paging(ClientTokenPagingRequest) returns (ClientTokenPagingResponse) {
option(google.api.http) = {
get: "/api/ai-proxy/clients/{clientId}/tokens"
};
}
}

message ClientToken {
string id = 1;
google.protobuf.Timestamp createdAt = 2;
google.protobuf.Timestamp updatedAt = 3;
google.protobuf.Timestamp deletedAt = 4;

string clientId = 5 [(validate.rules).string = {len: 36}];
string userId = 6 [(validate.rules).string = {min_len:1, max_len: 191}];
string token = 7 [(validate.rules).string = {len: 34}];
google.protobuf.Timestamp expireAt = 8;
Metadata metadata = 9;
}

message ClientTokenCreateRequest {
string clientId = 1 [(validate.rules).string = {len: 36}];
string userId = 2 [(validate.rules).string = {min_len:1, max_len: 191}];
uint64 expireInHours = 3 [(validate.rules).uint64 = {ignore_empty: true, gte: 1, lte: 720}]; // max 30 days
Metadata metadata = 4;

bool createOrGet = 5; // get if token exists
}

message ClientTokenGetRequest {
string clientId = 1 [(validate.rules).string = {len: 36}];
string token = 2 [(validate.rules).string = {len: 34}];
}

message ClientTokenDeleteRequest {
string clientId = 1 [(validate.rules).string = {len: 36}];
string token = 2 [(validate.rules).string = {len: 34}];
}

message ClientTokenUpdateRequest {
string clientId = 1 [(validate.rules).string = {len: 36}];
string token = 2 [(validate.rules).string = {len: 34}];
uint64 expireInHours = 3 [(validate.rules).uint64 = {gte: 0, lte: 720}]; // max 30 days
Metadata metadata = 4;
}

message ClientTokenPagingRequest {
string clientId = 1 [(validate.rules).string = {len: 36}];
string userId = 2 [(validate.rules).string = {ignore_empty: true, max_len: 191}];
string token = 3 [(validate.rules).string = {ignore_empty: true, max_len: 34}];
uint64 pageNum = 4 [(validate.rules).uint64 = {ignore_empty: true, gte: 1}];
uint64 pageSize = 5 [(validate.rules).uint64 = {ignore_empty: true, gte: 1, lte: 1000}];
}

message ClientTokenPagingResponse {
int64 total = 1;
repeated ClientToken list = 2;
}
2 changes: 0 additions & 2 deletions cmd/ai-proxy/bootstrap.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ grpc-server@ai:

erda.app.ai-proxy:
routesRef: conf/routes.yml
providersRef: conf/providers.yml
platformsRef: conf/erda-platforms.yml
logLevel: ${LOG_LEVEL:debug}
openOnErda: ${OPEN_ON_ERDA:true} # 是否将 API 通过 Erda Openapi 暴露出来

Expand Down
1 change: 1 addition & 0 deletions internal/apps/ai-proxy/dependent_filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ import (
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/message-context"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/openai-director"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/prometheus-collector"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/rate-limit"
)
38 changes: 36 additions & 2 deletions internal/apps/ai-proxy/filters/audit/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ import (
"github.com/pkg/errors"

"github.com/erda-project/erda-infra/base/logs"
clientpb "github.com/erda-project/erda-proto-go/apps/aiproxy/client/pb"
clienttokenpb "github.com/erda-project/erda-proto-go/apps/aiproxy/client_token/pb"
modelpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model/pb"
modelproviderpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model_provider/pb"
"github.com/erda-project/erda/internal/apps/ai-proxy/models"
"github.com/erda-project/erda/internal/apps/ai-proxy/models/metadata"
"github.com/erda-project/erda/internal/apps/ai-proxy/providers/dao"
"github.com/erda-project/erda/internal/apps/ai-proxy/vars"
"github.com/erda-project/erda/pkg/http/httputil"
Expand Down Expand Up @@ -86,7 +89,8 @@ func (f *Audit) OnRequest(ctx context.Context, w http.ResponseWriter, infor reve
f.SetXRequestId,
f.SetRequestAt,
f.SetSource,
f.SetUserInfo,
f.SetUserInfoFromHeader,
f.SetUserInfoFromClientToken,
f.SetProvider,
f.SetModel,
f.SetOperationId,
Expand Down Expand Up @@ -225,7 +229,7 @@ func (f *Audit) SetSource(_ context.Context, header http.Header) error {
return nil
}

func (f *Audit) SetUserInfo(ctx context.Context, header http.Header) error {
func (f *Audit) SetUserInfoFromHeader(ctx context.Context, header http.Header) error {
f.Audit.Username = header.Get(vars.XAIProxyName)
if f.Audit.Username == "" {
f.Audit.Username = header.Get(vars.XAIProxyUsername)
Expand Down Expand Up @@ -253,6 +257,36 @@ func (f *Audit) SetUserInfo(ctx context.Context, header http.Header) error {
return nil
}

func (f *Audit) SetUserInfoFromClientToken(ctx context.Context) error {
_clientToken, ok := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map).Load(vars.MapKeyClientToken{})
if !ok || _clientToken == nil {
return nil
}
clientToken := _clientToken.(*clienttokenpb.ClientToken)
meta := metadata.FromProtobuf(clientToken.Metadata)
metaCfg := metadata.Config{IgnoreCase: true}
f.Audit.DingtalkStaffID = meta.MustGetValueByKey(vars.XAIProxyDingTalkStaffID, metaCfg)
f.Audit.Email = meta.MustGetValueByKey(vars.XAIProxyEmail, metaCfg)
f.Audit.JobNumber = meta.MustGetValueByKey(vars.XAIProxyJobNumber, metaCfg)
f.Audit.Username = meta.MustGetValueByKey(vars.XAIProxyName, metaCfg)
f.Audit.PhoneNumber = meta.MustGetValueByKey(vars.XAIProxyPhone, metaCfg)
if f.Audit.Source == "" { // use token's client's name
_client, ok := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map).Load(vars.MapKeyClient{})
if ok && _client != nil {
client := _client.(*clientpb.Client)
f.Audit.Source = client.Name
}
}
if f.Audit.Model == "" {
_model, ok := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map).Load(vars.MapKeyModel{})
if ok && _model != nil {
model := _model.(*modelpb.Model)
f.Audit.Model = model.Name
}
}
return nil
}

func (f *Audit) SetProvider(ctx context.Context) error {
prov, ok := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map).Load(vars.MapKeyModelProvider{})
if !ok || prov == nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/apps/ai-proxy/filters/audit/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func TestAudit_SetUserInfo(t *testing.T) {

f, _ := audit.New(nil)
a := f.(*audit.Audit)
if err := a.SetUserInfo(context.Background(), header); err != nil {
if err := a.SetUserInfoFromHeader(context.Background(), header); err != nil {
t.Fatal(err)
}
if a.Audit.Username != m[vars.XAIProxyName] {
Expand Down
49 changes: 36 additions & 13 deletions internal/apps/ai-proxy/filters/context/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,16 @@ import (

"github.com/erda-project/erda-infra/base/logs"
clientpb "github.com/erda-project/erda-proto-go/apps/aiproxy/client/pb"
clienttokenpb "github.com/erda-project/erda-proto-go/apps/aiproxy/client_token/pb"
modelpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model/pb"
modelproviderpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model_provider/pb"
promptpb "github.com/erda-project/erda-proto-go/apps/aiproxy/prompt/pb"
sessionpb "github.com/erda-project/erda-proto-go/apps/aiproxy/session/pb"
"github.com/erda-project/erda/internal/apps/ai-proxy/models/client_token"
"github.com/erda-project/erda/internal/apps/ai-proxy/models/metadata"
"github.com/erda-project/erda/internal/apps/ai-proxy/providers/dao"
"github.com/erda-project/erda/internal/apps/ai-proxy/vars"
"github.com/erda-project/erda/pkg/http/httputil"
"github.com/erda-project/erda/pkg/reverseproxy"
"github.com/erda-project/erda/pkg/strutil"
)
Expand Down Expand Up @@ -62,24 +65,44 @@ func (f *Context) OnRequest(ctx context.Context, w http.ResponseWriter, infor re
)
// find client
var client *clientpb.Client
ak := vars.TrimBearer(infor.Header().Get("Authorization"))
ak := vars.TrimBearer(infor.Header().Get(httputil.HeaderKeyAuthorization))
if ak == "" {
http.Error(w, "Authorization is required", http.StatusUnauthorized)
return reverseproxy.Intercept, nil
}
// try to remove Bearer
ak = strings.TrimPrefix(ak, "Bearer ")
clientPagingResult, err := q.ClientClient().Paging(ctx, &clientpb.ClientPagingRequest{
AccessKeyIds: []string{ak},
PageNum: 1,
PageSize: 1,
})
if err != nil || clientPagingResult.Total < 1 {
l.Errorf("failed to get client, access_key_id: %s, err: %v", ak, err)
http.Error(w, "Authorization is invalid", http.StatusForbidden)
return reverseproxy.Intercept, err
if strings.HasPrefix(ak, client_token.TokenPrefix) {
tokenPagingResp, err := q.ClientTokenClient().Paging(ctx, &clienttokenpb.ClientTokenPagingRequest{
PageSize: 1,
PageNum: 1,
Token: ak,
})
if err != nil || tokenPagingResp.Total < 1 {
l.Errorf("failed to get client token, token: %s, err: %v", ak, err)
http.Error(w, "Authorization is invalid", http.StatusForbidden)
return reverseproxy.Intercept, err
}
token := tokenPagingResp.List[0]
clientResp, err := q.ClientClient().Get(ctx, &clientpb.ClientGetRequest{ClientId: token.ClientId})
if err != nil {
l.Errorf("failed to get client, id: %s, err: %v", tokenPagingResp.List[0].ClientId, err)
http.Error(w, "Authorization is invalid", http.StatusForbidden)
return reverseproxy.Intercept, err
}
client = clientResp
m.Store(vars.MapKeyClientToken{}, token)
} else {
clientPagingResult, err := q.ClientClient().Paging(ctx, &clientpb.ClientPagingRequest{
AccessKeyIds: []string{ak},
PageNum: 1,
PageSize: 1,
})
if err != nil || clientPagingResult.Total < 1 {
l.Errorf("failed to get client, access_key_id: %s, err: %v", ak, err)
http.Error(w, "Authorization is invalid", http.StatusForbidden)
return reverseproxy.Intercept, err
}
client = clientPagingResult.List[0]
}
client = clientPagingResult.List[0]

// find model
var model *modelpb.Model
Expand Down
Empty file.
Loading