Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat custom user #1978

Merged
merged 9 commits into from
Dec 5, 2024
7 changes: 3 additions & 4 deletions backend/flow_api/flow/shared/hook_issue_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,13 @@ func (h IssueSession) Execute(c flowpilot.HookExecutionContext) error {
return errors.New("user_id not found in stash")
}

emails, err := deps.Persister.GetEmailPersisterWithConnection(deps.Tx).FindByUserId(userId)
userModel, err := deps.Persister.GetUserPersisterWithConnection(deps.Tx).Get(userId)
if err != nil {
return fmt.Errorf("failed to fetch emails from db: %w", err)
return fmt.Errorf("failed to fetch user from db: %w", err)
}

var emailDTO *dto.EmailJwt

if email := emails.GetPrimary(); email != nil {
if email := userModel.Emails.GetPrimary(); email != nil {
emailDTO = dto.JwtFromEmailModel(email)
}

Expand Down
86 changes: 51 additions & 35 deletions backend/flow_api/services/webauthn.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package services

import (
"encoding/base64"
"errors"
"fmt"
"github.com/go-webauthn/webauthn/protocol"
Expand Down Expand Up @@ -178,45 +177,38 @@ func (s *webauthnService) VerifyAssertionResponse(p VerifyAssertionResponseParam
return nil, fmt.Errorf("%s: %w", err, ErrInvalidWebauthnCredential)
}

sessionDataModel, err := s.persister.GetWebauthnSessionDataPersister().Get(p.SessionDataID)
sessionDataModel, err := s.persister.GetWebauthnSessionDataPersisterWithConnection(p.Tx).Get(p.SessionDataID)
if err != nil {
return nil, fmt.Errorf("failed to get session data from db: %w", err)
}

var userID uuid.UUID
if p.IsMFA {
userID = sessionDataModel.UserId
} else {
userID, err = uuid.FromBytes(credentialAssertionData.Response.UserHandle)
if err != nil {
return nil, fmt.Errorf("failed to parse user id from user handle: %w", err)
}
}

userModel, err := s.persister.GetUserPersister().Get(userID)
credentialModel, err := s.persister.GetWebauthnCredentialPersister().Get(credentialAssertionData.ID)
if err != nil {
return nil, fmt.Errorf("failed to fetch user from db: %w", err)
return nil, fmt.Errorf("failed to get webauthncredential from db: %w", err)
}

if userModel == nil {
return nil, fmt.Errorf("%s: %w", err, ErrInvalidWebauthnCredential)
if credentialModel == nil {
return nil, ErrInvalidWebauthnCredential
}

cred := userModel.GetWebauthnCredentialById(credentialAssertionData.ID)
if cred != nil && (!p.IsMFA && cred.MFAOnly) {
if !p.IsMFA && credentialModel.MFAOnly {
return nil, ErrInvalidWebauthnCredentialMFAOnly
}

webAuthnUser, userModel, err := s.GetWebAuthnUser(p.Tx, *credentialModel)
if err != nil {
return nil, err
}

discoverableUserHandler := func(rawID, userHandle []byte) (webauthn.User, error) {
return userModel, nil
return webAuthnUser, nil
}

sessionData := sessionDataModel.ToSessionData()
var credential *webauthn.Credential
if p.IsMFA {
credential, err = s.cfg.Webauthn.Handler.ValidateLogin(userModel, *sessionData, credentialAssertionData)
_, err = s.cfg.Webauthn.Handler.ValidateLogin(webAuthnUser, *sessionData, credentialAssertionData)
} else {
credential, err = s.cfg.Webauthn.Handler.ValidateDiscoverableLogin(
_, err = s.cfg.Webauthn.Handler.ValidateDiscoverableLogin(
discoverableUserHandler,
*sessionData,
credentialAssertionData,
Expand All @@ -226,19 +218,16 @@ func (s *webauthnService) VerifyAssertionResponse(p VerifyAssertionResponseParam
return nil, fmt.Errorf("%s: %w", err, ErrInvalidWebauthnCredential)
}

encodedCredentialId := base64.RawURLEncoding.EncodeToString(credential.ID)
if credentialModel := userModel.GetWebauthnCredentialById(encodedCredentialId); credentialModel != nil {
now := time.Now().UTC()
flags := credentialAssertionData.Response.AuthenticatorData.Flags
now := time.Now().UTC()
flags := credentialAssertionData.Response.AuthenticatorData.Flags

credentialModel.LastUsedAt = &now
credentialModel.BackupState = flags.HasBackupState()
credentialModel.BackupEligible = flags.HasBackupEligible()
credentialModel.LastUsedAt = &now
credentialModel.BackupState = flags.HasBackupState()
credentialModel.BackupEligible = flags.HasBackupEligible()

err = s.persister.GetWebauthnCredentialPersisterWithConnection(p.Tx).Update(*credentialModel)
if err != nil {
return nil, fmt.Errorf("failed to update webauthn credential: %w", err)
}
err = s.persister.GetWebauthnCredentialPersisterWithConnection(p.Tx).Update(*credentialModel)
if err != nil {
return nil, fmt.Errorf("failed to update webauthn credential: %w", err)
}

err = s.persister.GetWebauthnSessionDataPersisterWithConnection(p.Tx).Delete(*sessionDataModel)
Expand Down Expand Up @@ -279,11 +268,10 @@ func (s *webauthnService) generateCreationOptions(p GenerateCreationOptionsParam

err = s.persister.GetWebauthnSessionDataPersisterWithConnection(p.Tx).Create(*sessionDataModel)
if err != nil {
return nil, nil, fmt.Errorf("failed to store session data to the db: %W", err)
return nil, nil, fmt.Errorf("failed to store session data to the db: %w", err)
}

return sessionDataModel, options, nil

}

func (s *webauthnService) GenerateCreationOptionsSecurityKey(p GenerateCreationOptionsParams) (*models.WebauthnSessionData, *protocol.CredentialCreation, error) {
Expand Down Expand Up @@ -354,3 +342,31 @@ func (s *webauthnService) VerifyAttestationResponse(p VerifyAttestationResponseP

return credential, nil
}

func (s *webauthnService) GetWebAuthnUser(tx *pop.Connection, credential models.WebauthnCredential) (webauthn.User, *models.User, error) {
user, err := s.persister.GetUserPersisterWithConnection(tx).Get(credential.UserId)
if err != nil {
return nil, nil, fmt.Errorf("failed to fetch user from db: %w", err)
}
if user == nil {
return nil, nil, ErrInvalidWebauthnCredential
}

if credential.UserHandle != nil {
return &webauthnUserWithCustomUserHandle{
CustomUserHandle: []byte(credential.UserHandle.Handle),
User: *user,
}, user, nil
}

return user, user, err
}

type webauthnUserWithCustomUserHandle struct {
models.User
CustomUserHandle []byte
}

func (u *webauthnUserWithCustomUserHandle) WebAuthnID() []byte {
return u.CustomUserHandle
}
2 changes: 1 addition & 1 deletion backend/handler/webauthn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ var userId = "ec4ef049-5b88-4321-a173-21b0eff06a04"
type sessionManager struct {
}

func (s sessionManager) GenerateJWT(_ uuid.UUID, _ *dto.EmailJwt) (string, jwt.Token, error) {
func (s sessionManager) GenerateJWT(_ uuid.UUID, _ *dto.EmailJwt, _ ...session.JWTOptions) (string, jwt.Token, error) {
return userId, nil, nil
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
drop_foreign_key("webauthn_credentials", "webauthn_credential_user_handle_fkey", {"if_exists": false})
drop_column("webauthn_credentials", "user_handle_id")
drop_table("webauthn_credential_user_handles")


Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
create_table("webauthn_credential_user_handles") {
t.Column("id", "uuid", {primary: true})
t.Column("user_id", "uuid", {"null": false})
t.Column("handle", "string", {"null": false, "unique": true})
t.Timestamps()
t.Index(["id", "user_id"], {"unique": true})
t.ForeignKey("user_id", {"users": ["id"]}, {"on_delete": "cascade", "on_update": "cascade"})
}

add_column("webauthn_credentials", "user_handle_id", "uuid", { "null": true })
add_foreign_key("webauthn_credentials", "user_handle_id", {"webauthn_credential_user_handles": ["id"]}, {
"on_delete": "set null",
"on_update": "cascade",
})

sql("ALTER TABLE webauthn_credentials ADD CONSTRAINT webauthn_credential_user_handle_fkey FOREIGN KEY (user_handle_id, user_id) REFERENCES webauthn_credential_user_handles(id, user_id) ON DELETE NO ACTION ON UPDATE CASCADE;")
30 changes: 16 additions & 14 deletions backend/persistence/models/webauthn_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,22 @@ import (

// WebauthnCredential is used by pop to map your webauthn_credentials database table to your go code.
type WebauthnCredential struct {
ID string `db:"id" json:"id"`
Name *string `db:"name" json:"name"`
UserId uuid.UUID `db:"user_id" json:"user_id"`
PublicKey string `db:"public_key" json:"public_key"`
AttestationType string `db:"attestation_type" json:"attestation_type"`
AAGUID uuid.UUID `db:"aaguid" json:"aaguid"`
SignCount int `db:"sign_count" json:"sign_count"`
LastUsedAt *time.Time `db:"last_used_at" json:"last_used_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Transports Transports `has_many:"webauthn_credential_transports" json:"transports"`
BackupEligible bool `db:"backup_eligible" json:"backup_eligible"`
BackupState bool `db:"backup_state" json:"backup_state"`
MFAOnly bool `db:"mfa_only" json:"mfa_only"`
ID string `db:"id" json:"id"`
Name *string `db:"name" json:"name"`
UserId uuid.UUID `db:"user_id" json:"user_id"`
PublicKey string `db:"public_key" json:"public_key"`
AttestationType string `db:"attestation_type" json:"attestation_type"`
AAGUID uuid.UUID `db:"aaguid" json:"aaguid"`
SignCount int `db:"sign_count" json:"sign_count"`
LastUsedAt *time.Time `db:"last_used_at" json:"last_used_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Transports Transports `has_many:"webauthn_credential_transports" json:"transports"`
BackupEligible bool `db:"backup_eligible" json:"backup_eligible"`
BackupState bool `db:"backup_state" json:"backup_state"`
MFAOnly bool `db:"mfa_only" json:"mfa_only"`
UserHandleID *uuid.UUID `db:"user_handle_id" json:"-"`
UserHandle *WebauthnCredentialUserHandle `belongs_to:"webauthn_credential_user_handle" fk_id:"webauthn_credential_user_handle_fkey" json:"user_handle,omitempty"`
}

type WebauthnCredentials []WebauthnCredential
Expand Down
28 changes: 28 additions & 0 deletions backend/persistence/models/webauthn_credential_user_handle.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package models

import (
"github.com/gobuffalo/pop/v6"
"github.com/gobuffalo/validate/v3"
"github.com/gobuffalo/validate/v3/validators"
"github.com/gofrs/uuid"
"time"
)

type WebauthnCredentialUserHandle struct {
ID uuid.UUID `db:"id" json:"id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Handle string `db:"handle" json:"handle"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}

// Validate gets run every time you call a "pop.Validate*" (pop.ValidateAndSave, pop.ValidateAndCreate, pop.ValidateAndUpdate) method.
func (userHandle *WebauthnCredentialUserHandle) Validate(tx *pop.Connection) (*validate.Errors, error) {
return validate.Validate(
&validators.UUIDIsPresent{Name: "ID", Field: userHandle.ID},
&validators.UUIDIsPresent{Name: "UserId", Field: userHandle.UserID},
&validators.StringIsPresent{Name: "handle", Field: userHandle.Handle},
&validators.TimeIsPresent{Name: "CreatedAt", Field: userHandle.CreatedAt},
&validators.TimeIsPresent{Name: "UpdatedAt", Field: userHandle.UpdatedAt},
), nil
}
2 changes: 1 addition & 1 deletion backend/persistence/webauthn_credential_persister.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func NewWebauthnCredentialPersister(db *pop.Connection) WebauthnCredentialPersis

func (p *webauthnCredentialPersister) Get(id string) (*models.WebauthnCredential, error) {
credential := models.WebauthnCredential{}
err := p.db.Find(&credential, id)
err := p.db.Eager().Find(&credential, id)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
Expand Down
16 changes: 14 additions & 2 deletions backend/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

type Manager interface {
GenerateJWT(userId uuid.UUID, userDto *dto.EmailJwt) (string, jwt.Token, error)
GenerateJWT(userId uuid.UUID, userDto *dto.EmailJwt, opts ...JWTOptions) (string, jwt.Token, error)
Verify(string) (jwt.Token, error)
GenerateCookie(token string) (*http.Cookie, error)
DeleteCookie() (*http.Cookie, error)
Expand Down Expand Up @@ -90,7 +90,7 @@ func NewManager(jwkManager hankoJwk.Manager, config config.Config) (Manager, err
}

// GenerateJWT creates a new session JWT for the given user
func (m *manager) GenerateJWT(userId uuid.UUID, email *dto.EmailJwt) (string, jwt.Token, error) {
func (m *manager) GenerateJWT(userId uuid.UUID, email *dto.EmailJwt, opts ...JWTOptions) (string, jwt.Token, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it make sense to set the email parameter via a JWTOption too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that would make sense. This is just leftover from the previous implementation where a user would have a custom userID. And I forgot that its still there. Maybe we can do it in a separate PR?

sessionID, err := uuid.NewV4()
if err != nil {
return "", nil, err
Expand All @@ -109,6 +109,10 @@ func (m *manager) GenerateJWT(userId uuid.UUID, email *dto.EmailJwt) (string, jw
_ = token.Set("email", &email)
}

for _, opt := range opts {
opt(token)
}

if m.issuer != "" {
_ = token.Set(jwt.IssuerKey, m.issuer)
}
Expand Down Expand Up @@ -158,3 +162,11 @@ func (m *manager) DeleteCookie() (*http.Cookie, error) {
MaxAge: -1,
}, nil
}

type JWTOptions func(token jwt.Token)

func WithValue(key string, value interface{}) JWTOptions {
return func(jwt jwt.Token) {
_ = jwt.Set(key, value)
}
}
Loading