Skip to content

Commit

Permalink
feat: support multiple token URLs
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl committed Dec 13, 2023
1 parent dfa2c0a commit fc01536
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 28 deletions.
44 changes: 27 additions & 17 deletions client_authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"fmt"
"net/http"
"net/url"
"strings"
"time"

"github.com/ory/x/errorsx"
Expand Down Expand Up @@ -149,7 +150,7 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
var jti string
if !claims.VerifyIssuer(clientID, true) {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'iss' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client."))
} else if f.Config.GetTokenURL(ctx) == "" {
} else if len(f.Config.GetTokenURLs(ctx)) == 0 {
return nil, errorsx.WithStack(ErrMisconfiguration.WithHint("The authorization server's token endpoint URL has not been set."))
} else if sub, ok := claims["sub"].(string); !ok || sub != clientID {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'sub' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client."))
Expand Down Expand Up @@ -180,22 +181,10 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
return nil, err
}

if auds, ok := claims["aud"].([]interface{}); !ok {
if !claims.VerifyAudience(f.Config.GetTokenURL(ctx), true) {
return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("Claim 'audience' from 'client_assertion' must match the authorization server's token endpoint '%s'.", f.Config.GetTokenURL(ctx)))
}
} else {
var found bool
for _, aud := range auds {
if a, ok := aud.(string); ok && a == f.Config.GetTokenURL(ctx) {
found = true
break
}
}

if !found {
return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("Claim 'audience' from 'client_assertion' must match the authorization server's token endpoint '%s'.", f.Config.GetTokenURL(ctx)))
}
if !audienceMatchesTokenURLs(claims, f.Config.GetTokenURLs(ctx)) {
return nil, errorsx.WithStack(ErrInvalidClient.WithHintf(
"Claim 'audience' from 'client_assertion' must match the authorization server's token endpoint '%s'.",
strings.Join(f.Config.GetTokenURLs(ctx), "' or '")))
}

return client, nil
Expand Down Expand Up @@ -235,6 +224,27 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
return client, nil
}

func audienceMatchesTokenURLs(claims jwt.MapClaims, tokenURLs []string) bool {
for _, tokenURL := range tokenURLs {
if audienceMatchesTokenURL(claims, tokenURL) {
return true
}
}
return false
}

func audienceMatchesTokenURL(claims jwt.MapClaims, tokenURL string) bool {
if audiences, ok := claims["aud"].([]interface{}); ok {
for _, aud := range audiences {
if a, ok := aud.(string); ok && a == tokenURL {
return true
}
}
return false
}
return claims.VerifyAudience(tokenURL, true)
}

func (f *Fosite) checkClientSecret(ctx context.Context, client Client, clientSecret []byte) error {
var err error
err = f.Config.GetSecretsHasher(ctx).Compare(ctx, client.GetHashedSecret(), clientSecret)
Expand Down
4 changes: 2 additions & 2 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ type FormPostHTMLTemplateProvider interface {
}

type TokenURLProvider interface {
// GetTokenURL returns the token URL.
GetTokenURL(ctx context.Context) string
// GetTokenURLs returns the token URL.
GetTokenURLs(ctx context.Context) []string
}

// AuthorizeEndpointHandlersProvider returns the provider for configuring the authorize endpoint handlers.
Expand Down
4 changes: 2 additions & 2 deletions config_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ func (c *Config) GetSecretsHasher(ctx context.Context) Hasher {
return c.ClientSecretsHasher
}

func (c *Config) GetTokenURL(ctx context.Context) string {
return c.TokenURL
func (c *Config) GetTokenURLs(ctx context.Context) []string {
return []string{c.TokenURL}
}

func (c *Config) GetFormPostHTMLTemplate(ctx context.Context) *template.Template {
Expand Down
18 changes: 13 additions & 5 deletions handler/rfc7523/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package rfc7523

import (
"context"
"strings"
"time"

"github.com/ory/fosite/handler/oauth2"
Expand Down Expand Up @@ -228,13 +229,11 @@ func (c *Handler) validateTokenClaims(ctx context.Context, claims jwt.Claims, ke
)
}

if !claims.Audience.Contains(c.Config.GetTokenURL(ctx)) {
if !audienceMatchesTokenURLs(claims, c.Config.GetTokenURLs(ctx)) {
return errorsx.WithStack(fosite.ErrInvalidGrant.
WithHintf(
"The JWT in \"assertion\" request parameter MUST contain an \"aud\" (audience) claim containing a value \"%s\" that identifies the authorization server as an intended audience.",
c.Config.GetTokenURL(ctx),
),
)
`The JWT in "assertion" request parameter MUST contain an "aud" (audience) claim containing a value "%s" that identifies the authorization server as an intended audience.`,
strings.Join(c.Config.GetTokenURLs(ctx), `" or "`)))
}

if claims.Expiry == nil {
Expand Down Expand Up @@ -299,6 +298,15 @@ func (c *Handler) validateTokenClaims(ctx context.Context, claims jwt.Claims, ke
return nil
}

func audienceMatchesTokenURLs(claims jwt.Claims, tokenURLs []string) bool {
for _, tokenURL := range tokenURLs {
if claims.Audience.Contains(tokenURL) {
return true
}
}
return false
}

type extendedSession interface {
Session
fosite.Session
Expand Down
5 changes: 3 additions & 2 deletions handler/rfc7523/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
mrand "math/rand"
"net/url"
"strconv"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -331,8 +332,8 @@ func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestNotValidAudienceInAsserti
s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, because of invalid audience claim in assertion")
s.Equal(
fmt.Sprintf(
"The JWT in \"assertion\" request parameter MUST contain an \"aud\" (audience) claim containing a value \"%s\" that identifies the authorization server as an intended audience.",
s.handler.Config.GetTokenURL(ctx),
`The JWT in "assertion" request parameter MUST contain an "aud" (audience) claim containing a value "%s" that identifies the authorization server as an intended audience.`,
strings.Join(s.handler.Config.GetTokenURLs(ctx), `" or "`),
),
fosite.ErrorToRFC6749Error(err).HintField,
)
Expand Down

0 comments on commit fc01536

Please sign in to comment.