Skip to content

Commit

Permalink
discord oauth and bot (#9)
Browse files Browse the repository at this point in the history
Add discord bot and discord oauth
  • Loading branch information
jm8 authored Oct 26, 2024
1 parent 43fe41e commit 960d643
Show file tree
Hide file tree
Showing 8 changed files with 469 additions and 12 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
/.env

/keys
/tmp
/tlskeys
Expand Down
192 changes: 192 additions & 0 deletions discord.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
package main

import (
"database/sql"
"fmt"
"log"
"net/url"

"github.com/bwmarrin/discordgo"
)

type DiscordBot struct {
Token string
GuildId string
AdminRoleId string
StudentRoleId string
ClientId string
ClientSecret string
Db *sql.DB

Session *discordgo.Session
}

var (
commands = []*discordgo.ApplicationCommand{
// All commands and options must have a description
// Commands/options without description will fail the registration
// of the command.
{
Name: "about",
Description: "Lookup information about a discord user who has linked their OSU account",
Options: []*discordgo.ApplicationCommandOption{
{
Type: discordgo.ApplicationCommandOptionUser,
Name: "user",
Description: "User",
Required: true,
},
},
},
}

commandHandlers = map[string]func(b *DiscordBot, i *discordgo.InteractionCreate){
"about": func(b *DiscordBot, i *discordgo.InteractionCreate) {
if !b.requireAdmin(i) {
return
}
options := i.ApplicationCommandData().Options
user := options[0].UserValue(nil)
discordId := user.ID

row := b.Db.QueryRow(`SELECT name_num, display_name, last_signin, student, alum, employee, faculty FROM users WHERE discord_id = ?`, discordId)
var (
nameNum string
displayName string
lastLogin int
student bool
alum bool
employee bool
faculty bool
)
err := row.Scan(&nameNum, &displayName, &lastLogin, &student, &alum, &employee, &faculty)
if err != nil {
log.Println("/about command: discordId =", discordId, err)
_ = b.Session.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Content: "User has not linked their OSU account",
},
})
return
}
content := fmt.Sprintf("**[%s (%s)](<https://www.osu.edu/search/?query=%s>)**\nLast login: <t:%d:f>\n",
displayName,
nameNum,
url.QueryEscape(nameNum),
lastLogin,
)

sep := ""
if student {
content += sep + "Student"
sep = ", "
}
if alum {
content += sep + "Alum"
sep = ", "
}
if employee {
content += sep + "Employee"
sep = ", "
}
if faculty {
content += sep + "Faculty"
}

_ = b.Session.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Content: content,
},
})
},
}
)

func (b *DiscordBot) isAdmin(m *discordgo.Member) bool {
for _, role := range m.Roles {
if role == b.AdminRoleId {
return true
}
}
return false
}

func (b *DiscordBot) requireAdmin(i *discordgo.InteractionCreate) bool {
if !b.isAdmin(i.Member) {
_ = b.Session.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Content: "This command requires admin",
},
})
return false
}

return true
}

func (b *DiscordBot) Connect() {
if b.Token == "" {
log.Fatalln("Missing token")
}
if b.AdminRoleId == "" {
log.Fatalln("Missing admin role id")
}
if b.GuildId == "" {
log.Fatalln("Missing guild id")
}
if b.ClientId == "" {
log.Fatalln("Missing client id")
}
if b.ClientSecret == "" {
log.Fatalln("Missing client secret")
}

s, err := discordgo.New("Bot " + b.Token)
if err != nil {
log.Fatalln("Failed to connect", err)
}

b.Session = s

s.AddHandler(func(s *discordgo.Session, r *discordgo.Ready) {
log.Println("Logged in as", r.User.String())
})

s.AddHandler(func(s *discordgo.Session, i *discordgo.InteractionCreate) {
if h, ok := commandHandlers[i.ApplicationCommandData().Name]; ok {
h(b, i)
}
})

err = s.Open()
if err != nil {
log.Fatalln("Failed to open session", err)
}

registeredCommands := make([]*discordgo.ApplicationCommand, len(commands))
for i, v := range commands {
cmd, err := s.ApplicationCommandCreate(s.State.User.ID, b.GuildId, v)
if err != nil {
log.Panicf("Cannot create '%v' command: %v", v.Name, err)
}
registeredCommands[i] = cmd
}
}

func (b *DiscordBot) GiveStudentRole(discordId string) error {
return b.Session.GuildMemberRoleAdd(b.GuildId, discordId, b.StudentRoleId)
}

func (b *DiscordBot) AddStudentToGuild(discordId string, accessToken string) error {
return b.Session.GuildMemberAdd(b.GuildId, discordId, &discordgo.GuildMemberAddParams{
AccessToken: accessToken,
Roles: []string{b.StudentRoleId},
})
}

func (b *DiscordBot) RemoveStudentRole(discordId string) error {
return b.Session.GuildMemberRoleRemove(b.GuildId, discordId, b.StudentRoleId)
}
194 changes: 194 additions & 0 deletions discord_oauth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package main

import (
"bytes"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"time"
)

const OAUTH_STATE_COOKIE = "oauthstate"

func (r *Router) DiscordSignin(w http.ResponseWriter, req *http.Request) {
state := generateStateOauthCookie(w)
redirectUri := fmt.Sprintf("%s/discord/callback", r.rootURL)
url := fmt.Sprintf("https://discord.com/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%v&scope=identify+guilds.join&state=%v", r.bot.ClientId, url.QueryEscape(redirectUri), state)
http.Redirect(w, req, url, http.StatusTemporaryRedirect)
}

func (r *Router) DiscordCallback(w http.ResponseWriter, req *http.Request) {
userId, _ := getUserIDFromContext(req.Context())

stateCookie, err := req.Cookie(OAUTH_STATE_COOKIE)
if err != nil {
log.Println("Discord callback: Missing oauth state cookie. User id =", userId)
http.Error(w, "Missing oauthstate cookie", http.StatusBadRequest)
return
}
stateParam := req.URL.Query().Get("state")
if stateParam == "" {
log.Println("Discord callback: Missing state url parameter. User id =", userId)
http.Error(w, "Missing state url parameter", http.StatusBadRequest)
return
}
if stateCookie.Value != stateParam {
log.Println("Discord callback: State cookie and state parameter don't match. State cookie =", stateCookie.Value, ", state param =", stateParam, "User id =", userId)
http.Error(w, "State cookie and state parameter don't match", http.StatusBadRequest)
return
}

code := req.URL.Query().Get("code")
authToken, err := getDiscordAuthToken(r.rootURL, r.bot, code)
if err != nil {
log.Println("Discord callback: Error getting discord auth token:", err, "User id =", userId)
http.Error(w, "Error getting discord auth token", http.StatusForbidden)
return
}

discordUser, err := getDiscordUser(authToken)
if err != nil {
log.Println("Discord callback: Error getting discord user:", err, "User id =", userId)
http.Error(w, "Error getting user information", http.StatusForbidden)
return
}

tx, err := r.db.Begin()
if err != nil {
log.Println("Discord callback: Failed to start transaction", err, "User id =", userId)
http.Error(w, "Failed to get user", http.StatusForbidden)
return
}
row := tx.QueryRow("SELECT discord_id FROM users WHERE idm_id = ?", userId)

var oldDiscordId sql.NullString
err = row.Scan(&oldDiscordId)
if err != nil {
log.Println("Discord callback: failed to get old discord id:", err)
http.Error(w, "Failed to get user", http.StatusInternalServerError)
_ = tx.Rollback()
return
}

_, err = r.db.Exec("UPDATE users SET discord_id = ? WHERE idm_id = ?", discordUser.ID, userId)
if err != nil {
log.Println("Discord callback: failed to update user:", err)
http.Error(w, "Failed to update discord", http.StatusInternalServerError)
_ = tx.Rollback()
return
}

err = tx.Commit()
if err != nil {
log.Println("Discord callback: failed to commit transcation:", err)
http.Error(w, "Failed to update discord", http.StatusInternalServerError)
_ = tx.Rollback()
return
}

if oldDiscordId.Valid {
_ = r.bot.RemoveStudentRole(oldDiscordId.String)
}
_ = r.bot.AddStudentToGuild(discordUser.ID, authToken)
_ = r.bot.GiveStudentRole(discordUser.ID)

http.Redirect(w, req, "/", http.StatusTemporaryRedirect)
}

func generateStateOauthCookie(w http.ResponseWriter) string {
var expiration = time.Now().Add(2 * time.Hour)
b := make([]byte, 16)
_, _ = rand.Read(b)
state := base64.URLEncoding.EncodeToString(b)
cookie := http.Cookie{Name: OAUTH_STATE_COOKIE, Value: state, Expires: expiration}
http.SetCookie(w, &cookie)
return state
}

func getDiscordAuthToken(rootURL *url.URL, b *DiscordBot, code string) (string, error) {
redirectUri := fmt.Sprintf("%s/discord/callback", rootURL)

data := url.Values{
"client_id": {b.ClientId},
"client_secret": {b.ClientSecret},
"grant_type": {"authorization_code"},
"code": {code},
"redirect_uri": {redirectUri},
}

req, err := http.NewRequest("POST", "https://discord.com/api/oauth2/token", bytes.NewBufferString(data.Encode()))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("discord token endpoint responded with %d", resp.StatusCode)
}

var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", err
}

accessToken, ok := result["access_token"].(string)
if !ok {
return "", fmt.Errorf("discord token endpoint response did not have access_token")
}

return accessToken, nil
}

type DiscordUser struct {
Avatar string `json:"avatar"`
Discriminator string `json:"discriminator"`
Email string `json:"email"`
Flags int `json:"flags"`
ID string `json:"id"`
Username string `json:"username"`
}

func getDiscordUser(authToken string) (DiscordUser, error) {
client := &http.Client{}
req, err := http.NewRequest("GET", "https://discord.com/api/users/@me", nil)
if err != nil {
log.Fatal(err)
}

req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken))

resp, err := client.Do(req)
if err != nil {
return DiscordUser{}, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return DiscordUser{}, fmt.Errorf("failed to get user info")
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return DiscordUser{}, err
}

var user DiscordUser
if err := json.Unmarshal(body, &user); err != nil {
return DiscordUser{}, err
}

return user, nil
}
Loading

0 comments on commit 960d643

Please sign in to comment.