From e54ea71e7daa5aac8e688589ca631cb72cb98b47 Mon Sep 17 00:00:00 2001 From: "C.C." Date: Fri, 16 Aug 2024 11:32:13 +0800 Subject: [PATCH] feat(ai-bridge): support Github Models (#885) --- cli/serve.go | 7 +- pkg/bridge/ai/ai.go | 1 - pkg/bridge/ai/api_server.go | 2 +- .../ai/provider/githubmodels/provider.go | 66 ++++++++++++ .../ai/provider/githubmodels/provider_test.go | 100 ++++++++++++++++++ zipper_notwindows.go | 2 +- 6 files changed, 173 insertions(+), 5 deletions(-) create mode 100644 pkg/bridge/ai/provider/githubmodels/provider.go create mode 100644 pkg/bridge/ai/provider/githubmodels/provider_test.go diff --git a/cli/serve.go b/cli/serve.go index bd4751055..55402da91 100644 --- a/cli/serve.go +++ b/cli/serve.go @@ -33,6 +33,7 @@ import ( "github.com/yomorun/yomo/pkg/bridge/ai/provider/cfazure" "github.com/yomorun/yomo/pkg/bridge/ai/provider/cfopenai" "github.com/yomorun/yomo/pkg/bridge/ai/provider/gemini" + "github.com/yomorun/yomo/pkg/bridge/ai/provider/githubmodels" "github.com/yomorun/yomo/pkg/bridge/ai/provider/ollama" "github.com/yomorun/yomo/pkg/bridge/ai/provider/openai" ) @@ -49,7 +50,7 @@ var serveCmd = &cobra.Command{ } // log.InfoStatusEvent(os.Stdout, "") - ylog.Info("Starting Zipper...") + ylog.Info("Starting YoMo Zipper...") // config conf, err := pkgconfig.ParseConfigFile(config) if err != nil { @@ -150,12 +151,14 @@ func registerAIProvider(aiConfig *ai.Config) error { providerpkg.RegisterProvider(ollama.NewProvider(provider["api_endpoint"])) case "gemini": providerpkg.RegisterProvider(gemini.NewProvider(provider["api_key"])) + case "githubmodels": + providerpkg.RegisterProvider(githubmodels.NewProvider(provider["api_key"], provider["model"])) default: log.WarningStatusEvent(os.Stdout, "unknown provider: %s", name) } } - ylog.Info("registered AI providers", "len", len(providerpkg.ListProviders())) + ylog.Info("register LLM providers", "num", len(providerpkg.ListProviders())) return nil } diff --git a/pkg/bridge/ai/ai.go b/pkg/bridge/ai/ai.go index d75392d40..eb38095c5 100644 --- a/pkg/bridge/ai/ai.go +++ b/pkg/bridge/ai/ai.go @@ -144,7 +144,6 @@ func ParseConfig(conf map[string]any) (config *Config, err error) { if config.Server.Addr == "" { config.Server.Addr = ":8000" } - ylog.Info("parse AI config success") return } diff --git a/pkg/bridge/ai/api_server.go b/pkg/bridge/ai/api_server.go index e42817df7..696734feb 100644 --- a/pkg/bridge/ai/api_server.go +++ b/pkg/bridge/ai/api_server.go @@ -49,7 +49,7 @@ func Serve(config *Config, zipperListenAddr string, credential string, logger *s return err } - logger.Info("start bridge server", "addr", config.Server.Addr, "provider", provider.Name()) + logger.Info("start AI Bridge service", "addr", config.Server.Addr, "provider", provider.Name()) return srv.ServeAddr(config.Server.Addr) } diff --git a/pkg/bridge/ai/provider/githubmodels/provider.go b/pkg/bridge/ai/provider/githubmodels/provider.go new file mode 100644 index 000000000..143818d44 --- /dev/null +++ b/pkg/bridge/ai/provider/githubmodels/provider.go @@ -0,0 +1,66 @@ +// Package githubmodels is the Github Models llm provider, see https://github.com/marketplace/models +package githubmodels + +import ( + "context" + "os" + + // automatically load .env file + _ "github.com/joho/godotenv/autoload" + "github.com/sashabaranov/go-openai" + "github.com/yomorun/yomo/core/metadata" + + provider "github.com/yomorun/yomo/pkg/bridge/ai/provider" +) + +// Provider is the provider for Github Models +type Provider struct { + // APIKey is the API key for Github Models + APIKey string + // Model is the model for Github Models, see https://github.com/marketplace/models + // e.g. "Meta-Llama-3.1-405B-Instruct", "Mistral-large-2407", "gpt-4o" + Model string + client *openai.Client +} + +// check if implements ai.Provider +var _ provider.LLMProvider = &Provider{} + +// NewProvider creates a new OpenAIProvider +func NewProvider(apiKey string, model string) *Provider { + if apiKey == "" { + apiKey = os.Getenv("GITHUB_TOKEN") + } + + config := openai.DefaultConfig(apiKey) + config.BaseURL = "https://models.inference.ai.azure.com" + + return &Provider{ + APIKey: apiKey, + Model: model, + client: openai.NewClientWithConfig(config), + } +} + +// Name returns the name of the provider +func (p *Provider) Name() string { + return "githubmodels" +} + +// GetChatCompletions implements ai.LLMProvider. +func (p *Provider) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (openai.ChatCompletionResponse, error) { + if p.Model != "" { + req.Model = p.Model + } + + return p.client.CreateChatCompletion(ctx, req) +} + +// GetChatCompletionsStream implements ai.LLMProvider. +func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (provider.ResponseRecver, error) { + if p.Model != "" { + req.Model = p.Model + } + + return p.client.CreateChatCompletionStream(ctx, req) +} diff --git a/pkg/bridge/ai/provider/githubmodels/provider_test.go b/pkg/bridge/ai/provider/githubmodels/provider_test.go new file mode 100644 index 000000000..c8dce6b32 --- /dev/null +++ b/pkg/bridge/ai/provider/githubmodels/provider_test.go @@ -0,0 +1,100 @@ +package githubmodels + +import ( + "context" + "os" + "testing" + + openai "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/assert" +) + +func TestGithubModelsProvider_Name(t *testing.T) { + provider := &Provider{} + + name := provider.Name() + + assert.Equal(t, "githubmodels", name) +} + +func TestNewProvider(t *testing.T) { + t.Run("with parameters", func(t *testing.T) { + provider := NewProvider("test_api_key", "test_model") + + assert.Equal(t, "test_api_key", provider.APIKey) + assert.Equal(t, "test_model", provider.Model) + }) + + t.Run("with environment variables", func(t *testing.T) { + os.Setenv("GITHUB_TOKEN", "env_api_key") + + provider := NewProvider("", "test_model") + + assert.Equal(t, "env_api_key", provider.APIKey) + assert.Equal(t, "test_model", provider.Model) + + os.Unsetenv("GITHUB_TOKEN") + }) +} + +func TestGithubModelsProvider_GetChatCompletions(t *testing.T) { + config := openai.DefaultConfig("test_api_key") + config.BaseURL = "https://models.inference.ai.azure.com" + client := openai.NewClientWithConfig(config) + + provider := &Provider{ + APIKey: "test_api_key", + Model: "test_model", + client: client, + } + + msgs := []openai.ChatCompletionMessage{ + { + Role: "user", + Content: "hello", + }, + { + Role: "system", + Content: "I'm a bot", + }, + } + req := openai.ChatCompletionRequest{ + Messages: msgs, + } + + _, err := provider.GetChatCompletions(context.TODO(), req, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "401") + assert.Contains(t, err.Error(), "Bad credentials") +} + +func TestGithubModelsProvider_GetChatCompletionsStream(t *testing.T) { + config := openai.DefaultConfig("test_api_key") + config.BaseURL = "https://models.inference.ai.azure.com" + client := openai.NewClientWithConfig(config) + + provider := &Provider{ + APIKey: "test_api_key", + Model: "test_model", + client: client, + } + + msgs := []openai.ChatCompletionMessage{ + { + Role: "user", + Content: "hello", + }, + { + Role: "system", + Content: "I'm a bot", + }, + } + req := openai.ChatCompletionRequest{ + Messages: msgs, + } + + _, err := provider.GetChatCompletionsStream(context.TODO(), req, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "401") + assert.Contains(t, err.Error(), "Bad credentials") +} diff --git a/zipper_notwindows.go b/zipper_notwindows.go index 61a33c020..fd2134610 100644 --- a/zipper_notwindows.go +++ b/zipper_notwindows.go @@ -21,7 +21,7 @@ import ( func waitSignalForShutdownServer(server *core.Server) { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGTERM, syscall.SIGUSR2, syscall.SIGUSR1, syscall.SIGINT) - ylog.Info("Listening SIGUSR1, SIGUSR2, SIGTERM/SIGINT...") + ylog.Info("listening SIGUSR1, SIGUSR2, SIGTERM/SIGINT...") for p1 := range c { ylog.Debug("Received signal", "signal", p1) if p1 == syscall.SIGTERM || p1 == syscall.SIGINT {