Skip to content

Commit

Permalink
structured ws message (might not work)
Browse files Browse the repository at this point in the history
  • Loading branch information
ruskaof committed Mar 24, 2024
1 parent 4824696 commit d611716
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 62 deletions.
7 changes: 5 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"dwelt/src/config"
"dwelt/src/handler"
"dwelt/src/model/dao"
"dwelt/src/service/usrserv"
"dwelt/src/ws/chat"
"flag"
"log/slog"
Expand All @@ -16,12 +17,14 @@ func main() {
slog.SetLogLoggerLevel(slog.LevelDebug)
flag.Parse()
config.InitCfg()
dao.InitDB()
db := dao.InitDB()

hub := chat.NewHub()
go hub.Run()

handler.InitHandlers(hub)
userService := usrserv.NewUserService(hub, db)
userController := handler.NewUserController(userService)
userController.InitHandlers(hub)

server := &http.Server{
Addr: *port,
Expand Down
32 changes: 32 additions & 0 deletions src/dto/dto.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package dto

import "encoding/json"

type UserInfo struct {
UserId int64 `json:"id"`
}
Expand All @@ -8,3 +10,33 @@ type UserResponse struct {
UserId int64 `json:"userId"`
Username string `json:"username"`
}

type WebSocketClientMessage struct {
ChatId int64 `json:"chatId"`
Message string `json:"message"`
}

type WebSocketServerMessage struct {
ChatId int64 `json:"chatId"`
Message string `json:"message"`
}

func SerializeWebSocketServerMessage(message WebSocketServerMessage) []byte {
res, _ := json.Marshal(message)
return res
}

func DeserializeWebSocketServerMessage(data []byte) (message WebSocketServerMessage, err error) {
err = json.Unmarshal(data, &message)
return
}

func SerializeWebSocketClientMessage(message WebSocketClientMessage) []byte {
res, _ := json.Marshal(message)
return res
}

func DeserializeWebSocketClientMessage(data []byte) (message WebSocketClientMessage, err error) {
err = json.Unmarshal(data, &message)
return
}
48 changes: 28 additions & 20 deletions src/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,46 @@ import (
"github.com/gorilla/mux"
)

const DEFAULT_LIMIT = 100
const defaultLimit = 100

func InitHandlers(hub *chat.Hub) {
type UserController struct {
userService *usrserv.UserService
}

func NewUserController(userService *usrserv.UserService) *UserController {
return &UserController{userService}
}

func (uc UserController) InitHandlers(hub *chat.Hub) {
router := mux.NewRouter()

authenticatedRouter := router.PathPrefix("/").Subrouter()
authenticatedRouter.Use(handlerAuthMiddleware)
authenticatedRouter.HandleFunc("/hello", handlerHelloWorld).Methods(http.MethodGet)
authenticatedRouter.HandleFunc("/ws", createHandlerWs(hub)).Methods(http.MethodGet)
authenticatedRouter.HandleFunc("/hello", uc.handlerHelloWorld).Methods(http.MethodGet)
authenticatedRouter.HandleFunc("/ws", uc.createHandlerWs(hub)).Methods(http.MethodGet)

usersRouter := authenticatedRouter.PathPrefix("/users").Subrouter()
usersRouter.HandleFunc("/search", handlerSearchUsers).Methods(http.MethodGet)
usersRouter.HandleFunc("/search", uc.handlerSearchUsers).Methods(http.MethodGet)

chatsRouter := authenticatedRouter.PathPrefix("/chats").Subrouter()
chatsRouter.HandleFunc("/direct/{directToUid}", handlerFindDirectChat).Methods(http.MethodGet)
chatsRouter.HandleFunc("/direct/{directToUid}", uc.handlerFindDirectChat).Methods(http.MethodGet)

noAuthRouter := router.PathPrefix("/").Subrouter()
noAuthRouter.HandleFunc("/register", handlerRegister).Methods(http.MethodPost)
noAuthRouter.HandleFunc("/login", handlerLogin).Methods(http.MethodGet)
noAuthRouter.HandleFunc("/info", handleApplicationInfoDashboard).Methods(http.MethodGet)
noAuthRouter.HandleFunc("/register", uc.handlerRegister).Methods(http.MethodPost)
noAuthRouter.HandleFunc("/login", uc.handlerLogin).Methods(http.MethodGet)
noAuthRouter.HandleFunc("/info", uc.handleApplicationInfoDashboard).Methods(http.MethodGet)

http.Handle("/", router)
}

func handlerLogin(w http.ResponseWriter, r *http.Request) {
func (uc UserController) handlerLogin(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if !ok {
w.WriteHeader(http.StatusUnauthorized)
return
}

userId, valid, err := usrserv.ValidateUser(username, password)
userId, valid, err := uc.userService.ValidateUser(username, password)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
Expand All @@ -59,14 +67,14 @@ func handlerLogin(w http.ResponseWriter, r *http.Request) {
utils.WriteJson(w, dto.UserInfo{UserId: userId})
}

func handlerRegister(w http.ResponseWriter, r *http.Request) {
func (uc UserController) handlerRegister(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if !ok {
w.WriteHeader(http.StatusUnauthorized)
return
}

userId, duplicate, err := usrserv.RegisterUser(username, password)
userId, duplicate, err := uc.userService.RegisterUser(username, password)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
Expand All @@ -82,25 +90,25 @@ func handlerRegister(w http.ResponseWriter, r *http.Request) {
utils.WriteJson(w, dto.UserInfo{UserId: userId})
}

func handlerSearchUsers(w http.ResponseWriter, r *http.Request) {
func (uc UserController) handlerSearchUsers(w http.ResponseWriter, r *http.Request) {
prefix := r.URL.Query().Get("prefix")
users, err := usrserv.SearchUsers(prefix, DEFAULT_LIMIT)
users, err := uc.userService.SearchUsers(prefix, defaultLimit)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
utils.WriteJson(w, users)
}

func handlerFindDirectChat(w http.ResponseWriter, r *http.Request) {
func (uc UserController) handlerFindDirectChat(w http.ResponseWriter, r *http.Request) {
requesterUid := retrieveUserId(r)
userId, err := strconv.ParseInt(mux.Vars(r)["directToUid"], 10, 64)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}

chatId, badUsers, err := usrserv.FindDirectChat(requesterUid, userId) // todo: show old messages
chatId, badUsers, err := uc.userService.FindDirectChat(requesterUid, userId) // todo: show old messages
if badUsers {
w.WriteHeader(http.StatusNotFound)
return
Expand All @@ -113,18 +121,18 @@ func handlerFindDirectChat(w http.ResponseWriter, r *http.Request) {
utils.WriteJson(w, chatId)
}

func handlerHelloWorld(w http.ResponseWriter, r *http.Request) { // todo remove
func (uc UserController) handlerHelloWorld(w http.ResponseWriter, r *http.Request) { // todo remove
w.WriteHeader(http.StatusOK)
userId, _ := r.Context().Value("userId").(int64)
utils.Must(w.Write([]byte("hello, " + strconv.FormatInt(userId, 10))))
}

func handleApplicationInfoDashboard(w http.ResponseWriter, _ *http.Request) {
func (uc UserController) handleApplicationInfoDashboard(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
utils.Must(w.Write([]byte("Workflow run number: " + strconv.Itoa(config.DweltCfg.WorkflowRunNumber))))
}

func createHandlerWs(hub *chat.Hub) http.HandlerFunc {
func (uc UserController) createHandlerWs(hub *chat.Hub) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
chat.ServeWs(hub, retrieveUserId(r), w, r)
}
Expand Down
6 changes: 2 additions & 4 deletions src/model/dao/dao.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ import (
"gorm.io/gorm/logger"
)

var Db *gorm.DB

func InitDB() {
func InitDB() *gorm.DB {
dsn := fmt.Sprintf(
"host=%s user=%s password=%s dbname=%s port=%d sslmode=disable",
config.DbCfg.Host,
Expand All @@ -22,7 +20,7 @@ func InitDB() {
config.DbCfg.Port,
)

Db = utils.Must(
return utils.Must(
gorm.Open(postgres.Open(dsn), &gorm.Config{
TranslateError: true,
// log every SQL command
Expand Down
85 changes: 66 additions & 19 deletions src/service/usrserv/usrserv.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,36 @@ package usrserv
import (
"crypto/sha512"
"dwelt/src/dto"
"dwelt/src/model/dao"
"dwelt/src/model/entity"
"dwelt/src/ws/chat"
"encoding/hex"
"errors"
"log/slog"

"gorm.io/gorm"
"log/slog"
)

type UserService struct {
wsHub *chat.Hub
db *gorm.DB
}

func NewUserService(wsHub *chat.Hub, db *gorm.DB) *UserService {
return &UserService{
wsHub: wsHub,
db: db,
}
}

func hashPassword(password string) string {
h := sha512.New()
h.Write([]byte(password))
return hex.EncodeToString(h.Sum(nil))
}

func ValidateUser(username string, password string) (userId int64, valid bool, err error) {
func (us *UserService) ValidateUser(username string, password string) (userId int64, valid bool, err error) {
var user entity.User

err = dao.Db.Where("username = ? AND password = ?", username, hashPassword(password)).First(&user).Error
err = us.db.Where("username = ? AND password = ?", username, hashPassword(password)).First(&user).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
slog.Debug("User not found", "username", username, "password", password)
err = nil
Expand All @@ -38,13 +49,13 @@ func ValidateUser(username string, password string) (userId int64, valid bool, e
return
}

func RegisterUser(username string, password string) (userId int64, duplicate bool, err error) {
func (us *UserService) RegisterUser(username string, password string) (userId int64, duplicate bool, err error) {
user := entity.User{
Username: username,
Password: hashPassword(password),
}

err = dao.Db.Create(&user).Error
err = us.db.Create(&user).Error
if errors.Is(err, gorm.ErrDuplicatedKey) {
duplicate = true
err = nil
Expand All @@ -59,9 +70,9 @@ func RegisterUser(username string, password string) (userId int64, duplicate boo
return
}

func SearchUsers(prefix string, limit int) (users []dto.UserResponse, err error) {
func (us *UserService) SearchUsers(prefix string, limit int) (users []dto.UserResponse, err error) {
var usersEntity []entity.User
err = dao.Db.Where("username LIKE ?", prefix+"%").Limit(limit).Find(&usersEntity).Error
err = us.db.Where("username LIKE ?", prefix+"%").Limit(limit).Find(&usersEntity).Error
if err != nil {
slog.Error(err.Error(), "method", "SearchUsers")
}
Expand All @@ -74,10 +85,10 @@ func SearchUsers(prefix string, limit int) (users []dto.UserResponse, err error)
return
}

func FindDirectChat(requesterUid int64, directToUid int64) (chatId int64, badUsers bool, err error) {
func (us *UserService) FindDirectChat(requesterUid int64, directToUid int64) (chatId int64, badUsers bool, err error) {
// check if both users exist
var count int64
err = dao.Db.Model(&entity.User{}).Where("id IN (?)", []int64{requesterUid, directToUid}).Count(&count).Error
err = us.db.Model(&entity.User{}).Where("id IN (?)", []int64{requesterUid, directToUid}).Count(&count).Error
if err != nil {
slog.Error(err.Error(), "method", "FindDirectChat")
return
Expand All @@ -87,13 +98,13 @@ func FindDirectChat(requesterUid int64, directToUid int64) (chatId int64, badUse
return
}

var chat entity.Chat
err = dao.Db.
var chatEntity entity.Chat
err = us.db.
Joins("JOIN users_chats uc1 ON chats.id = uc1.chat_id").
Joins("JOIN users_chats uc2 ON uc1.chat_id = uc2.chat_id").
Where("uc1.user_id = ?", requesterUid).
Where("uc2.user_id = ?", directToUid).
First(&chat).Error
First(&chatEntity).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
err = nil
}
Expand All @@ -102,24 +113,60 @@ func FindDirectChat(requesterUid int64, directToUid int64) (chatId int64, badUse
return
}

if chat.ID != 0 {
chatId = chat.ID
if chatEntity.ID != 0 {
chatId = chatEntity.ID
return
}

// create chat with associated users but don't create the users
chat = entity.Chat{
chatEntity = entity.Chat{
Users: []entity.User{
{ID: requesterUid},
{ID: directToUid},
},
}
err = dao.Db.Create(&chat).Error
err = us.db.Create(&chatEntity).Error
if err != nil {
slog.Error(err.Error(), "method", "CreateDirectChat")
return
}

chatId = chat.ID
chatId = chatEntity.ID
return
}

func (us *UserService) HandleMessage(userId int64, message dto.WebSocketClientMessage) {
// find chat
chatEntity := entity.Chat{
ID: message.ChatId,
}

err := us.db.Model(&entity.Chat{}).Preload("Users").First(&chatEntity, message.ChatId).Error
if err != nil {
slog.Error(err.Error(), "method", "HandleMessage")
return
}

// check if user is in chat
inChat := false
var otherUserIds []int64
for _, user := range chatEntity.Users {
if user.ID == userId {
inChat = true
} else {
otherUserIds = append(otherUserIds, user.ID)
}
}

if !inChat {
slog.Error("User is not in chat", "userId", userId, "chatId", message.ChatId)
return
}

serverMessage := dto.WebSocketServerMessage{
ChatId: message.ChatId,
Message: message.Message,
}

us.wsHub.SendToSelected(serverMessage, otherUserIds)
}
Loading

0 comments on commit d611716

Please sign in to comment.