-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add discord bot and discord oauth
- Loading branch information
Showing
8 changed files
with
469 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
/.env | ||
|
||
/keys | ||
/tmp | ||
/tlskeys | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
package main | ||
|
||
import ( | ||
"database/sql" | ||
"fmt" | ||
"log" | ||
"net/url" | ||
|
||
"github.com/bwmarrin/discordgo" | ||
) | ||
|
||
type DiscordBot struct { | ||
Token string | ||
GuildId string | ||
AdminRoleId string | ||
StudentRoleId string | ||
ClientId string | ||
ClientSecret string | ||
Db *sql.DB | ||
|
||
Session *discordgo.Session | ||
} | ||
|
||
var ( | ||
commands = []*discordgo.ApplicationCommand{ | ||
// All commands and options must have a description | ||
// Commands/options without description will fail the registration | ||
// of the command. | ||
{ | ||
Name: "about", | ||
Description: "Lookup information about a discord user who has linked their OSU account", | ||
Options: []*discordgo.ApplicationCommandOption{ | ||
{ | ||
Type: discordgo.ApplicationCommandOptionUser, | ||
Name: "user", | ||
Description: "User", | ||
Required: true, | ||
}, | ||
}, | ||
}, | ||
} | ||
|
||
commandHandlers = map[string]func(b *DiscordBot, i *discordgo.InteractionCreate){ | ||
"about": func(b *DiscordBot, i *discordgo.InteractionCreate) { | ||
if !b.requireAdmin(i) { | ||
return | ||
} | ||
options := i.ApplicationCommandData().Options | ||
user := options[0].UserValue(nil) | ||
discordId := user.ID | ||
|
||
row := b.Db.QueryRow(`SELECT name_num, display_name, last_signin, student, alum, employee, faculty FROM users WHERE discord_id = ?`, discordId) | ||
var ( | ||
nameNum string | ||
displayName string | ||
lastLogin int | ||
student bool | ||
alum bool | ||
employee bool | ||
faculty bool | ||
) | ||
err := row.Scan(&nameNum, &displayName, &lastLogin, &student, &alum, &employee, &faculty) | ||
if err != nil { | ||
log.Println("/about command: discordId =", discordId, err) | ||
_ = b.Session.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ | ||
Type: discordgo.InteractionResponseChannelMessageWithSource, | ||
Data: &discordgo.InteractionResponseData{ | ||
Content: "User has not linked their OSU account", | ||
}, | ||
}) | ||
return | ||
} | ||
content := fmt.Sprintf("**[%s (%s)](<https://www.osu.edu/search/?query=%s>)**\nLast login: <t:%d:f>\n", | ||
displayName, | ||
nameNum, | ||
url.QueryEscape(nameNum), | ||
lastLogin, | ||
) | ||
|
||
sep := "" | ||
if student { | ||
content += sep + "Student" | ||
sep = ", " | ||
} | ||
if alum { | ||
content += sep + "Alum" | ||
sep = ", " | ||
} | ||
if employee { | ||
content += sep + "Employee" | ||
sep = ", " | ||
} | ||
if faculty { | ||
content += sep + "Faculty" | ||
} | ||
|
||
_ = b.Session.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ | ||
Type: discordgo.InteractionResponseChannelMessageWithSource, | ||
Data: &discordgo.InteractionResponseData{ | ||
Content: content, | ||
}, | ||
}) | ||
}, | ||
} | ||
) | ||
|
||
func (b *DiscordBot) isAdmin(m *discordgo.Member) bool { | ||
for _, role := range m.Roles { | ||
if role == b.AdminRoleId { | ||
return true | ||
} | ||
} | ||
return false | ||
} | ||
|
||
func (b *DiscordBot) requireAdmin(i *discordgo.InteractionCreate) bool { | ||
if !b.isAdmin(i.Member) { | ||
_ = b.Session.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ | ||
Type: discordgo.InteractionResponseChannelMessageWithSource, | ||
Data: &discordgo.InteractionResponseData{ | ||
Content: "This command requires admin", | ||
}, | ||
}) | ||
return false | ||
} | ||
|
||
return true | ||
} | ||
|
||
func (b *DiscordBot) Connect() { | ||
if b.Token == "" { | ||
log.Fatalln("Missing token") | ||
} | ||
if b.AdminRoleId == "" { | ||
log.Fatalln("Missing admin role id") | ||
} | ||
if b.GuildId == "" { | ||
log.Fatalln("Missing guild id") | ||
} | ||
if b.ClientId == "" { | ||
log.Fatalln("Missing client id") | ||
} | ||
if b.ClientSecret == "" { | ||
log.Fatalln("Missing client secret") | ||
} | ||
|
||
s, err := discordgo.New("Bot " + b.Token) | ||
if err != nil { | ||
log.Fatalln("Failed to connect", err) | ||
} | ||
|
||
b.Session = s | ||
|
||
s.AddHandler(func(s *discordgo.Session, r *discordgo.Ready) { | ||
log.Println("Logged in as", r.User.String()) | ||
}) | ||
|
||
s.AddHandler(func(s *discordgo.Session, i *discordgo.InteractionCreate) { | ||
if h, ok := commandHandlers[i.ApplicationCommandData().Name]; ok { | ||
h(b, i) | ||
} | ||
}) | ||
|
||
err = s.Open() | ||
if err != nil { | ||
log.Fatalln("Failed to open session", err) | ||
} | ||
|
||
registeredCommands := make([]*discordgo.ApplicationCommand, len(commands)) | ||
for i, v := range commands { | ||
cmd, err := s.ApplicationCommandCreate(s.State.User.ID, b.GuildId, v) | ||
if err != nil { | ||
log.Panicf("Cannot create '%v' command: %v", v.Name, err) | ||
} | ||
registeredCommands[i] = cmd | ||
} | ||
} | ||
|
||
func (b *DiscordBot) GiveStudentRole(discordId string) error { | ||
return b.Session.GuildMemberRoleAdd(b.GuildId, discordId, b.StudentRoleId) | ||
} | ||
|
||
func (b *DiscordBot) AddStudentToGuild(discordId string, accessToken string) error { | ||
return b.Session.GuildMemberAdd(b.GuildId, discordId, &discordgo.GuildMemberAddParams{ | ||
AccessToken: accessToken, | ||
Roles: []string{b.StudentRoleId}, | ||
}) | ||
} | ||
|
||
func (b *DiscordBot) RemoveStudentRole(discordId string) error { | ||
return b.Session.GuildMemberRoleRemove(b.GuildId, discordId, b.StudentRoleId) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
package main | ||
|
||
import ( | ||
"bytes" | ||
"crypto/rand" | ||
"database/sql" | ||
"encoding/base64" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"log" | ||
"net/http" | ||
"net/url" | ||
"time" | ||
) | ||
|
||
const OAUTH_STATE_COOKIE = "oauthstate" | ||
|
||
func (r *Router) DiscordSignin(w http.ResponseWriter, req *http.Request) { | ||
state := generateStateOauthCookie(w) | ||
redirectUri := fmt.Sprintf("%s/discord/callback", r.rootURL) | ||
url := fmt.Sprintf("https://discord.com/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%v&scope=identify+guilds.join&state=%v", r.bot.ClientId, url.QueryEscape(redirectUri), state) | ||
http.Redirect(w, req, url, http.StatusTemporaryRedirect) | ||
} | ||
|
||
func (r *Router) DiscordCallback(w http.ResponseWriter, req *http.Request) { | ||
userId, _ := getUserIDFromContext(req.Context()) | ||
|
||
stateCookie, err := req.Cookie(OAUTH_STATE_COOKIE) | ||
if err != nil { | ||
log.Println("Discord callback: Missing oauth state cookie. User id =", userId) | ||
http.Error(w, "Missing oauthstate cookie", http.StatusBadRequest) | ||
return | ||
} | ||
stateParam := req.URL.Query().Get("state") | ||
if stateParam == "" { | ||
log.Println("Discord callback: Missing state url parameter. User id =", userId) | ||
http.Error(w, "Missing state url parameter", http.StatusBadRequest) | ||
return | ||
} | ||
if stateCookie.Value != stateParam { | ||
log.Println("Discord callback: State cookie and state parameter don't match. State cookie =", stateCookie.Value, ", state param =", stateParam, "User id =", userId) | ||
http.Error(w, "State cookie and state parameter don't match", http.StatusBadRequest) | ||
return | ||
} | ||
|
||
code := req.URL.Query().Get("code") | ||
authToken, err := getDiscordAuthToken(r.rootURL, r.bot, code) | ||
if err != nil { | ||
log.Println("Discord callback: Error getting discord auth token:", err, "User id =", userId) | ||
http.Error(w, "Error getting discord auth token", http.StatusForbidden) | ||
return | ||
} | ||
|
||
discordUser, err := getDiscordUser(authToken) | ||
if err != nil { | ||
log.Println("Discord callback: Error getting discord user:", err, "User id =", userId) | ||
http.Error(w, "Error getting user information", http.StatusForbidden) | ||
return | ||
} | ||
|
||
tx, err := r.db.Begin() | ||
if err != nil { | ||
log.Println("Discord callback: Failed to start transaction", err, "User id =", userId) | ||
http.Error(w, "Failed to get user", http.StatusForbidden) | ||
return | ||
} | ||
row := tx.QueryRow("SELECT discord_id FROM users WHERE idm_id = ?", userId) | ||
|
||
var oldDiscordId sql.NullString | ||
err = row.Scan(&oldDiscordId) | ||
if err != nil { | ||
log.Println("Discord callback: failed to get old discord id:", err) | ||
http.Error(w, "Failed to get user", http.StatusInternalServerError) | ||
_ = tx.Rollback() | ||
return | ||
} | ||
|
||
_, err = r.db.Exec("UPDATE users SET discord_id = ? WHERE idm_id = ?", discordUser.ID, userId) | ||
if err != nil { | ||
log.Println("Discord callback: failed to update user:", err) | ||
http.Error(w, "Failed to update discord", http.StatusInternalServerError) | ||
_ = tx.Rollback() | ||
return | ||
} | ||
|
||
err = tx.Commit() | ||
if err != nil { | ||
log.Println("Discord callback: failed to commit transcation:", err) | ||
http.Error(w, "Failed to update discord", http.StatusInternalServerError) | ||
_ = tx.Rollback() | ||
return | ||
} | ||
|
||
if oldDiscordId.Valid { | ||
_ = r.bot.RemoveStudentRole(oldDiscordId.String) | ||
} | ||
_ = r.bot.AddStudentToGuild(discordUser.ID, authToken) | ||
_ = r.bot.GiveStudentRole(discordUser.ID) | ||
|
||
http.Redirect(w, req, "/", http.StatusTemporaryRedirect) | ||
} | ||
|
||
func generateStateOauthCookie(w http.ResponseWriter) string { | ||
var expiration = time.Now().Add(2 * time.Hour) | ||
b := make([]byte, 16) | ||
_, _ = rand.Read(b) | ||
state := base64.URLEncoding.EncodeToString(b) | ||
cookie := http.Cookie{Name: OAUTH_STATE_COOKIE, Value: state, Expires: expiration} | ||
http.SetCookie(w, &cookie) | ||
return state | ||
} | ||
|
||
func getDiscordAuthToken(rootURL *url.URL, b *DiscordBot, code string) (string, error) { | ||
redirectUri := fmt.Sprintf("%s/discord/callback", rootURL) | ||
|
||
data := url.Values{ | ||
"client_id": {b.ClientId}, | ||
"client_secret": {b.ClientSecret}, | ||
"grant_type": {"authorization_code"}, | ||
"code": {code}, | ||
"redirect_uri": {redirectUri}, | ||
} | ||
|
||
req, err := http.NewRequest("POST", "https://discord.com/api/oauth2/token", bytes.NewBufferString(data.Encode())) | ||
if err != nil { | ||
return "", err | ||
} | ||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | ||
|
||
client := &http.Client{} | ||
resp, err := client.Do(req) | ||
if err != nil { | ||
return "", err | ||
} | ||
defer resp.Body.Close() | ||
|
||
if resp.StatusCode != http.StatusOK { | ||
return "", fmt.Errorf("discord token endpoint responded with %d", resp.StatusCode) | ||
} | ||
|
||
var result map[string]interface{} | ||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { | ||
return "", err | ||
} | ||
|
||
accessToken, ok := result["access_token"].(string) | ||
if !ok { | ||
return "", fmt.Errorf("discord token endpoint response did not have access_token") | ||
} | ||
|
||
return accessToken, nil | ||
} | ||
|
||
type DiscordUser struct { | ||
Avatar string `json:"avatar"` | ||
Discriminator string `json:"discriminator"` | ||
Email string `json:"email"` | ||
Flags int `json:"flags"` | ||
ID string `json:"id"` | ||
Username string `json:"username"` | ||
} | ||
|
||
func getDiscordUser(authToken string) (DiscordUser, error) { | ||
client := &http.Client{} | ||
req, err := http.NewRequest("GET", "https://discord.com/api/users/@me", nil) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken)) | ||
|
||
resp, err := client.Do(req) | ||
if err != nil { | ||
return DiscordUser{}, err | ||
} | ||
defer resp.Body.Close() | ||
|
||
if resp.StatusCode != http.StatusOK { | ||
return DiscordUser{}, fmt.Errorf("failed to get user info") | ||
} | ||
|
||
body, err := io.ReadAll(resp.Body) | ||
if err != nil { | ||
return DiscordUser{}, err | ||
} | ||
|
||
var user DiscordUser | ||
if err := json.Unmarshal(body, &user); err != nil { | ||
return DiscordUser{}, err | ||
} | ||
|
||
return user, nil | ||
} |
Oops, something went wrong.