Skip to content

Commit

Permalink
update main.go ai/*.go
Browse files Browse the repository at this point in the history
  • Loading branch information
kechigon committed Nov 6, 2024
1 parent bcaac66 commit 44e5550
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 77 deletions.
130 changes: 57 additions & 73 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"log"
"os"
"strings"
"fmt"

"github.com/3-shake/alert-menta/internal/ai"
"github.com/3-shake/alert-menta/internal/github"
Expand All @@ -24,21 +25,53 @@ type Config struct {
}

func main() {
cfg := parseFlags()
cfg := &Config{}
flag.StringVar(&cfg.repo, "repo", "", "Repository name")
flag.StringVar(&cfg.owner, "owner", "", "Repository owner")
flag.IntVar(&cfg.issueNumber, "issue", 0, "Issue number")
flag.StringVar(&cfg.intent, "intent", "", "Question or intent for the 'ask' command")
flag.StringVar(&cfg.command, "command", "", "Commands to be executed by AI. Commands defined in the configuration file are available.")
flag.StringVar(&cfg.configFile, "config", "", "Configuration file")
flag.StringVar(&cfg.ghToken, "github-token", "", "GitHub token")
flag.StringVar(&cfg.oaiKey, "api-key", "", "OpenAI api key")
flag.Parse()

if cfg.repo == "" || cfg.owner == "" || cfg.issueNumber == 0 || cfg.ghToken == "" || cfg.command == "" || cfg.configFile == "" {
flag.PrintDefaults()
os.Exit(1)
}

logger := initLogger()
logger := log.New(
os.Stdout, "[alert-menta main] ",
log.Ldate|log.Ltime|log.Llongfile|log.Lmsgprefix,
)

loadedConfig := loadConfiguration(cfg.configFile, logger)
loadedcfg, err := utils.NewConfig(cfg.configFile)
if err != nil {
logger.Fatalf("Error loading config: %v", err)
}

validateCommand(cfg.command, loadedConfig, logger)
err = validateCommand(cfg.command, loadedcfg, logger)
if err != nil {
logger.Fatalf("Error validating command: %v", err)
}

issue := getGitHubIssue(cfg.owner, cfg.repo, cfg.issueNumber, cfg.ghToken)
issue := github.NewIssue(cfg.owner, cfg.repo, cfg.issueNumber, cfg.ghToken)

userPrompt := constructUserPrompt(issue, loadedConfig, logger)
userPrompt, err := constructUserPrompt(issue, loadedcfg, logger)
if err != nil {
logger.Fatalf("Erro constructing userPrompt: %v", err)
}

prompt := constructPrompt(cfg.command, cfg.intent, userPrompt, loadedConfig, logger)
prompt, err := constructPrompt(cfg.command, cfg.intent, userPrompt, loadedcfg, logger)
if err != nil {
logger.Fatalf("Error constructing prompt: %v", err)
}

aic := getAIClient(cfg.oaiKey, loadedConfig, logger)
aic, err := getAIClient(cfg.oaiKey, loadedcfg, logger)
if err != nil {
logger.Fatalf("Error geting AI client: %v", err)
}

comment, err := aic.GetResponse(prompt)
if err != nil {
Expand All @@ -51,76 +84,28 @@ func main() {
}
}

// Parse command-line flags
func parseFlags() *Config {
repo := flag.String("repo", "", "Repository name")
owner := flag.String("owner", "", "Repository owner")
issueNumber := flag.Int("issue", 0, "Issue number")
intent := flag.String("intent", "", "Question or intent for the 'ask' command")
command := flag.String("command", "", "Commands to be executed by AI.")
configFile := flag.String("config", "", "Configuration file")
ghToken := flag.String("github-token", "", "GitHub token")
oaiKey := flag.String("api-key", "", "OpenAI api key")
flag.Parse()
if *repo == "" || *owner == "" || *issueNumber == 0 || *ghToken == "" || *command == "" || *configFile == "" {
flag.PrintDefaults()
os.Exit(1)
}
return &Config{
repo: *repo,
owner: *owner,
issueNumber: *issueNumber,
intent: *intent,
command: *command,
configFile: *configFile,
ghToken: *ghToken,
oaiKey: *oaiKey,
}
}

// Initialize a logger
func initLogger() *log.Logger {
return log.New(
os.Stdout, "[alert-menta main] ",
log.Ldate|log.Ltime|log.Llongfile|log.Lmsgprefix,
)
}

// Load and validate configuration
func loadConfiguration(configFile string, logger *log.Logger) *utils.Config {
cfg, err := utils.NewConfig(configFile)
if err != nil {
logger.Fatalf("Error loading config: %v", err)
}
return cfg
}

// Validate the provided command
func validateCommand(command string, cfg *utils.Config, logger *log.Logger) {
func validateCommand(command string, cfg *utils.Config, logger *log.Logger) error {
if _, ok := cfg.Ai.Commands[command]; !ok {
allowedCommands := make([]string, 0, len(cfg.Ai.Commands))
for cmd := range cfg.Ai.Commands {
allowedCommands = append(allowedCommands, cmd)
}
logger.Fatalf("Invalid command: %s. Allowed commands are %s.", command, strings.Join(allowedCommands, ", "))
return fmt.Errorf("Invalid command: %s. Allowed commands are %s", command, strings.Join(allowedCommands, ", "))
}
}

// Get GitHub issue instance
func getGitHubIssue(owner, repo string, issueNumber int, ghToken string) *github.GitHubIssue {
return github.NewIssue(owner, repo, issueNumber, ghToken)
return nil
}

// Construct user prompt from issue
func constructUserPrompt(issue *github.GitHubIssue, cfg *utils.Config, logger *log.Logger) string {
func constructUserPrompt(issue *github.GitHubIssue, cfg *utils.Config, logger *log.Logger) (string, error) {
title, err := issue.GetTitle()
if err != nil {
logger.Fatalf("Error getting Title: %v", err)
return "", fmt.Errorf("Error getting Title: %w", err)
}

body, err := issue.GetBody()
if err != nil {
logger.Fatalf("Error getting Body: %v", err)
return "", fmt.Errorf("Error getting Body: %w", err)
}

var userPrompt strings.Builder
Expand All @@ -129,7 +114,7 @@ func constructUserPrompt(issue *github.GitHubIssue, cfg *utils.Config, logger *l

comments, err := issue.GetComments()
if err != nil {
logger.Fatalf("Error getting comments: %v", err)
return "", fmt.Errorf("Error getting comments: %w", err)
}
for _, v := range comments {
if *v.User.Login == "github-actions[bot]" {
Expand All @@ -140,40 +125,39 @@ func constructUserPrompt(issue *github.GitHubIssue, cfg *utils.Config, logger *l
}
userPrompt.WriteString(*v.User.Login + ":" + *v.Body + "\n")
}
return userPrompt.String()
return userPrompt.String(), nil
}

// Construct AI prompt
func constructPrompt(command, intent, userPrompt string, cfg *utils.Config, logger *log.Logger) ai.Prompt {
func constructPrompt(command, intent, userPrompt string, cfg *utils.Config, logger *log.Logger) (*ai.Prompt, error){
var systemPrompt string
if command == "ask" {
if intent == "" {
logger.Fatalf("Error: intent is required for 'ask' command")
return nil, fmt.Errorf("Error: intent is required for 'ask' command")
}
systemPrompt = cfg.Ai.Commands[command].System_prompt + intent + "\n"
} else {
systemPrompt = cfg.Ai.Commands[command].System_prompt
}
logger.Println("\x1b[34mPrompt: |\n", systemPrompt, userPrompt, "\x1b[0m")
return ai.Prompt{UserPrompt: userPrompt, SystemPrompt: systemPrompt}
return &ai.Prompt{UserPrompt: userPrompt, SystemPrompt: systemPrompt}, nil
}

// Initialize AI client
func getAIClient(oaiKey string, cfg *utils.Config, logger *log.Logger) ai.Ai {
func getAIClient(oaiKey string, cfg *utils.Config, logger *log.Logger) (ai.Ai, error) {
switch cfg.Ai.Provider {
case "openai":
if oaiKey == "" {
logger.Fatalf("Error: Please provide your Open AI API key.")
return nil, fmt.Errorf("Error: Please provide your Open AI API key")
}
logger.Println("Using OpenAI API")
logger.Println("OpenAI model:", cfg.Ai.OpenAI.Model)
return ai.NewOpenAIClient(oaiKey, cfg.Ai.OpenAI.Model)
return ai.NewOpenAIClient(oaiKey, cfg.Ai.OpenAI.Model), nil
case "vertexai":
logger.Println("Using VertexAI API")
logger.Println("VertexAI model:", cfg.Ai.VertexAI.Model)
return ai.NewVertexAIClient(cfg.Ai.VertexAI.Project, cfg.Ai.VertexAI.Region, cfg.Ai.VertexAI.Model)
return ai.NewVertexAIClient(cfg.Ai.VertexAI.Project, cfg.Ai.VertexAI.Region, cfg.Ai.VertexAI.Model), nil
default:
logger.Fatalf("Error: Invalid provider")
return nil
return nil, fmt.Errorf("Error: Invalid provider")
}
}
2 changes: 1 addition & 1 deletion internal/ai/ai.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package ai

type Ai interface {
GetResponse(prompt Prompt) (string, error)
GetResponse(prompt *Prompt) (string, error)
}

type Prompt struct {
Expand Down
2 changes: 1 addition & 1 deletion internal/ai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type OpenAI struct {
model string
}

func (ai *OpenAI) GetResponse(prompt Prompt) (string, error) {
func (ai *OpenAI) GetResponse(prompt *Prompt) (string, error) {
// Create a new OpenAI client
keyCredential := azcore.NewKeyCredential(ai.apiKey)
client, _ := azopenai.NewClientForOpenAI("https://api.openai.com/v1/", keyCredential, nil)
Expand Down
4 changes: 2 additions & 2 deletions internal/ai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ type VertexAI struct {
model string
}

func (ai *VertexAI) GetResponse(prompt Prompt) (string, error) {
func (ai *VertexAI) GetResponse(prompt *Prompt) (string, error) {
model := ai.client.GenerativeModel(ai.model)
//Temperature recommended by LLM
model.SetTemperature(0.5)

resp, err := model.GenerateContent(ai.context, genai.Text(prompt.SystemPrompt+prompt.UserPrompt))
resp, err := model.GenerateContent(ai.context, genai.Text(prompt.SystemPrompt + prompt.UserPrompt))
if err != nil {
log.Fatal(err)
return "", err
Expand Down

0 comments on commit 44e5550

Please sign in to comment.