Skip to content

Commit

Permalink
feat(ai-bridge): support Github Models (#885)
Browse files Browse the repository at this point in the history
  • Loading branch information
fanweixiao authored Aug 16, 2024
1 parent 8eedd7d commit e54ea71
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 5 deletions.
7 changes: 5 additions & 2 deletions cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/yomorun/yomo/pkg/bridge/ai/provider/cfazure"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/cfopenai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/gemini"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/githubmodels"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/ollama"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/openai"
)
Expand All @@ -49,7 +50,7 @@ var serveCmd = &cobra.Command{
}

// log.InfoStatusEvent(os.Stdout, "")
ylog.Info("Starting Zipper...")
ylog.Info("Starting YoMo Zipper...")
// config
conf, err := pkgconfig.ParseConfigFile(config)
if err != nil {
Expand Down Expand Up @@ -150,12 +151,14 @@ func registerAIProvider(aiConfig *ai.Config) error {
providerpkg.RegisterProvider(ollama.NewProvider(provider["api_endpoint"]))
case "gemini":
providerpkg.RegisterProvider(gemini.NewProvider(provider["api_key"]))
case "githubmodels":
providerpkg.RegisterProvider(githubmodels.NewProvider(provider["api_key"], provider["model"]))
default:
log.WarningStatusEvent(os.Stdout, "unknown provider: %s", name)
}
}

ylog.Info("registered AI providers", "len", len(providerpkg.ListProviders()))
ylog.Info("register LLM providers", "num", len(providerpkg.ListProviders()))
return nil
}

Expand Down
1 change: 0 additions & 1 deletion pkg/bridge/ai/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ func ParseConfig(conf map[string]any) (config *Config, err error) {
if config.Server.Addr == "" {
config.Server.Addr = ":8000"
}
ylog.Info("parse AI config success")
return
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/bridge/ai/api_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func Serve(config *Config, zipperListenAddr string, credential string, logger *s
return err
}

logger.Info("start bridge server", "addr", config.Server.Addr, "provider", provider.Name())
logger.Info("start AI Bridge service", "addr", config.Server.Addr, "provider", provider.Name())
return srv.ServeAddr(config.Server.Addr)
}

Expand Down
66 changes: 66 additions & 0 deletions pkg/bridge/ai/provider/githubmodels/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Package githubmodels is the Github Models llm provider, see https://github.com/marketplace/models
package githubmodels

import (
"context"
"os"

// automatically load .env file
_ "github.com/joho/godotenv/autoload"
"github.com/sashabaranov/go-openai"
"github.com/yomorun/yomo/core/metadata"

provider "github.com/yomorun/yomo/pkg/bridge/ai/provider"
)

// Provider is the provider for Github Models
type Provider struct {
// APIKey is the API key for Github Models
APIKey string
// Model is the model for Github Models, see https://github.com/marketplace/models
// e.g. "Meta-Llama-3.1-405B-Instruct", "Mistral-large-2407", "gpt-4o"
Model string
client *openai.Client
}

// check if implements ai.Provider
var _ provider.LLMProvider = &Provider{}

// NewProvider creates a new OpenAIProvider
func NewProvider(apiKey string, model string) *Provider {
if apiKey == "" {
apiKey = os.Getenv("GITHUB_TOKEN")
}

config := openai.DefaultConfig(apiKey)
config.BaseURL = "https://models.inference.ai.azure.com"

return &Provider{
APIKey: apiKey,
Model: model,
client: openai.NewClientWithConfig(config),
}
}

// Name returns the name of the provider
func (p *Provider) Name() string {
return "githubmodels"
}

// GetChatCompletions implements ai.LLMProvider.
func (p *Provider) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (openai.ChatCompletionResponse, error) {
if p.Model != "" {
req.Model = p.Model
}

return p.client.CreateChatCompletion(ctx, req)
}

// GetChatCompletionsStream implements ai.LLMProvider.
func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (provider.ResponseRecver, error) {
if p.Model != "" {
req.Model = p.Model
}

return p.client.CreateChatCompletionStream(ctx, req)
}
100 changes: 100 additions & 0 deletions pkg/bridge/ai/provider/githubmodels/provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package githubmodels

import (
"context"
"os"
"testing"

openai "github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
)

func TestGithubModelsProvider_Name(t *testing.T) {
provider := &Provider{}

name := provider.Name()

assert.Equal(t, "githubmodels", name)
}

func TestNewProvider(t *testing.T) {
t.Run("with parameters", func(t *testing.T) {
provider := NewProvider("test_api_key", "test_model")

assert.Equal(t, "test_api_key", provider.APIKey)
assert.Equal(t, "test_model", provider.Model)
})

t.Run("with environment variables", func(t *testing.T) {
os.Setenv("GITHUB_TOKEN", "env_api_key")

provider := NewProvider("", "test_model")

assert.Equal(t, "env_api_key", provider.APIKey)
assert.Equal(t, "test_model", provider.Model)

os.Unsetenv("GITHUB_TOKEN")
})
}

func TestGithubModelsProvider_GetChatCompletions(t *testing.T) {
config := openai.DefaultConfig("test_api_key")
config.BaseURL = "https://models.inference.ai.azure.com"
client := openai.NewClientWithConfig(config)

provider := &Provider{
APIKey: "test_api_key",
Model: "test_model",
client: client,
}

msgs := []openai.ChatCompletionMessage{
{
Role: "user",
Content: "hello",
},
{
Role: "system",
Content: "I'm a bot",
},
}
req := openai.ChatCompletionRequest{
Messages: msgs,
}

_, err := provider.GetChatCompletions(context.TODO(), req, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "401")
assert.Contains(t, err.Error(), "Bad credentials")
}

func TestGithubModelsProvider_GetChatCompletionsStream(t *testing.T) {
config := openai.DefaultConfig("test_api_key")
config.BaseURL = "https://models.inference.ai.azure.com"
client := openai.NewClientWithConfig(config)

provider := &Provider{
APIKey: "test_api_key",
Model: "test_model",
client: client,
}

msgs := []openai.ChatCompletionMessage{
{
Role: "user",
Content: "hello",
},
{
Role: "system",
Content: "I'm a bot",
},
}
req := openai.ChatCompletionRequest{
Messages: msgs,
}

_, err := provider.GetChatCompletionsStream(context.TODO(), req, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "401")
assert.Contains(t, err.Error(), "Bad credentials")
}
2 changes: 1 addition & 1 deletion zipper_notwindows.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
func waitSignalForShutdownServer(server *core.Server) {
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGTERM, syscall.SIGUSR2, syscall.SIGUSR1, syscall.SIGINT)
ylog.Info("Listening SIGUSR1, SIGUSR2, SIGTERM/SIGINT...")
ylog.Info("listening SIGUSR1, SIGUSR2, SIGTERM/SIGINT...")
for p1 := range c {
ylog.Debug("Received signal", "signal", p1)
if p1 == syscall.SIGTERM || p1 == syscall.SIGINT {
Expand Down

0 comments on commit e54ea71

Please sign in to comment.