Skip to content

Commit

Permalink
feat(ai-proxy): support STT (Speach to text) (#6185)
Browse files Browse the repository at this point in the history
* add audio transcription API in routes.yml

* support proxy audio API

* check&modify multipart body

* polish and abstract gen curl logic for multipart-form

* polish code

* fix goimports
  • Loading branch information
sfwn authored Dec 21, 2023
1 parent 142504b commit 4be7de9
Show file tree
Hide file tree
Showing 17 changed files with 567 additions and 46 deletions.
5 changes: 5 additions & 0 deletions api/proto/apps/aiproxy/audit/audit.proto
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ message AuditUpdateRequestAfterContextParsed {
string identityJobNumber = 12;
string username = 13;
string identityPhoneNumber = 14;

// audio info parsed from multipart/form-data request body
string audioFileName = 15;
string audioFileSize = 16;
string audioFileHeaders = 17;
}

message AuditUpdateRequestAfterLLMDirectorInvoke {
Expand Down
34 changes: 33 additions & 1 deletion cmd/ai-proxy/conf/routes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ routes:
- name: context
- name: context-chat
config:
#sysMsg: "background: Your name is Erda AI Assistant. You are trained by Erda."
#sysMsg: "background: Your name is Erda AI Assistant. You are trained by Erda."
- name: audit-before-llm-director
- name: azure-director
config:
Expand Down Expand Up @@ -67,3 +67,35 @@ routes:
- AddModelInRequestBody
- name: audit-after-llm-director
- name: finalize
- path: /v1/audio/transcriptions
method: POST
router: null
filters:
- name: initialize
- name: log-http
- name: rate-limit
- name: context
- name: context-audio
config:
maxAudioSize: 25MB
supportedAudioFileTypes: [ "flac", "mp3", "mp4", "mpeg", "mpga", "m4a", "ogg", "wav", "webm" ]
defaultOpenAIAudioModel: whisper-1
- name: audit-before-llm-director
- name: azure-director
config:
directors:
- TransAuthorization
- SetModelAPIVersionIfNotSpecified
- DefaultQueries("api-version=2023-09-01-preview")
- RewriteScheme
- RewriteHost
- RewritePath("/openai/deployments/${ provider.metadata.deployment_id }/audio/transcriptions")
- name: openai-director
config:
directors:
- TransAuthorization
- RewriteScheme
- RewriteHost
- AddModelInRequestBody
- name: audit-after-llm-director
- name: finalize
49 changes: 49 additions & 0 deletions internal/apps/ai-proxy/common/ctxhelper/audio.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) 2021 Terminus, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ctxhelper

import (
"context"
"net/textproto"
"sync"

"github.com/pyroscope-io/pyroscope/pkg/util/bytesize"

"github.com/erda-project/erda/internal/apps/ai-proxy/vars"
"github.com/erda-project/erda/pkg/reverseproxy"
)

type AudioInfo struct {
FileName string `json:"fileName"`
FileSize bytesize.ByteSize `json:"fileSize"`
FileHeaders textproto.MIMEHeader `json:"fileHeaders"`
}

func GetAudioInfo(ctx context.Context) (*AudioInfo, bool) {
value, ok := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map).Load(vars.MapKeyAudioInfo{})
if !ok || value == nil {
return nil, false
}
info, ok := value.(*AudioInfo)
if !ok {
return nil, false
}
return info, true
}

func PutAudioInfo(ctx context.Context, info AudioInfo) {
m := ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map)
m.Store(vars.MapKeyAudioInfo{}, &info)
}
1 change: 1 addition & 0 deletions internal/apps/ai-proxy/dependent_filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/bailian-director"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/body-size-limit"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/context"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/context-audio"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/context-chat"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/context-embedding"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/erda-auth"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ func (f *Filter) OnActualRequest(ctx context.Context, infor reverseproxy.HttpInf
dbClient := dao.AuditClient()
_, err := dbClient.UpdateAfterLLMDirectorInvoke(ctx, &updateReq)
if err != nil {
// log it
l := ctxhelper.GetLogger(ctx)
l.Errorf("failed to update audit after llm director invoke, audit id: %s, err: %v", auditRecID, err)
}
return
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ func (f *Filter) OnResponseEOFImmutable(ctx context.Context, infor reverseproxy.
dbClient := dao.AuditClient()
_, err = dbClient.UpdateAfterLLMResponse(ctx, &updateReq)
if err != nil {
// log it
l := ctxhelper.GetLogger(ctx)
l.Errorf("failed to update audit after llm response, audit id: %s, err: %v", auditRecID, err)
}
return nil
}
74 changes: 50 additions & 24 deletions internal/apps/ai-proxy/filters/audit-before-llm-director/req.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

"github.com/erda-project/erda-infra/providers/component-protocol/utils/cputil"
"github.com/erda-project/erda-proto-go/apps/aiproxy/audit/pb"
modelpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model/pb"
"github.com/erda-project/erda/internal/apps/ai-proxy/common/ctxhelper"
"github.com/erda-project/erda/internal/apps/ai-proxy/models/metadata"
"github.com/erda-project/erda/internal/apps/ai-proxy/vars"
Expand Down Expand Up @@ -69,9 +70,12 @@ func (f *Filter) OnOriginalRequest(ctx context.Context, infor reverseproxy.HttpI
// insert audit into db
newAudit, err := ctxhelper.MustGetDBClient(ctx).AuditClient().CreateWhenReceived(ctx, &createReq)
if err != nil {
// log it
l := ctxhelper.GetLogger(ctx)
l.Errorf("failed to create audit: %v", err)
}
if newAudit != nil {
ctxhelper.PutAuditID(ctx, newAudit.Id)
}
ctxhelper.PutAuditID(ctx, newAudit.Id)
return
}

Expand All @@ -85,11 +89,6 @@ func (f *Filter) OnRequestBeforeLLMDirector(ctx context.Context, w http.Response
return
}

var openaiReq openai.ChatCompletionRequest
if err := json.NewDecoder(infor.BodyBuffer()).Decode(&openaiReq); err != nil {
return reverseproxy.Continue, err
}

var updateReq pb.AuditUpdateRequestAfterContextParsed
updateReq.AuditId = auditRecID
// prompt
Expand All @@ -113,24 +112,49 @@ func (f *Filter) OnRequestBeforeLLMDirector(ctx context.Context, w http.Response
// operation id
updateReq.OperationId = infor.Method() + " " + infor.URL().Path

// metadata
updateReq.RequestFunctionCallName = func() string {
// switch type
switch openaiReq.FunctionCall.(type) {
case string:
return openaiReq.FunctionCall.(string)
case map[string]interface{}:
var reqFuncCall openai.FunctionCall
cputil.MustObjJSONTransfer(openaiReq.FunctionCall, &reqFuncCall)
return reqFuncCall.Name
case openai.FunctionCall:
return openaiReq.FunctionCall.(openai.FunctionCall).Name
case nil:
return ""
default:
return fmt.Sprintf("%v", openaiReq.FunctionCall)
// metadata, routing by model type
switch model.Type {
case modelpb.ModelType_text_generation:
var openaiReq openai.ChatCompletionRequest
if err := json.NewDecoder(infor.BodyBuffer()).Decode(&openaiReq); err != nil {
goto Next
}
}()
updateReq.RequestFunctionCallName = func() string {
// switch type
switch openaiReq.FunctionCall.(type) {
case string:
return openaiReq.FunctionCall.(string)
case map[string]interface{}:
var reqFuncCall openai.FunctionCall
cputil.MustObjJSONTransfer(openaiReq.FunctionCall, &reqFuncCall)
return reqFuncCall.Name
case openai.FunctionCall:
return openaiReq.FunctionCall.(openai.FunctionCall).Name
case nil:
return ""
default:
return fmt.Sprintf("%v", openaiReq.FunctionCall)
}
}()
case modelpb.ModelType_audio:
audioInfo, ok := ctxhelper.GetAudioInfo(ctx)
if !ok {
goto Next
}
updateReq.AudioFileName = audioInfo.FileName
updateReq.AudioFileSize = audioInfo.FileSize.String()
updateReq.AudioFileHeaders = func() string {
b, err := json.Marshal(audioInfo.FileHeaders)
if err != nil {
return err.Error()
}
return string(b)
}()
default:
// do nothing
}

Next:

// set from client token
setUserInfoFromClientToken(ctx, infor, &updateReq)
Expand All @@ -139,6 +163,8 @@ func (f *Filter) OnRequestBeforeLLMDirector(ctx context.Context, w http.Response
_, err = ctxhelper.MustGetDBClient(ctx).AuditClient().UpdateAfterContextParsed(ctx, &updateReq)
if err != nil {
// log it
l := ctxhelper.GetLogger(ctx)
l.Errorf("failed to update audit: %v", err)
}
return reverseproxy.Continue, nil
}
Expand Down
27 changes: 22 additions & 5 deletions internal/apps/ai-proxy/filters/audit-before-llm-director/resp.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/sashabaranov/go-openai"

"github.com/erda-project/erda-proto-go/apps/aiproxy/audit/pb"
modelpb "github.com/erda-project/erda-proto-go/apps/aiproxy/model/pb"
"github.com/erda-project/erda/internal/apps/ai-proxy/common/ctxhelper"
"github.com/erda-project/erda/pkg/reverseproxy"
)
Expand All @@ -38,11 +39,26 @@ func (f *Filter) OnResponseEOFImmutable(ctx context.Context, infor reverseproxy.
}
respBuffer := ctxhelper.GetLLMDirectorActualResponseBuffer(ctx)
var completion, responseFunctionCallName string
if ctxhelper.GetIsStream(ctx) {
completion, responseFunctionCallName = ExtractEventStreamCompletionAndFcName(respBuffer.String())
} else {
completion, responseFunctionCallName = ExtractApplicationJsonCompletionAndFcName(respBuffer.String())

// routing by model type
model, _ := ctxhelper.GetModel(ctx)
switch model.Type {
case modelpb.ModelType_text_generation:
if ctxhelper.GetIsStream(ctx) {
completion, responseFunctionCallName = ExtractEventStreamCompletionAndFcName(respBuffer.String())
} else {
completion, responseFunctionCallName = ExtractApplicationJsonCompletionAndFcName(respBuffer.String())
}
case modelpb.ModelType_audio:
var openaiAudioResp openai.AudioResponse
respBufferStr := respBuffer.String()
if err := json.NewDecoder(respBuffer).Decode(&openaiAudioResp); err == nil {
completion = openaiAudioResp.Text
} else {
completion = respBufferStr
}
}

// collect actual llm response info
updateReq := pb.AuditUpdateRequestAfterLLMDirectorResponse{
AuditId: auditRecID,
Expand All @@ -63,7 +79,8 @@ func (f *Filter) OnResponseEOFImmutable(ctx context.Context, infor reverseproxy.
dbClient := dao.AuditClient()
_, err = dbClient.UpdateAfterLLMDirectorResponse(ctx, &updateReq)
if err != nil {
// log it
l := ctxhelper.GetLogger(ctx)
l.Errorf("failed to update audit on response EOF, audit id: %s, err: %v", auditRecID, err)
}
return nil
}
Expand Down
Loading

0 comments on commit 4be7de9

Please sign in to comment.