Skip to content

Commit

Permalink
feat: refactor proxy handling and update Vertex AI model mappings and…
Browse files Browse the repository at this point in the history
… added streaming support.
  • Loading branch information
Gyarbij committed Dec 20, 2024
1 parent 1bde53e commit 6eb5fb3
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 93 deletions.
18 changes: 7 additions & 11 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,31 +331,27 @@ func handleProxy(c *gin.Context) {
return
}

var server http.Handler

// Choose the proxy based on ProxyMode or specific environment variables
switch ProxyMode {
case "azure":
server = azure.NewOpenAIReverseProxy()
server := azure.NewOpenAIReverseProxy()
server.ServeHTTP(c.Writer, c.Request)
case "google":
google.HandleGoogleAIProxy(c)
return // Add this return statement
case "vertex":
server = vertex.NewVertexAIReverseProxy()
vertex.HandleVertexAIProxy(c) // Call HandleVertexAIProxy directly
default:
// Default to Azure if not specified, but only if the endpoint is set
if os.Getenv("AZURE_OPENAI_ENDPOINT") != "" {
server = azure.NewOpenAIReverseProxy()
server := azure.NewOpenAIReverseProxy()
server.ServeHTTP(c.Writer, c.Request)
} else {
// If no endpoint is configured, default to OpenAI
server = openai.NewOpenAIReverseProxy()
server := openai.NewOpenAIReverseProxy()
server.ServeHTTP(c.Writer, c.Request)
}
}

if ProxyMode != "google" {
server.ServeHTTP(c.Writer, c.Request)
}

if c.Writer.Header().Get("Content-Type") == "text/event-stream" {
if _, err := c.Writer.Write([]byte("\n")); err != nil {
log.Printf("rewrite response error: %v", err)
Expand Down
244 changes: 162 additions & 82 deletions pkg/vertex/proxy.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
package vertex

import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/http/httputil"
"net/url"
"os/exec"
"os"
"strings"

"github.com/gin-gonic/gin"
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)

var (
Expand All @@ -18,10 +22,10 @@ var (
VertexAIAPIVersion = "v1"
VertexAILocation = "us-central1"
VertexAIModelMapper = map[string]string{
"chat-bison": "chat-bison@001",
"text-bison": "text-bison@001",
"embedding-gecko": "textembedding-gecko@001",
"embedding-gecko-multilingual": "textembedding-gecko-multilingual@001",
"chat-bison": "chat-bison@002",
"text-bison": "text-bison@002",
"embedding-gecko": "textembedding-gecko@003",
"embedding-gecko-multilingual": "textembedding-gecko-multilingual@003",
}
)

Expand All @@ -38,89 +42,166 @@ func Init(projectID string) {
log.Printf("Vertex AI initialized with Project ID: %s", projectID)
}

func NewVertexAIReverseProxy() *httputil.ReverseProxy {
config := &VertexAIConfig{
ProjectID: VertexAIProjectID,
Endpoint: VertexAIEndpoint,
APIVersion: VertexAIAPIVersion,
Location: VertexAILocation,
ModelMapper: VertexAIModelMapper,
func HandleVertexAIProxy(c *gin.Context) {
if VertexAIProjectID == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Vertex AI Project ID not set"})
return
}

return newVertexAIReverseProxy(config)
}
ctx := context.Background()

func newVertexAIReverseProxy(config *VertexAIConfig) *httputil.ReverseProxy {
director := func(req *http.Request) {
originalURL := req.URL.String()
model := getModelFromRequest(req)
// Use the GOOGLE_APPLICATION_CREDENTIALS environment variable to set the credentials
creds := option.WithCredentialsFile(os.Getenv("GOOGLE_APPLICATION_CREDENTIALS"))
client, err := genai.NewClient(ctx, creds)
if err != nil {
log.Printf("Error creating Vertex AI client: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create Vertex AI client"})
return
}
defer client.Close()

// Map the model name if necessary
if mappedModel, ok := config.ModelMapper[strings.ToLower(model)]; ok {
model = mappedModel
}
modelName := getModelFromRequestBody(c.Request)
if mappedModel, ok := VertexAIModelMapper[strings.ToLower(modelName)]; ok {
modelName = mappedModel
}

// Construct the new URL
targetURL := fmt.Sprintf("https://%s/%s/projects/%s/locations/%s/publishers/google/models/%s:predict", config.Endpoint, config.APIVersion, config.ProjectID, config.Location, model)
target, err := url.Parse(targetURL)
if err != nil {
log.Printf("Error parsing target URL: %v", err)
return
model := client.GenerativeModel(modelName)

// Handle chat/completions
if strings.HasSuffix(c.Request.URL.Path, "/chat/completions") {
handleChatCompletion(c, model)
} else {
c.JSON(http.StatusNotFound, gin.H{"error": "Invalid endpoint for Vertex AI"})
}
}

func getModelFromRequestBody(req *http.Request) string {
body, _ := io.ReadAll(req.Body)
req.Body = io.NopCloser(strings.NewReader(string(body))) // Restore the body
var data map[string]interface{}
if err := json.Unmarshal(body, &data); err == nil {
if model, ok := data["model"].(string); ok {
return model
}
}
return ""
}

func handleChatCompletion(c *gin.Context, model *genai.GenerativeModel) {
var req struct {
Messages []struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"messages"`
Stream *bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
}

if err := c.BindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
return
}

cs := model.StartChat()
cs.History = []*genai.Content{}

// Set the target
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path = target.Path
for _, msg := range req.Messages {
cs.History = append(cs.History, &genai.Content{
Parts: []genai.Part{
genai.Text(msg.Content),
},
Role: msg.Role,
})
}

// Set Authorization header using Google Application Default Credentials (ADC)
token, err := getAccessToken()
// Set advanced parameters if provided
if req.Temperature != nil {
model.SetTemperature(float32(*req.Temperature))
}
if req.TopP != nil {
model.SetTopP(float32(*req.TopP))
}
if req.TopK != nil {
model.SetTopK(int32(*req.TopK))
}

// Handle streaming if requested
if req.Stream != nil && *req.Stream {
iter := cs.SendMessageStream(context.Background(), genai.Text(req.Messages[len(req.Messages)-1].Content))
c.Stream(func(w io.Writer) bool {
resp, err := iter.Next()
if err == iterator.Done {
return false
}
if err != nil {
log.Printf("Error generating content: %v", err)
c.SSEvent("error", "Failed to generate content")
return false
}

// Convert each response to OpenAI format and send as SSE
openaiResp := convertToOpenAIResponseStream(resp)
c.SSEvent("message", openaiResp)
return true
})
} else {
// Use SendMessage for a single response
resp, err := cs.SendMessage(context.Background(), genai.Text(req.Messages[len(req.Messages)-1].Content))
if err != nil {
log.Printf("Error getting access token: %v", err)
log.Printf("Error generating content: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate content"})
return
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))

log.Printf("proxying request %s -> %s", originalURL, req.URL.String())
// Convert the response to OpenAI format
openaiResp := convertToOpenAIResponse(resp)
c.JSON(http.StatusOK, openaiResp)
}

return &httputil.ReverseProxy{Director: director}
}

func getModelFromRequest(req *http.Request) string {
// Check the URL path for the model
parts := strings.Split(req.URL.Path, "/")
for i, part := range parts {
if part == "models" && i+1 < len(parts) {
return parts[i+1]
// Helper function to convert a single response to OpenAI format (for streaming)
func convertToOpenAIResponseStream(resp *genai.GenerateContentResponse) map[string]interface{} {
var parts []string
for _, candidate := range resp.Candidates {
for _, part := range candidate.Content.Parts {
parts = append(parts, fmt.Sprintf("%v", part))
}
}

// If not found in the path, try to get it from the request body
if req.Body != nil {
body, _ := io.ReadAll(req.Body)
req.Body = io.NopCloser(strings.NewReader(string(body))) // Restore the body
var data map[string]interface{}
if err := json.Unmarshal(body, &data); err == nil {
if model, ok := data["model"].(string); ok {
return model
}
}
return map[string]interface{}{
"object": "chat.completion.chunk",
"choices": []map[string]interface{}{
{
"index": 0,
"delta": map[string]interface{}{
"role": "assistant",
"content": strings.Join(parts, ""),
},
"finish_reason": "stop",
},
},
}

return ""
}

func getAccessToken() (string, error) {
// Use Application Default Credentials (ADC) to get an access token
// Ensure that your environment is set up with ADC, e.g., by running:
// gcloud auth application-default login
// Or by setting the GOOGLE_APPLICATION_CREDENTIALS environment variable
output, err := exec.Command("gcloud", "auth", "print-access-token").Output()
if err != nil {
return "", fmt.Errorf("failed to get access token: %v", err)
// Helper function to convert a single response to OpenAI format (for non-streaming)
func convertToOpenAIResponse(resp *genai.GenerateContentResponse) map[string]interface{} {
var choices []map[string]interface{}
for _, candidate := range resp.Candidates {
choices = append(choices, map[string]interface{}{
"index": candidate.Index,
"message": map[string]interface{}{
"role": "model",
"content": fmt.Sprintf("%v", candidate.Content.Parts),
},
})
}

return map[string]interface{}{
"object": "chat.completion",
"choices": choices,
}
return strings.TrimSpace(string(output)), nil
}

type Model struct {
Expand Down Expand Up @@ -156,24 +237,25 @@ func FetchVertexAIModels() ([]Model, error) {
return nil, fmt.Errorf("Vertex AI Project ID not set")
}

token, err := getAccessToken()
ctx := context.Background()
creds := option.WithCredentialsFile(os.Getenv("GOOGLE_APPLICATION_CREDENTIALS"))
client, err := genai.NewClient(ctx, creds)
if err != nil {
return nil, fmt.Errorf("failed to get access token: %v", err)
return nil, fmt.Errorf("failed to create Vertex AI client: %v", err)
}
defer client.Close()

url := fmt.Sprintf("https://%s/%s/projects/%s/locations/%s/publishers/google/models", VertexAIEndpoint, VertexAIAPIVersion, VertexAIProjectID, VertexAILocation)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}

req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))

client := &http.Client{}
resp, err := client.Do(req)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}

defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
Expand All @@ -186,7 +268,6 @@ func FetchVertexAIModels() ([]Model, error) {
Name string `json:"name"`
DisplayName string `json:"displayName"`
Description string `json:"description"`
// Add other relevant fields if needed
} `json:"models"`
}

Expand All @@ -196,22 +277,21 @@ func FetchVertexAIModels() ([]Model, error) {

var models []Model
for _, m := range vertexModels.Models {
// Extract model ID from the name field (e.g., "publishers/google/models/chat-bison")
parts := strings.Split(m.Name, "/")
modelID := parts[len(parts)-1]

models = append(models, Model{
ID: modelID,
Object: "model",
ID: modelID,
Object: "model",
Name: m.Name,
Description: m.Description,
LifecycleStatus: "active", // You might need to adjust this based on actual Vertex AI model data
Status: "ready", // You might need to adjust this based on actual Vertex AI model data
Capabilities: Capabilities{
Completion: true,
ChatCompletion: strings.Contains(modelID, "chat"),
Embeddings: strings.Contains(modelID, "embedding"),
},
LifecycleStatus: "active",
Status: "ready",
Name: m.Name,
Description: m.Description,
})
}

Expand Down

0 comments on commit 6eb5fb3

Please sign in to comment.