Skip to content

Commit

Permalink
Use rollback user creation if incrementing invite usage count didn't …
Browse files Browse the repository at this point in the history
…work
  • Loading branch information
foodelevator committed Oct 31, 2024
1 parent b0a8d03 commit 94fea3b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
29 changes: 24 additions & 5 deletions pkg/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"embed"

"github.com/datasektionen/logout/pkg/config"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
"github.com/pressly/goose/v3"
Expand Down Expand Up @@ -43,14 +44,32 @@ func ConnectAndMigrate(ctx context.Context) (*Queries, error) {
}

func (q *Queries) Tx(ctx context.Context, f func(db *Queries) error) error {
pool := q.db.(*pgxpool.Pool)
tx, err := pool.Begin(ctx)
txq, err := q.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
if err := f(q.WithTx(tx)); err != nil {
defer txq.Rollback(ctx)
if err := f(txq); err != nil {
return err
}
return tx.Commit(ctx)
return txq.Commit(ctx)
}

func (q *Queries) Begin(ctx context.Context) (*Queries, error) {
// q.db is either *pgxpool.Pool or pgx.Tx. Both have this method
tx, err := q.db.(interface {
Begin(ctx context.Context) (pgx.Tx, error)
}).Begin(ctx)
if err != nil {
return nil, err
}
return q.WithTx(tx), nil
}

func (q *Queries) Commit(ctx context.Context) error {
return q.db.(pgx.Tx).Commit(ctx)
}

func (q *Queries) Rollback(ctx context.Context) error {
return q.db.(pgx.Tx).Rollback(ctx)
}
12 changes: 10 additions & 2 deletions services/user/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,12 @@ func (s *service) FinishInvite(w http.ResponseWriter, r *http.Request, kthid str
slog.Error("Could not find user in ldap", "kthid", kthid, "invite id", id)
return true, errors.New("Could not find user in ldap")
}
if err := s.db.CreateUser(r.Context(), database.CreateUserParams{
tx, err := s.db.Begin(r.Context())
if err != nil {
return true, err
}
defer tx.Rollback(r.Context())
if err := tx.CreateUser(r.Context(), database.CreateUserParams{
Kthid: kthid,
UgKthid: person.UGKTHID,
Email: kthid + "@kth.se",
Expand All @@ -141,7 +146,10 @@ func (s *service) FinishInvite(w http.ResponseWriter, r *http.Request, kthid str
}); err != nil {
return true, err
}
if err := s.db.IncrementInviteUses(r.Context(), id); err != nil {
if err := tx.IncrementInviteUses(r.Context(), id); err != nil {
return true, err
}
if err := tx.Commit(r.Context()); err != nil {
return true, err
}
http.SetCookie(w, &http.Cookie{Name: "invite", MaxAge: -1})
Expand Down

0 comments on commit 94fea3b

Please sign in to comment.