-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement conversation API endpoint #188
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
package main | ||
|
||
import ( | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"net/http" | ||
|
||
"github.com/mattermost/mattermost-plugin-ai/server/ai" | ||
|
||
"github.com/mattermost/mattermost/server/public/model" | ||
"github.com/mattermost/mattermost/server/public/pluginapi" | ||
|
||
"github.com/gin-gonic/gin" | ||
) | ||
|
||
type ConversationRequest struct { | ||
// The name of the bot that should handle the request. | ||
BotName string `json:"bot_name"` | ||
// Optional past conversation to be used as context. | ||
Thread []*model.Post `json:"thread"` | ||
// The post to be processed in this request. | ||
Request *model.Post `json:"request"` | ||
// Whether to use the system role to generate the prompt. | ||
UseSystemRole bool `json:"use_system_role"` | ||
} | ||
|
||
func (p *Plugin) handlePostConversation(c *gin.Context) { | ||
userID := c.GetHeader("Mattermost-User-Id") | ||
|
||
// We only allow bots to use this API handler for the time being. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you are using the inter-plugin API to use this API we could restrict it to other plugins for now with the header: https://github.com/mattermost/mattermost-plugin-github/blob/5aa2450cc65254054872fc3ca6917b1c9354d543/server/plugin/api.go#L265 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. I'll look into it although that would probably mean an additional roundtrip in my case (i.e. go through the plugin) since the bot is just a client running outside of Mattermost. |
||
if _, err := p.pluginAPI.Bot.Get(userID, false); errors.Is(err, pluginapi.ErrNotFound) { | ||
c.AbortWithError(http.StatusForbidden, errors.New("forbidden")) | ||
return | ||
} else if err != nil { | ||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to get bot: %w", err)) | ||
return | ||
} | ||
|
||
var reqData ConversationRequest | ||
if err := json.NewDecoder(c.Request.Body).Decode(&reqData); err != nil { | ||
c.AbortWithError(http.StatusBadRequest, err) | ||
return | ||
} | ||
defer c.Request.Body.Close() | ||
|
||
// Validation | ||
if reqData.BotName == "" { | ||
c.AbortWithError(http.StatusBadRequest, errors.New("invalid empty bot")) | ||
return | ||
} | ||
|
||
bot := p.GetBotByUsername(reqData.BotName) | ||
if bot == nil { | ||
c.AbortWithError(http.StatusBadRequest, errors.New("invalid bot name")) | ||
return | ||
} | ||
|
||
post := reqData.Request | ||
|
||
if post == nil { | ||
c.AbortWithError(http.StatusBadRequest, errors.New("invalid request")) | ||
return | ||
} | ||
|
||
if post.Message == "" { | ||
c.AbortWithError(http.StatusBadRequest, errors.New("invalid empty message")) | ||
return | ||
} | ||
|
||
if post.ChannelId == "" { | ||
c.AbortWithError(http.StatusBadRequest, errors.New("invalid empty channel id")) | ||
return | ||
} | ||
|
||
channel, err := p.pluginAPI.Channel.Get(post.ChannelId) | ||
if errors.Is(err, pluginapi.ErrNotFound) { | ||
c.AbortWithError(http.StatusBadRequest, errors.New("channel not found")) | ||
return | ||
} else if err != nil { | ||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to get channel: %w", err)) | ||
return | ||
} | ||
|
||
if post.UserId == "" { | ||
c.AbortWithError(http.StatusBadRequest, errors.New("invalid empty user id")) | ||
return | ||
} | ||
|
||
postingUser, err := p.pluginAPI.User.Get(post.UserId) | ||
if errors.Is(err, pluginapi.ErrNotFound) { | ||
c.AbortWithError(http.StatusBadRequest, errors.New("user not found")) | ||
return | ||
} else if err != nil { | ||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to get posting user: %w", err)) | ||
return | ||
} | ||
|
||
// Don't respond to ourselves | ||
if p.IsAnyBot(post.UserId) { | ||
c.AbortWithError(http.StatusBadRequest, errors.New("not responding to ourselves")) | ||
return | ||
} | ||
|
||
list := &model.PostList{ | ||
Order: make([]string, 0, len(reqData.Thread)+1), | ||
Posts: make(map[string]*model.Post, len(reqData.Thread)+1), | ||
} | ||
list.Order = append(list.Order, post.Id) | ||
list.Posts[post.Id] = post | ||
for i, post := range reqData.Thread { | ||
list.Order = append(list.Order, post.Id) | ||
list.Posts[post.Id] = reqData.Thread[i] | ||
} | ||
|
||
threadData, err := p.getMetadataForPosts(list) | ||
if err != nil { | ||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to get thread data: %w", err)) | ||
return | ||
} | ||
|
||
prompt, err := p.prompts.ChatCompletion(ai.PromptDirectMessageQuestion, p.MakeConversationContext(bot, postingUser, channel, post)) | ||
if err != nil { | ||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to generate prompt: %w", err)) | ||
return | ||
} | ||
prompt.AppendConversation(p.ThreadToBotConversation(bot, threadData.Posts)) | ||
|
||
// Overriding post role if requested. | ||
if reqData.UseSystemRole { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I think I was unclear about what this means. The system prompt is special instructions to the LLM that are weighted more highly then user instructions. Usually there is only one system prompt. You can look at OpenAI's explanation of what it means here: https://platform.openai.com/docs/guides/text-generation/chat-completions-api I guess what I really want is to be able to call this API from other plugins and be able to customize what prompt is provided on line 122. So that we can provide some custom instructions to the LLM about how the rest of the conversation should be handled like telling it it's in a call and talking live, the participants, etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @crspeller Are you thinking to go as far as letting plugins pass/override a whole prompt template or more like letting them choose which prompt to use and possibly augment existing ones? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am imagining them passing the whole prompt template. That way they can do whatever they want. Ideally they would be able to utilize the default personalities and whatnot as well. |
||
for i := range prompt.Posts { | ||
prompt.Posts[i].Role = ai.PostRoleSystem | ||
} | ||
} | ||
streamer45 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
result, err := p.getLLM(bot.cfg).ChatCompletion(prompt) | ||
if err != nil { | ||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to process request: %w", err)) | ||
return | ||
} | ||
|
||
for { | ||
select { | ||
case msg := <-result.Stream: | ||
if _, err := c.Writer.WriteString(msg); err != nil { | ||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("error while writing result: %w", err)) | ||
} | ||
// Flushing lets us stream partial results without requiring the client to wait for the full response. | ||
c.Writer.Flush() | ||
Comment on lines
+148
to
+149
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For streaming to work correctly we'll need mattermost/mattermost@51f4271. |
||
case err, ok := <-result.Err: | ||
if !ok { | ||
return | ||
} | ||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("error while streaming result: %w", err)) | ||
return | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not overly worried about naming or scoping. Just started with this to get something working.