-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(ai-bridge): support Github Models (#885)
- Loading branch information
1 parent
8eedd7d
commit e54ea71
Showing
6 changed files
with
173 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
// Package githubmodels is the Github Models llm provider, see https://github.com/marketplace/models | ||
package githubmodels | ||
|
||
import ( | ||
"context" | ||
"os" | ||
|
||
// automatically load .env file | ||
_ "github.com/joho/godotenv/autoload" | ||
"github.com/sashabaranov/go-openai" | ||
"github.com/yomorun/yomo/core/metadata" | ||
|
||
provider "github.com/yomorun/yomo/pkg/bridge/ai/provider" | ||
) | ||
|
||
// Provider is the provider for Github Models | ||
type Provider struct { | ||
// APIKey is the API key for Github Models | ||
APIKey string | ||
// Model is the model for Github Models, see https://github.com/marketplace/models | ||
// e.g. "Meta-Llama-3.1-405B-Instruct", "Mistral-large-2407", "gpt-4o" | ||
Model string | ||
client *openai.Client | ||
} | ||
|
||
// check if implements ai.Provider | ||
var _ provider.LLMProvider = &Provider{} | ||
|
||
// NewProvider creates a new OpenAIProvider | ||
func NewProvider(apiKey string, model string) *Provider { | ||
if apiKey == "" { | ||
apiKey = os.Getenv("GITHUB_TOKEN") | ||
} | ||
|
||
config := openai.DefaultConfig(apiKey) | ||
config.BaseURL = "https://models.inference.ai.azure.com" | ||
|
||
return &Provider{ | ||
APIKey: apiKey, | ||
Model: model, | ||
client: openai.NewClientWithConfig(config), | ||
} | ||
} | ||
|
||
// Name returns the name of the provider | ||
func (p *Provider) Name() string { | ||
return "githubmodels" | ||
} | ||
|
||
// GetChatCompletions implements ai.LLMProvider. | ||
func (p *Provider) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (openai.ChatCompletionResponse, error) { | ||
if p.Model != "" { | ||
req.Model = p.Model | ||
} | ||
|
||
return p.client.CreateChatCompletion(ctx, req) | ||
} | ||
|
||
// GetChatCompletionsStream implements ai.LLMProvider. | ||
func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (provider.ResponseRecver, error) { | ||
if p.Model != "" { | ||
req.Model = p.Model | ||
} | ||
|
||
return p.client.CreateChatCompletionStream(ctx, req) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
package githubmodels | ||
|
||
import ( | ||
"context" | ||
"os" | ||
"testing" | ||
|
||
openai "github.com/sashabaranov/go-openai" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestGithubModelsProvider_Name(t *testing.T) { | ||
provider := &Provider{} | ||
|
||
name := provider.Name() | ||
|
||
assert.Equal(t, "githubmodels", name) | ||
} | ||
|
||
func TestNewProvider(t *testing.T) { | ||
t.Run("with parameters", func(t *testing.T) { | ||
provider := NewProvider("test_api_key", "test_model") | ||
|
||
assert.Equal(t, "test_api_key", provider.APIKey) | ||
assert.Equal(t, "test_model", provider.Model) | ||
}) | ||
|
||
t.Run("with environment variables", func(t *testing.T) { | ||
os.Setenv("GITHUB_TOKEN", "env_api_key") | ||
|
||
provider := NewProvider("", "test_model") | ||
|
||
assert.Equal(t, "env_api_key", provider.APIKey) | ||
assert.Equal(t, "test_model", provider.Model) | ||
|
||
os.Unsetenv("GITHUB_TOKEN") | ||
}) | ||
} | ||
|
||
func TestGithubModelsProvider_GetChatCompletions(t *testing.T) { | ||
config := openai.DefaultConfig("test_api_key") | ||
config.BaseURL = "https://models.inference.ai.azure.com" | ||
client := openai.NewClientWithConfig(config) | ||
|
||
provider := &Provider{ | ||
APIKey: "test_api_key", | ||
Model: "test_model", | ||
client: client, | ||
} | ||
|
||
msgs := []openai.ChatCompletionMessage{ | ||
{ | ||
Role: "user", | ||
Content: "hello", | ||
}, | ||
{ | ||
Role: "system", | ||
Content: "I'm a bot", | ||
}, | ||
} | ||
req := openai.ChatCompletionRequest{ | ||
Messages: msgs, | ||
} | ||
|
||
_, err := provider.GetChatCompletions(context.TODO(), req, nil) | ||
assert.Error(t, err) | ||
assert.Contains(t, err.Error(), "401") | ||
assert.Contains(t, err.Error(), "Bad credentials") | ||
} | ||
|
||
func TestGithubModelsProvider_GetChatCompletionsStream(t *testing.T) { | ||
config := openai.DefaultConfig("test_api_key") | ||
config.BaseURL = "https://models.inference.ai.azure.com" | ||
client := openai.NewClientWithConfig(config) | ||
|
||
provider := &Provider{ | ||
APIKey: "test_api_key", | ||
Model: "test_model", | ||
client: client, | ||
} | ||
|
||
msgs := []openai.ChatCompletionMessage{ | ||
{ | ||
Role: "user", | ||
Content: "hello", | ||
}, | ||
{ | ||
Role: "system", | ||
Content: "I'm a bot", | ||
}, | ||
} | ||
req := openai.ChatCompletionRequest{ | ||
Messages: msgs, | ||
} | ||
|
||
_, err := provider.GetChatCompletionsStream(context.TODO(), req, nil) | ||
assert.Error(t, err) | ||
assert.Contains(t, err.Error(), "401") | ||
assert.Contains(t, err.Error(), "Bad credentials") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters