From acce230d607bb829b7564d5062d0aec86a986833 Mon Sep 17 00:00:00 2001 From: Freddie Rice Date: Mon, 15 Jan 2024 15:07:30 -0600 Subject: [PATCH] update jwt package from v3 to v5 --- _example/main.go | 4 +- v2/go.mod | 1 + v2/go.sum | 2 + v2/middleware/auth.go | 2 +- v2/middleware/auth_test.go | 5 +-- v2/provider/apple.go | 26 +++++------ v2/provider/apple_pubkeys.go | 2 +- v2/provider/apple_pubkeys_test.go | 2 +- v2/provider/apple_test.go | 4 +- v2/provider/custom_server_test.go | 6 +-- v2/provider/direct.go | 8 ++-- v2/provider/direct_test.go | 4 +- v2/provider/oauth1.go | 16 +++---- v2/provider/oauth1_test.go | 2 +- v2/provider/oauth2.go | 16 +++---- v2/provider/oauth2_test.go | 2 +- v2/provider/telegram.go | 12 ++--- v2/provider/verify.go | 14 +++--- v2/provider/verify_test.go | 12 ++--- v2/token/jwt.go | 75 +++++++++++++++++++++---------- v2/token/jwt_test.go | 51 +++++++++++---------- 21 files changed, 151 insertions(+), 115 deletions(-) diff --git a/_example/main.go b/_example/main.go index 2a2e42dc..3dd315c8 100644 --- a/_example/main.go +++ b/_example/main.go @@ -22,7 +22,7 @@ import ( log "github.com/go-pkgz/lgr" "github.com/go-pkgz/rest" "github.com/go-pkgz/rest/logger" - "github.com/golang-jwt/jwt" + oldjwt "github.com/golang-jwt/jwt" "golang.org/x/oauth2" "github.com/go-pkgz/auth" @@ -295,7 +295,7 @@ func initGoauth2Srv() *goauth2.Server { manager.MustTokenStorage(store.NewMemoryTokenStore()) // generate jwt access token - manager.MapAccessGenerate(generates.NewJWTAccessGenerate("custom", []byte("00000000"), jwt.SigningMethodHS512)) + manager.MapAccessGenerate(generates.NewJWTAccessGenerate("custom", []byte("00000000"), oldjwt.SigningMethodHS512)) // client memory store clientStore := store.NewClientStore() diff --git a/v2/go.mod b/v2/go.mod index a15c84ce..77554540 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -9,6 +9,7 @@ require ( github.com/go-pkgz/repeater v1.1.3 github.com/go-pkgz/rest v1.19.0 github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/rrivera/identicon v0.0.0-20240116195454-d5ba35832c0d github.com/stretchr/testify v1.9.0 go.etcd.io/bbolt v1.3.9 diff --git a/v2/go.sum b/v2/go.sum index 4d432554..29b1ba54 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -31,6 +31,8 @@ github.com/go-session/session v3.1.2+incompatible/go.mod h1:8B3iivBQjrz/JtC68Np2 github.com/golang-jwt/jwt v3.2.1+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= diff --git a/v2/middleware/auth.go b/v2/middleware/auth.go index c62d1686..6e8c5ed8 100644 --- a/v2/middleware/auth.go +++ b/v2/middleware/auth.go @@ -177,7 +177,7 @@ func (a *Authenticator) refreshExpiredToken(w http.ResponseWriter, claims token. } } - claims.ExpiresAt = 0 // this will cause now+duration for refreshed token + claims.ExpiresAt = nil // this will cause now+duration for refreshed token c, err := a.JWTService.Set(w, claims) // Set changes token if err != nil { return token.Claims{}, err diff --git a/v2/middleware/auth_test.go b/v2/middleware/auth_test.go index 518380c1..d71e0c9f 100644 --- a/v2/middleware/auth_test.go +++ b/v2/middleware/auth_test.go @@ -166,9 +166,8 @@ func TestAuthJWTRefresh(t *testing.T) { claims, err := a.JWTService.Parse(resp.Cookies()[0].Value) assert.NoError(t, err) - ts := time.Unix(claims.ExpiresAt, 0) - assert.True(t, ts.After(time.Now()), "expiration in the future") - log.Print(time.Unix(claims.ExpiresAt, 0)) + assert.True(t, claims.ExpiresAt.After(time.Now()), "expiration in the future") + log.Print(claims.ExpiresAt) } func TestAuthJWTRefreshConcurrentWithCache(t *testing.T) { diff --git a/v2/provider/apple.go b/v2/provider/apple.go index f83168ab..4a2fbcd0 100644 --- a/v2/provider/apple.go +++ b/v2/provider/apple.go @@ -24,7 +24,7 @@ import ( "golang.org/x/oauth2" "github.com/go-pkgz/rest" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" "github.com/go-pkgz/auth/v2/logger" "github.com/go-pkgz/auth/v2/token" @@ -261,11 +261,11 @@ func (ah *AppleHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { From: r.URL.Query().Get("from"), }, SessionOnly: r.URL.Query().Get("session") != "" && r.URL.Query().Get("session") != "0", - StandardClaims: jwt.StandardClaims{ - Id: cid, - Audience: r.URL.Query().Get("site"), - ExpiresAt: time.Now().Add(30 * time.Minute).Unix(), - NotBefore: time.Now().Add(-1 * time.Minute).Unix(), + RegisteredClaims: jwt.RegisteredClaims{ + ID: cid, + Audience: jwt.ClaimStrings{r.URL.Query().Get("site")}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(30 * time.Minute)), + NotBefore: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)), }, } @@ -370,9 +370,9 @@ func (ah AppleHandler) AuthHandler(w http.ResponseWriter, r *http.Request) { claims := token.Claims{ User: &u, - StandardClaims: jwt.StandardClaims{ + RegisteredClaims: jwt.RegisteredClaims{ Issuer: ah.Issuer, - Id: cid, + ID: cid, Audience: oauthClaims.Audience, }, SessionOnly: false, @@ -467,13 +467,13 @@ func (ah *AppleHandler) createClientSecret() (string, error) { } // Create a claims now := time.Now() - exp := now.Add(time.Minute * 30).Unix() // default value + exp := now.Add(time.Minute * 30) // default value - claims := &jwt.StandardClaims{ + claims := &jwt.RegisteredClaims{ Issuer: ah.conf.TeamID, - IssuedAt: now.Unix(), - ExpiresAt: exp, - Audience: "https://appleid.apple.com", + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(exp), + Audience: []string{"https://appleid.apple.com"}, Subject: ah.conf.ClientID, } diff --git a/v2/provider/apple_pubkeys.go b/v2/provider/apple_pubkeys.go index ce0ccde0..c1975a2a 100644 --- a/v2/provider/apple_pubkeys.go +++ b/v2/provider/apple_pubkeys.go @@ -16,7 +16,7 @@ import ( "net/http" "time" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" ) // appleKeysURL is the endpoint URL for fetch Appleā€™s public key diff --git a/v2/provider/apple_pubkeys_test.go b/v2/provider/apple_pubkeys_test.go index 3b6e810d..3003f0c1 100644 --- a/v2/provider/apple_pubkeys_test.go +++ b/v2/provider/apple_pubkeys_test.go @@ -9,7 +9,7 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/v2/provider/apple_test.go b/v2/provider/apple_test.go index 1b47deb1..19d10892 100644 --- a/v2/provider/apple_test.go +++ b/v2/provider/apple_test.go @@ -19,7 +19,7 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -303,7 +303,7 @@ func TestAppleHandler_LoginHandler(t *testing.T) { require.NoError(t, err) t.Log(claims) assert.Equal(t, "go-pkgz/auth", claims.Issuer) - assert.Equal(t, "remark", claims.Audience) + assert.Equal(t, "remark", claims.Audience[0]) } diff --git a/v2/provider/custom_server_test.go b/v2/provider/custom_server_test.go index 945e7728..bcb31265 100644 --- a/v2/provider/custom_server_test.go +++ b/v2/provider/custom_server_test.go @@ -18,7 +18,7 @@ import ( "github.com/go-oauth2/oauth2/v4/models" goauth2 "github.com/go-oauth2/oauth2/v4/server" "github.com/go-oauth2/oauth2/v4/store" - "github.com/golang-jwt/jwt" + oldjwt "github.com/golang-jwt/jwt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -80,7 +80,7 @@ func TestCustomProvider(t *testing.T) { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } - assert.Equal(t, 2, len(resp.Cookies())) + require.Equal(t, 2, len(resp.Cookies())) assert.Equal(t, "JWT", resp.Cookies()[0].Name) assert.NotEqual(t, "", resp.Cookies()[0].Value, "token set") assert.Equal(t, 2678400, resp.Cookies()[0].MaxAge) @@ -192,7 +192,7 @@ func initGoauth2Srv(t *testing.T) *goauth2.Server { manager.MustTokenStorage(store.NewMemoryTokenStore()) // generate jwt access token - manager.MapAccessGenerate(generates.NewJWTAccessGenerate("", []byte("00000000"), jwt.SigningMethodHS512)) + manager.MapAccessGenerate(generates.NewJWTAccessGenerate("", []byte("00000000"), oldjwt.SigningMethodHS512)) // client memory store clientStore := store.NewClientStore() diff --git a/v2/provider/direct.go b/v2/provider/direct.go index 19a965e2..7d940177 100644 --- a/v2/provider/direct.go +++ b/v2/provider/direct.go @@ -9,7 +9,7 @@ import ( "time" "github.com/go-pkgz/rest" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" "github.com/go-pkgz/auth/v2/logger" "github.com/go-pkgz/auth/v2/token" @@ -120,10 +120,10 @@ func (p DirectHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { claims := token.Claims{ User: &u, - StandardClaims: jwt.StandardClaims{ - Id: cid, + RegisteredClaims: jwt.RegisteredClaims{ + ID: cid, Issuer: p.Issuer, - Audience: creds.Audience, + Audience: []string{creds.Audience}, }, SessionOnly: sessOnly, } diff --git a/v2/provider/direct_test.go b/v2/provider/direct_test.go index 0e8fcdf1..03199b22 100644 --- a/v2/provider/direct_test.go +++ b/v2/provider/direct_test.go @@ -90,9 +90,9 @@ func TestDirect_LoginHandler(t *testing.T) { claims, err := d.TokenService.Parse(c.Value) require.NoError(t, err) t.Logf("%+v", claims) - assert.Equal(t, "xyz123", claims.Audience) + assert.Equal(t, "xyz123", claims.Audience[0]) assert.Equal(t, "iss-test", claims.Issuer) - assert.True(t, claims.ExpiresAt > time.Now().Unix()) + assert.True(t, claims.ExpiresAt.After(time.Now())) assert.Equal(t, "myuser", claims.User.Name) }) } diff --git a/v2/provider/oauth1.go b/v2/provider/oauth1.go index b00e1e1c..6a51b0da 100644 --- a/v2/provider/oauth1.go +++ b/v2/provider/oauth1.go @@ -10,7 +10,7 @@ import ( "github.com/dghubble/oauth1" "github.com/go-pkgz/rest" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" "github.com/go-pkgz/auth/v2/logger" "github.com/go-pkgz/auth/v2/token" @@ -55,11 +55,11 @@ func (h Oauth1Handler) LoginHandler(w http.ResponseWriter, r *http.Request) { From: r.URL.Query().Get("from"), }, SessionOnly: r.URL.Query().Get("session") != "" && r.URL.Query().Get("session") != "0", - StandardClaims: jwt.StandardClaims{ - Id: cid, - Audience: r.URL.Query().Get("site"), - ExpiresAt: time.Now().Add(30 * time.Minute).Unix(), - NotBefore: time.Now().Add(-1 * time.Minute).Unix(), + RegisteredClaims: jwt.RegisteredClaims{ + ID: cid, + Audience: []string{r.URL.Query().Get("site")}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(30 * time.Minute)), + NotBefore: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)), }, } @@ -140,9 +140,9 @@ func (h Oauth1Handler) AuthHandler(w http.ResponseWriter, r *http.Request) { } claims := token.Claims{ User: &u, - StandardClaims: jwt.StandardClaims{ + RegisteredClaims: jwt.RegisteredClaims{ Issuer: h.Issuer, - Id: cid, + ID: cid, Audience: oauthClaims.Audience, }, SessionOnly: oauthClaims.SessionOnly, diff --git a/v2/provider/oauth1_test.go b/v2/provider/oauth1_test.go index 2565869f..920c467f 100644 --- a/v2/provider/oauth1_test.go +++ b/v2/provider/oauth1_test.go @@ -62,7 +62,7 @@ func TestOauth1Login(t *testing.T) { require.NoError(t, err) t.Log(claims) assert.Equal(t, "remark42", claims.Issuer) - assert.Equal(t, "remark", claims.Audience) + assert.Equal(t, "remark", claims.Audience[0]) // check admin user resp, err = client.Get(fmt.Sprintf("http://localhost:%d/login?site=remark", loginPort)) diff --git a/v2/provider/oauth2.go b/v2/provider/oauth2.go index 23cb605c..a985cd98 100644 --- a/v2/provider/oauth2.go +++ b/v2/provider/oauth2.go @@ -10,7 +10,7 @@ import ( "time" "github.com/go-pkgz/rest" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" "golang.org/x/oauth2" "github.com/go-pkgz/auth/v2/logger" @@ -111,11 +111,11 @@ func (p Oauth2Handler) LoginHandler(w http.ResponseWriter, r *http.Request) { From: r.URL.Query().Get("from"), }, SessionOnly: r.URL.Query().Get("session") != "" && r.URL.Query().Get("session") != "0", - StandardClaims: jwt.StandardClaims{ - Id: cid, - Audience: aud, - ExpiresAt: time.Now().Add(30 * time.Minute).Unix(), - NotBefore: time.Now().Add(-1 * time.Minute).Unix(), + RegisteredClaims: jwt.RegisteredClaims{ + ID: cid, + Audience: []string{aud}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(30 * time.Minute)), + NotBefore: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)), }, NoAva: r.URL.Query().Get("noava") == "1", } @@ -208,9 +208,9 @@ func (p Oauth2Handler) AuthHandler(w http.ResponseWriter, r *http.Request) { } claims := token.Claims{ User: &u, - StandardClaims: jwt.StandardClaims{ + RegisteredClaims: jwt.RegisteredClaims{ Issuer: p.Issuer, - Id: cid, + ID: cid, Audience: oauthClaims.Audience, }, SessionOnly: oauthClaims.SessionOnly, diff --git a/v2/provider/oauth2_test.go b/v2/provider/oauth2_test.go index 5448564b..ce890daa 100644 --- a/v2/provider/oauth2_test.go +++ b/v2/provider/oauth2_test.go @@ -73,7 +73,7 @@ func TestOauth2Login(t *testing.T) { require.NoError(t, err) t.Log(claims) assert.Equal(t, "remark42", claims.Issuer) - assert.Equal(t, "remark", claims.Audience) + assert.Equal(t, "remark", claims.Audience[0]) // check admin user resp, err = client.Get("http://localhost:8981/login?site=remark") diff --git a/v2/provider/telegram.go b/v2/provider/telegram.go index 8ecc2ef9..906fe59f 100644 --- a/v2/provider/telegram.go +++ b/v2/provider/telegram.go @@ -17,7 +17,7 @@ import ( "github.com/go-pkgz/repeater" "github.com/go-pkgz/rest" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" "github.com/go-pkgz/auth/v2/logger" authtoken "github.com/go-pkgz/auth/v2/token" @@ -302,12 +302,12 @@ func (th *TelegramHandler) LoginHandler(w http.ResponseWriter, r *http.Request) claims := authtoken.Claims{ User: &u, - StandardClaims: jwt.StandardClaims{ - Audience: r.URL.Query().Get("site"), - Id: queryToken, + RegisteredClaims: jwt.RegisteredClaims{ + Audience: []string{r.URL.Query().Get("site")}, + ID: queryToken, Issuer: th.ProviderName, - ExpiresAt: time.Now().Add(30 * time.Minute).Unix(), - NotBefore: time.Now().Add(-1 * time.Minute).Unix(), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(30 * time.Minute)), + NotBefore: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)), }, SessionOnly: false, // TODO review? } diff --git a/v2/provider/verify.go b/v2/provider/verify.go index a1376af3..0fc9ca6e 100644 --- a/v2/provider/verify.go +++ b/v2/provider/verify.go @@ -10,7 +10,7 @@ import ( "time" "github.com/go-pkgz/rest" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" "github.com/go-pkgz/auth/v2/avatar" "github.com/go-pkgz/auth/v2/logger" @@ -111,8 +111,8 @@ func (e VerifyHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { claims := token.Claims{ User: &u, - StandardClaims: jwt.StandardClaims{ - Id: cid, + RegisteredClaims: jwt.RegisteredClaims{ + ID: cid, Issuer: e.Issuer, Audience: confClaims.Audience, }, @@ -146,10 +146,10 @@ func (e VerifyHandler) sendConfirmation(w http.ResponseWriter, r *http.Request) ID: user + "::" + address, }, SessionOnly: r.URL.Query().Get("session") != "" && r.URL.Query().Get("session") != "0", - StandardClaims: jwt.StandardClaims{ - Audience: site, - ExpiresAt: time.Now().Add(30 * time.Minute).Unix(), - NotBefore: time.Now().Add(-1 * time.Minute).Unix(), + RegisteredClaims: jwt.RegisteredClaims{ + Audience: []string{site}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(30 * time.Minute)), + NotBefore: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)), Issuer: e.Issuer, }, } diff --git a/v2/provider/verify_test.go b/v2/provider/verify_test.go index 2c298ed6..276600d4 100644 --- a/v2/provider/verify_test.go +++ b/v2/provider/verify_test.go @@ -54,8 +54,8 @@ func TestVerifyHandler_LoginSendConfirm(t *testing.T) { assert.NoError(t, err) t.Logf("%s %+v", tknStr, tkn) assert.Equal(t, "test123::blah@user.com", tkn.Handshake.ID) - assert.Equal(t, "remark42", tkn.Audience) - assert.True(t, tkn.ExpiresAt > tkn.NotBefore) + assert.Equal(t, "remark42", tkn.Audience[0]) + assert.True(t, tkn.ExpiresAt.After(tkn.NotBefore.Time)) assert.Equal(t, "test", e.Name()) } @@ -93,8 +93,8 @@ func TestVerifyHandler_LoginSendConfirmEscapesBadInput(t *testing.T) { t.Logf("%s %+v", tknStr, tkn) // not escaped in these fields as they are not rendered as HTML assert.Equal(t, badData+"::blah@user.com", tkn.Handshake.ID) - assert.Equal(t, badData, tkn.Audience) - assert.True(t, tkn.ExpiresAt > tkn.NotBefore) + assert.Equal(t, badData, tkn.Audience[0]) + assert.True(t, tkn.ExpiresAt.After(tkn.NotBefore.Time)) assert.Equal(t, "test", e.Name()) } @@ -125,9 +125,9 @@ func TestVerifyHandler_LoginAcceptConfirm(t *testing.T) { claims, err := e.TokenService.Parse(c.Value) require.NoError(t, err) t.Logf("%+v", claims) - assert.Equal(t, "remark42", claims.Audience) + assert.Equal(t, "remark42", claims.Audience[0]) assert.Equal(t, "iss-test", claims.Issuer) - assert.True(t, claims.ExpiresAt > time.Now().Unix()) + assert.True(t, claims.ExpiresAt.After(time.Now())) assert.Equal(t, "test123", claims.User.Name) assert.Equal(t, true, claims.SessionOnly) } diff --git a/v2/token/jwt.go b/v2/token/jwt.go index c899b070..1f601526 100644 --- a/v2/token/jwt.go +++ b/v2/token/jwt.go @@ -3,12 +3,14 @@ package token import ( "encoding/json" + "errors" "fmt" "net/http" "strings" + "sync" "time" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" ) // Service wraps jwt operations @@ -19,7 +21,7 @@ type Service struct { // Claims stores user info for token and state & from from login type Claims struct { - jwt.StandardClaims + jwt.RegisteredClaims User *User `json:"user,omitempty"` // user info SessionOnly bool `json:"sess_only,omitempty"` Handshake *Handshake `json:"handshake,omitempty"` // used for oauth handshake @@ -74,6 +76,11 @@ type Opts struct { // NewService makes JWT service func NewService(opts Opts) *Service { + var once sync.Once + once.Do(func() { + jwt.MarshalSingleStringAsArray = false + }) + res := Service{Opts: opts} setDefault := func(fld *string, def string) { @@ -121,7 +128,7 @@ func (j *Service) Token(claims Claims) (string, error) { return "", fmt.Errorf("aud rejected: %w", err) } - secret, err := j.SecretReader.Get(claims.Audience) // get secret via consumer defined SecretReader + secret, err := j.SecretReader.Get(claims.Audience[0]) // get secret via consumer defined SecretReader if err != nil { return "", fmt.Errorf("can't get secret: %w", err) } @@ -135,7 +142,7 @@ func (j *Service) Token(claims Claims) (string, error) { // Parse token string and verify. Not checking for expiration func (j *Service) Parse(tokenString string) (Claims, error) { - parser := jwt.Parser{SkipClaimsValidation: true} // allow parsing of expired tokens + parser := jwt.NewParser(jwt.WithoutClaimsValidation()) if j.SecretReader == nil { return Claims{}, fmt.Errorf("secret reader not defined") @@ -179,7 +186,7 @@ func (j *Service) Parse(tokenString string) (Claims, error) { // aud pre-parse token and extracts aud from the claim // important! this step ignores token verification, should not be used for any validations func (j *Service) aud(tokenString string) (string, error) { - parser := jwt.Parser{} + parser := jwt.NewParser() token, _, err := parser.ParseUnverified(tokenString, &Claims{}) if err != nil { return "", fmt.Errorf("can't pre-parse token: %w", err) @@ -188,34 +195,44 @@ func (j *Service) aud(tokenString string) (string, error) { if !ok { return "", fmt.Errorf("invalid token") } - if strings.TrimSpace(claims.Audience) == "" { + + if len(claims.Audience) == 0 { + return "", fmt.Errorf("empty aud") + } + aud := claims.Audience[0] + + if strings.TrimSpace(aud) == "" { return "", fmt.Errorf("empty aud") } - return claims.Audience, nil + return aud, nil } func (j *Service) validate(claims *Claims) error { - cerr := claims.Valid() + validator := jwt.NewValidator() + err := validator.Validate(claims) - if cerr == nil { + if err == nil { return nil } - if e, ok := cerr.(*jwt.ValidationError); ok { - if e.Errors == jwt.ValidationErrorExpired { - return nil // allow expired tokens + // Ignore "ErrTokenExpired" if it is the only error. + if errors.Is(err, jwt.ErrTokenExpired) { + if uw, ok := err.(interface{ Unwrap() []error }); ok && len(uw.Unwrap()) == 1 { + return nil } } - return cerr + return err } // Set creates token cookie with xsrf cookie and put it to ResponseWriter // accepts claims and sets expiration if none defined. permanent flag means long-living cookie, // false makes it session only. func (j *Service) Set(w http.ResponseWriter, claims Claims) (Claims, error) { - if claims.ExpiresAt == 0 { - claims.ExpiresAt = time.Now().Add(j.TokenDuration).Unix() + nowUnix := time.Now().Unix() + + if claims.ExpiresAt == nil || claims.ExpiresAt.Time.Unix() == 0 { + claims.ExpiresAt = jwt.NewNumericDate(time.Unix(nowUnix, 0).Add(j.TokenDuration)) } if claims.Issuer == "" { @@ -223,7 +240,7 @@ func (j *Service) Set(w http.ResponseWriter, claims Claims) (Claims, error) { } if !j.DisableIAT { - claims.IssuedAt = time.Now().Unix() + claims.IssuedAt = jwt.NewNumericDate(time.Unix(nowUnix, 0)) } tokenString, err := j.Token(claims) @@ -245,7 +262,7 @@ func (j *Service) Set(w http.ResponseWriter, claims Claims) (Claims, error) { MaxAge: cookieExpiration, Secure: j.SecureCookies, SameSite: j.SameSite} http.SetCookie(w, &jwtCookie) - xsrfCookie := http.Cookie{Name: j.XSRFCookieName, Value: claims.Id, HttpOnly: false, Path: "/", Domain: j.JWTCookieDomain, + xsrfCookie := http.Cookie{Name: j.XSRFCookieName, Value: claims.ID, HttpOnly: false, Path: "/", Domain: j.JWTCookieDomain, MaxAge: cookieExpiration, Secure: j.SecureCookies, SameSite: j.SameSite} http.SetCookie(w, &xsrfCookie) @@ -286,7 +303,10 @@ func (j *Service) Get(r *http.Request) (Claims, string, error) { // promote claim's aud to User.Audience if claims.User != nil { - claims.User.Audience = claims.Audience + if len(claims.Audience) != 1 { + return Claims{}, "", fmt.Errorf("aud is not of size 1") + } + claims.User.Audience = claims.Audience[0] } if !fromCookie && j.IsExpired(claims) { @@ -299,7 +319,7 @@ func (j *Service) Get(r *http.Request) (Claims, string, error) { if fromCookie && claims.User != nil { xsrf := r.Header.Get(j.XSRFHeaderKey) - if claims.Id != xsrf { + if claims.ID != xsrf { return Claims{}, "", fmt.Errorf("xsrf mismatch") } } @@ -309,7 +329,9 @@ func (j *Service) Get(r *http.Request) (Claims, string, error) { // IsExpired returns true if claims expired func (j *Service) IsExpired(claims Claims) bool { - return !claims.VerifyExpiresAt(time.Now().Unix(), true) + validator := jwt.NewValidator(jwt.WithExpirationRequired()) + err := validator.Validate(claims) + return errors.Is(err, jwt.ErrTokenExpired) } // Reset token's cookies @@ -327,25 +349,32 @@ func (j *Service) Reset(w http.ResponseWriter) { // checkAuds verifies if claims.Audience in the list of allowed by audReader func (j *Service) checkAuds(claims *Claims, audReader Audience) error { + // marshal the audience. if audReader == nil { // lack of any allowed means any return nil } + + if len(claims.Audience) == 0 { + return fmt.Errorf("no audience provided") + } + claimsAudience := claims.Audience[0] + auds, err := audReader.Get() if err != nil { return fmt.Errorf("failed to get auds: %w", err) } for _, a := range auds { - if strings.EqualFold(a, claims.Audience) { + if strings.EqualFold(a, claimsAudience) { return nil } } - return fmt.Errorf("aud %q not allowed", claims.Audience) + return fmt.Errorf("aud %q not allowed", claimsAudience) } func (c Claims) String() string { b, err := json.Marshal(c) if err != nil { - return fmt.Sprintf("%+v %+v", c.StandardClaims, c.User) + return fmt.Sprintf("%+v %+v", c.RegisteredClaims, c.User) } return string(b) } diff --git a/v2/token/jwt_test.go b/v2/token/jwt_test.go index e30ee137..6e6db14a 100644 --- a/v2/token/jwt_test.go +++ b/v2/token/jwt_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -17,9 +17,9 @@ import ( // ("secret" in most cases here, "xyz 12345" in makeTestAuth), and alter the fields you want to be changed. var ( - testJwtValid = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19LCJoYW5kc2hha2UiOnsic3RhdGUiOiIxMjM0NTYiLCJmcm9tIjoiZnJvbSIsImlkIjoibXlpZC0xMjM0NTYifX0._2X1cAEoxjLA7XuN8xW8V9r7rYfP_m9lSRz_9_UFzac" - testJwtValidNoHandshake = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19fQ.OWPdibrSSSHuOV3DzzLH5soO6kUcERELL7_GLf7Ja_E" - testJwtValidSess = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19LCJzZXNzX29ubHkiOnRydWV9.SjPlVgca_bijC2wbaite2_eNHk66VXgsxUKLy7eqlXM" + testJwtValid = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJyZW1hcms0MiIsImF1ZCI6InRlc3Rfc3lzIiwiZXhwIjoyNzg5MTkxODIyLCJuYmYiOjE1MjY4ODQyMjIsImp0aSI6InJhbmRvbSBpZCIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19LCJoYW5kc2hha2UiOnsic3RhdGUiOiIxMjM0NTYiLCJmcm9tIjoiZnJvbSIsImlkIjoibXlpZC0xMjM0NTYifX0.Ln7P2rEO-kWLN8AuKddWzjKC9l_kpw_yWfSO12MYo0o" + testJwtValidNoHandshake = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJyZW1hcms0MiIsImF1ZCI6InRlc3Rfc3lzIiwiZXhwIjoyNzg5MTkxODIyLCJuYmYiOjE1MjY4ODQyMjIsImp0aSI6InJhbmRvbSBpZCIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19fQ.D7fO3tzq3y-uSnh3Mae-Mqp8w9WdkH9s4zPTh44k8Gs" + testJwtValidSess = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJyZW1hcms0MiIsImF1ZCI6InRlc3Rfc3lzIiwiZXhwIjoyNzg5MTkxODIyLCJuYmYiOjE1MjY4ODQyMjIsImp0aSI6InJhbmRvbSBpZCIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19LCJzZXNzX29ubHkiOnRydWV9.RtQ6uBksqtMTd9GDLJen_eDUlLAYLh9uH0GBO_OIf4M" testJwtExpired = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE1MjY4ODc4MjIsImp0aSI6InJhbmRvbSBpZCIs" + "ImlzcyI6InJlbWFyazQyIiwibmJmIjoxNTI2ODg0MjIyLCJ1c2VyIjp7Im5hbWUiOiJuYW1lMSIsImlkIjoiaWQxIiwicGljdHVyZSI6IiI" + "sImFkbWluIjpmYWxzZX0sInN0YXRlIjoiMTIzNDU2IiwiZnJvbSI6ImZyb20ifQ.4_dCrY9ihyfZIedz-kZwBTxmxU1a52V7IqeJrOqTzE4" @@ -88,6 +88,11 @@ func TestJWT_Token(t *testing.T) { assert.NoError(t, err) assert.Equal(t, testJwtValid, res) + newClaims, _ := j.Parse(res) + assert.Equal(t, claims, newClaims) + fmt.Println(claims) + fmt.Println(newClaims) + j.SecretReader = nil _, err = j.Token(claims) assert.EqualError(t, err, "secret reader not defined") @@ -126,10 +131,10 @@ func TestJWT_Parse(t *testing.T) { assert.Error(t, err, "bad token") _, err = j.Parse(testJwtBadSign) - assert.EqualError(t, err, "can't parse token: signature is invalid") + assert.EqualError(t, err, "can't parse token: token signature is invalid: signature is invalid") _, err = j.Parse(testJwtNoneAlg) - assert.EqualError(t, err, "can't parse token: unexpected signing method: none") + assert.EqualError(t, err, "can't parse token: token is unverifiable: error while executing keyfunc: unexpected signing method: none") j = NewService(Opts{ SecretReader: SecretFunc(func(string) (string, error) { return "bad 12345", nil }), @@ -271,7 +276,7 @@ func TestJWT_SetProlonged(t *testing.T) { claims := testClaims claims.Handshake = nil - claims.ExpiresAt = 0 + claims.ExpiresAt = nil rr := httptest.NewRecorder() _, err := j.Set(rr, claims) @@ -282,7 +287,7 @@ func TestJWT_SetProlonged(t *testing.T) { cc, err := j.Parse(cookies[0].Value) assert.NoError(t, err) - assert.True(t, cc.ExpiresAt > time.Now().Unix()) + assert.True(t, cc.ExpiresAt.After(time.Now().UTC())) } func TestJWT_NoIssuer(t *testing.T) { @@ -345,7 +350,7 @@ func TestJWT_GetFromHeader(t *testing.T) { req.Header.Add(jwtCustomHeaderKey, "bad bad token") _, _, err = j.Get(req) require.NotNil(t, err) - assert.True(t, strings.Contains(err.Error(), "failed to get token: can't parse token: token contains an invalid number of segments"), err.Error()) + assert.True(t, strings.Contains(err.Error(), "failed to get token: can't parse token: token is malformed: token contains an invalid number of segments"), err.Error()) } func TestJWT_GetFromQuery(t *testing.T) { @@ -375,7 +380,7 @@ func TestJWT_GetFromQuery(t *testing.T) { req = httptest.NewRequest("GET", "/blah?token=blah", nil) _, _, err = j.Get(req) require.NotNil(t, err) - assert.True(t, strings.Contains(err.Error(), "failed to get token: can't parse token: token contains an invalid number of segments"), err.Error()) + assert.True(t, strings.Contains(err.Error(), "failed to get token: can't parse token: token is malformed: token contains an invalid number of segments"), err.Error()) } func TestJWT_GetFailed(t *testing.T) { @@ -467,7 +472,7 @@ func TestJWT_SetAndGetWithXsrfMismatch(t *testing.T) { req.Header.Add(xsrfCustomHeaderKey, "random id wrong") c, _, err := j.Get(req) require.NoError(t, err, "xsrf mismatch, but ignored") - claims.User.Audience = c.Audience // set aud to user because we don't do the normal Get call + claims.User.Audience = c.Audience[0] // set aud to user because we don't do the normal Get call assert.Equal(t, claims, c) } @@ -485,8 +490,8 @@ func TestJWT_SetAndGetWithCookiesExpired(t *testing.T) { }) claims := testClaims - claims.StandardClaims.ExpiresAt = time.Date(2018, 5, 21, 1, 35, 22, 0, time.Local).Unix() - claims.StandardClaims.NotBefore = time.Date(2018, 5, 21, 1, 30, 22, 0, time.Local).Unix() + claims.RegisteredClaims.ExpiresAt = jwt.NewNumericDate(time.Date(2018, 5, 21, 1, 35, 22, 0, time.Local)) + claims.RegisteredClaims.NotBefore = jwt.NewNumericDate(time.Date(2018, 5, 21, 1, 30, 22, 0, time.Local)) claims.SessionOnly = true ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -560,8 +565,8 @@ func TestAudience(t *testing.T) { }) c := Claims{ - StandardClaims: jwt.StandardClaims{ - Audience: "au1", + RegisteredClaims: jwt.RegisteredClaims{ + Audience: []string{"au1"}, Issuer: "test iss", }, } @@ -592,7 +597,7 @@ func TestAudReader(t *testing.T) { assert.EqualError(t, err, "empty aud") _, err = j.aud("blah bad bad") - assert.EqualError(t, err, "can't pre-parse token: token contains an invalid number of segments") + assert.EqualError(t, err, "can't pre-parse token: token is malformed: token contains an invalid number of segments") } func TestParseWithAud(t *testing.T) { @@ -608,19 +613,19 @@ func TestParseWithAud(t *testing.T) { claims, err = j.Parse(testJwtValidAud) assert.NoError(t, err) - assert.Equal(t, "test_aud_only", claims.Audience) + assert.Equal(t, "test_aud_only", claims.Audience[0]) claims, err = j.Parse(testJwtNonAudSign) - assert.EqualError(t, err, "can't parse token: signature is invalid") + assert.EqualError(t, err, "can't parse token: token signature is invalid: signature is invalid") } var testClaims = Claims{ - StandardClaims: jwt.StandardClaims{ - Id: "random id", + RegisteredClaims: jwt.RegisteredClaims{ + ID: "random id", Issuer: "remark42", - Audience: "test_sys", - ExpiresAt: time.Date(2058, 5, 21, 7, 30, 22, 0, time.UTC).Unix(), - NotBefore: time.Date(2018, 5, 21, 6, 30, 22, 0, time.UTC).Unix(), + Audience: []string{"test_sys"}, + ExpiresAt: jwt.NewNumericDate(time.Date(2058, 5, 21, 7, 30, 22, 0, time.UTC).Local()), + NotBefore: jwt.NewNumericDate(time.Date(2018, 5, 21, 6, 30, 22, 0, time.UTC).Local()), }, User: &User{