From ee7e5e8fe0dde61d2051867a80c41ac5cc4e11ed Mon Sep 17 00:00:00 2001 From: James Elliott Date: Fri, 21 Jul 2023 09:28:46 +1000 Subject: [PATCH] feat: filtering mode --- handler/oauth2/flow_refresh.go | 11 +- integration/refresh_token_grant_test.go | 398 +++++++++++++++--------- 2 files changed, 265 insertions(+), 144 deletions(-) diff --git a/handler/oauth2/flow_refresh.go b/handler/oauth2/flow_refresh.go index 659c4ea34..95cef6078 100644 --- a/handler/oauth2/flow_refresh.go +++ b/handler/oauth2/flow_refresh.go @@ -29,6 +29,11 @@ type RefreshTokenGrantHandler struct { fosite.AudienceStrategyProvider fosite.RefreshTokenScopesProvider } + + // IgnoreRequestedScopeNotInOriginalGrant determines the action to take when the requested scopes in the refresh + // flow were not originally granted. If false which is the default the handler will automatically return an error. + // If true the handler will filter out / ignore the scopes which were not originally granted. + IgnoreRequestedScopeNotInOriginalGrant bool } // HandleTokenEndpointRequest implements https://tools.ietf.org/html/rfc6749#section-6 @@ -105,7 +110,11 @@ func (c *RefreshTokenGrantHandler) HandleTokenEndpointRequest(ctx context.Contex for _, scope := range request.GetRequestedScopes() { // Addresses point 2 of the text in RFC6749 Section 6. if !strategy(originalScopes, scope) { - return errorsx.WithStack(fosite.ErrInvalidScope.WithHintf("The requested scope '%s' was not originally granted by the resource owner.", scope)) + if c.IgnoreRequestedScopeNotInOriginalGrant { + continue + } else { + return errorsx.WithStack(fosite.ErrInvalidScope.WithHintf("The requested scope '%s' was not originally granted by the resource owner.", scope)) + } } if !strategy(request.GetClient().GetScopes(), scope) { diff --git a/integration/refresh_token_grant_test.go b/integration/refresh_token_grant_test.go index c91c8be61..44b09123d 100644 --- a/integration/refresh_token_grant_test.go +++ b/integration/refresh_token_grant_test.go @@ -6,6 +6,7 @@ package integration_test import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "net/url" @@ -19,6 +20,7 @@ import ( "github.com/ory/fosite" "github.com/ory/fosite/compose" + hoauth2 "github.com/ory/fosite/handler/oauth2" "github.com/ory/fosite/handler/openid" "github.com/ory/fosite/internal/gen" "github.com/ory/fosite/token/jwt" @@ -267,84 +269,11 @@ func TestRefreshTokenFlow(t *testing.T) { } func TestRefreshTokenFlowScopeParameter(t *testing.T) { - ctx := context.Background() - - session := &defaultSession{ - DefaultSession: &openid.DefaultSession{ - Claims: &jwt.IDTokenClaims{ - Subject: "peter", - }, - Headers: &jwt.Headers{}, - Subject: "peter", - Username: "peteru", - }, - } - fc := new(fosite.Config) - fc.GlobalSecret = []byte("some-secret-thats-random-some-secret-thats-random-") - f := compose.ComposeAllEnabled(fc, fositeStore, gen.MustRSAKey()) - ts := mockServer(t, f, session) - defer ts.Close() - - fc.ScopeStrategy = fosite.ExactScopeStrategy - - client := newOAuth2Client(ts) - client.Scopes = []string{"openid", "offline", "offline_access", "foo", "bar"} - client.ClientID = "grant-all-requested-scopes-client" - - state := "1234567890" - - testRefreshingClient := &fosite.DefaultClient{ - ID: "grant-all-requested-scopes-client", - Secret: []byte(`$2a$10$IxMdI6d.LIRZPpSfEwNoeu4rY3FhDREsxFJXikcgdRRAStxUlsuEO`), // = "foobar" - RedirectURIs: []string{ts.URL + "/callback"}, - ResponseTypes: []string{"code"}, - GrantTypes: []string{"implicit", "refresh_token", "authorization_code", "password", "client_credentials"}, - Scopes: []string{"openid", "offline_access", "offline", "foo", "bar", "baz"}, - Audience: []string{"https://www.ory.sh/api"}, - } - - fositeStore.Clients["grant-all-requested-scopes-client"] = testRefreshingClient - - s := compose.NewOAuth2HMACStrategy(fc) - - originalScopes := fosite.Arguments{"openid", "offline", "offline_access", "foo", "bar"} - - testCases := []struct { + type testCase struct { name string scopes fosite.Arguments expected fosite.Arguments err string - }{ - { - "ShouldGrantOriginalScopesWhenOmitted", - nil, - originalScopes, - "", - }, - { - "ShouldNarrowScopesWhenIncluded", - fosite.Arguments{"openid", "offline_access", "foo"}, - fosite.Arguments{"openid", "offline_access", "foo"}, - "", - }, - { - "ShouldGrantOriginalScopesWhenOmittedAfterNarrowing", - nil, - originalScopes, - "", - }, - { - "ShouldGrantOriginalScopesExplicitlyRequested", - originalScopes, - originalScopes, - "", - }, - { - "ShouldErrorWhenBroadeningScopesAllowedByClientButNotOriginallyGranted", - fosite.Arguments{"openid", "offline", "offline_access", "foo", "bar", "baz"}, - nil, - "The requested scope is invalid, unknown, or malformed. The requested scope 'baz' was not originally granted by the resource owner.", - }, } type step struct { @@ -352,98 +281,281 @@ func TestRefreshTokenFlowScopeParameter(t *testing.T) { SessionAT, SessionRT fosite.Requester } - entries := make([]step, len(testCases)+1) - - resp, err := http.Get(client.AuthCodeURL(state)) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) + originalScopes := fosite.Arguments{"openid", "offline", "offline_access", "foo", "bar"} - entries[0].OAuth2, err = client.Exchange(ctx, resp.Request.URL.Query().Get("code"), oauth2.SetAuthURLParam("client_id", client.ClientID)) + scenarios := []struct { + name string + ignore bool + checkTime bool + testCases []testCase + }{ + { + "ShouldPassRFC", + false, + true, + []testCase{ + { + "ShouldGrantOriginalScopesWhenOmitted", + nil, + originalScopes, + "", + }, + { + "ShouldNarrowScopesWhenIncluded", + fosite.Arguments{"openid", "offline_access", "foo"}, + fosite.Arguments{"openid", "offline_access", "foo"}, + "", + }, + { + "ShouldGrantOriginalScopesWhenOmittedAfterNarrowing", + nil, + originalScopes, + "", + }, + { + "ShouldGrantOriginalScopesExplicitlyRequested", + originalScopes, + originalScopes, + "", + }, + { + "ShouldErrorWhenBroadeningScopesAllowedByClientButNotOriginallyGranted", + fosite.Arguments{"openid", "offline", "offline_access", "foo", "bar", "baz"}, + nil, + "The requested scope is invalid, unknown, or malformed. The requested scope 'baz' was not originally granted by the resource owner.", + }, + }, + }, + { + "ShouldPassIgnoreFilter", + true, + false, + []testCase{ + { + "ShouldGrantOriginalScopesWhenOmitted", + nil, + originalScopes, + "", + }, + { + "ShouldNarrowScopesWhenIncluded", + fosite.Arguments{"openid", "offline_access", "foo"}, + fosite.Arguments{"openid", "offline_access", "foo"}, + "", + }, + { + "ShouldGrantOriginalScopesWhenOmittedAfterNarrowing", + nil, + originalScopes, + "", + }, + { + "ShouldGrantOriginalScopesExplicitlyRequested", + originalScopes, + originalScopes, + "", + }, + { + "ShouldErrorWhenBroadeningScopesAllowedByClientButNotOriginallyGranted", + fosite.Arguments{"openid", "offline", "offline_access", "foo", "bar", "baz"}, + fosite.Arguments{"openid", "offline", "offline_access", "foo", "bar"}, + "", + }, + }, + }, + } - require.NoError(t, err) - require.NotEmpty(t, entries[0].OAuth2.AccessToken) - require.NotEmpty(t, entries[0].OAuth2.RefreshToken) + state := "1234567890" - assert.Equal(t, strings.Join(originalScopes, " "), entries[0].OAuth2.Extra("scope")) + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + ctx := context.Background() + + session := &defaultSession{ + DefaultSession: &openid.DefaultSession{ + Claims: &jwt.IDTokenClaims{ + Subject: "peter", + }, + Headers: &jwt.Headers{}, + Subject: "peter", + Username: "peteru", + }, + } - entries[0].SessionAT, err = fositeStore.GetAccessTokenSession(ctx, s.AccessTokenSignature(ctx, entries[0].OAuth2.AccessToken), nil) - require.NoError(t, err) + fc := new(fosite.Config) + fc.GlobalSecret = []byte("some-secret-thats-random-some-secret-thats-random-") + fc.ScopeStrategy = fosite.ExactScopeStrategy - entries[0].SessionRT, err = fositeStore.GetRefreshTokenSession(ctx, s.RefreshTokenSignature(ctx, entries[0].OAuth2.RefreshToken), nil) - require.NoError(t, err) + s := compose.NewOAuth2HMACStrategy(fc) - assert.ElementsMatch(t, entries[0].SessionAT.GetRequestedScopes(), originalScopes) - assert.ElementsMatch(t, entries[0].SessionRT.GetRequestedScopes(), originalScopes) - assert.ElementsMatch(t, entries[0].SessionAT.GetGrantedScopes(), originalScopes) - assert.ElementsMatch(t, entries[0].SessionRT.GetGrantedScopes(), originalScopes) - assert.Equal(t, strings.Join(originalScopes, " "), entries[0].OAuth2.Extra("scope")) + var f fosite.OAuth2Provider - for i, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - time.Sleep(time.Second) + if scenario.ignore { + fmt.Println("Ignore Mode") + keyGetter := func(context.Context) (interface{}, error) { + return gen.MustRSAKey(), nil + } - idx := i + 1 + // OAuth2RefreshTokenGrantFactory creates an OAuth2 refresh grant handler and registers + // an access token, refresh token and authorize code validator.nmj + factoryRefresh := func(config fosite.Configurator, storage interface{}, strategy interface{}) interface{} { + return &hoauth2.RefreshTokenGrantHandler{ + AccessTokenStrategy: strategy.(hoauth2.AccessTokenStrategy), + RefreshTokenStrategy: strategy.(hoauth2.RefreshTokenStrategy), + TokenRevocationStorage: storage.(hoauth2.TokenRevocationStorage), + Config: config, + IgnoreRequestedScopeNotInOriginalGrant: true, + } + } - opts := []oauth2.AuthCodeOption{ - oauth2.SetAuthURLParam("refresh_token", entries[i].OAuth2.RefreshToken), - oauth2.SetAuthURLParam("grant_type", "refresh_token"), + f = compose.Compose( + fc, + fositeStore, + &compose.CommonStrategy{ + CoreStrategy: compose.NewOAuth2HMACStrategy(fc), + OpenIDConnectTokenStrategy: compose.NewOpenIDConnectStrategy(keyGetter, fc), + Signer: &jwt.DefaultSigner{GetPrivateKey: keyGetter}, + }, + compose.OAuth2AuthorizeExplicitFactory, + compose.OAuth2AuthorizeImplicitFactory, + compose.OAuth2ClientCredentialsGrantFactory, + factoryRefresh, + compose.OAuth2ResourceOwnerPasswordCredentialsFactory, + compose.RFC7523AssertionGrantFactory, + + compose.OpenIDConnectExplicitFactory, + compose.OpenIDConnectImplicitFactory, + compose.OpenIDConnectHybridFactory, + compose.OpenIDConnectRefreshFactory, + + compose.OAuth2TokenIntrospectionFactory, + compose.OAuth2TokenRevocationFactory, + + compose.OAuth2PKCEFactory, + compose.PushedAuthorizeHandlerFactory, + ) + } else { + fmt.Println("Error Mode") + f = compose.ComposeAllEnabled(fc, fositeStore, gen.MustRSAKey()) } - if len(tc.scopes) != 0 { - opts = append(opts, oauth2.SetAuthURLParam("scope", strings.Join(tc.scopes, " ")), oauth2.SetAuthURLParam("client_id", client.ClientID)) + ts := mockServer(t, f, session) + defer ts.Close() + + client := newOAuth2Client(ts) + client.Scopes = []string{"openid", "offline", "offline_access", "foo", "bar"} + client.ClientID = "grant-all-requested-scopes-client" + + testRefreshingClient := &fosite.DefaultClient{ + ID: "grant-all-requested-scopes-client", + Secret: []byte(`$2a$10$IxMdI6d.LIRZPpSfEwNoeu4rY3FhDREsxFJXikcgdRRAStxUlsuEO`), // = "foobar" + RedirectURIs: []string{ts.URL + "/callback"}, + ResponseTypes: []string{"code"}, + GrantTypes: []string{"implicit", "refresh_token", "authorization_code", "password", "client_credentials"}, + Scopes: []string{"openid", "offline_access", "offline", "foo", "bar", "baz"}, + Audience: []string{"https://www.ory.sh/api"}, } - entries[idx].OAuth2, err = client.Exchange(ctx, "", opts...) - if len(tc.err) != 0 { - require.Error(t, err) - require.Nil(t, entries[idx].OAuth2) - require.Contains(t, err.Error(), tc.err) + fositeStore.Clients["grant-all-requested-scopes-client"] = testRefreshingClient - return - } + entries := make([]step, len(scenario.testCases)+1) + resp, err := http.Get(client.AuthCodeURL(state)) require.NoError(t, err) - require.NotEmpty(t, entries[idx].OAuth2.AccessToken) - require.NotEmpty(t, entries[idx].OAuth2.RefreshToken) + require.Equal(t, http.StatusOK, resp.StatusCode) - entries[idx].SessionAT, err = fositeStore.GetAccessTokenSession(ctx, s.AccessTokenSignature(ctx, entries[idx].OAuth2.AccessToken), nil) - require.NoError(t, err) + entries[0].OAuth2, err = client.Exchange(ctx, resp.Request.URL.Query().Get("code"), oauth2.SetAuthURLParam("client_id", client.ClientID)) - entries[idx].SessionRT, err = fositeStore.GetRefreshTokenSession(ctx, s.RefreshTokenSignature(ctx, entries[idx].OAuth2.RefreshToken), nil) require.NoError(t, err) + require.NotEmpty(t, entries[0].OAuth2.AccessToken) + require.NotEmpty(t, entries[0].OAuth2.RefreshToken) - if len(tc.scopes) != 0 { - assert.ElementsMatch(t, entries[idx].SessionAT.GetRequestedScopes(), tc.scopes) - assert.Equal(t, strings.Join(tc.expected, " "), entries[idx].OAuth2.Extra("scope")) - } else { - assert.ElementsMatch(t, entries[idx].SessionAT.GetRequestedScopes(), originalScopes) - assert.Equal(t, strings.Join(originalScopes, " "), entries[idx].OAuth2.Extra("scope")) - } - assert.ElementsMatch(t, entries[idx].SessionAT.GetGrantedScopes(), tc.expected) - assert.ElementsMatch(t, entries[idx].SessionRT.GetRequestedScopes(), originalScopes) - assert.ElementsMatch(t, entries[idx].SessionRT.GetGrantedScopes(), originalScopes) - - var ( - j int - entry step - ) + assert.Equal(t, strings.Join(originalScopes, " "), entries[0].OAuth2.Extra("scope")) - assert.Equal(t, entries[idx].SessionAT.GetID(), entries[idx].SessionRT.GetID()) - - for j, entry = range entries { - if j == idx { - break - } + entries[0].SessionAT, err = fositeStore.GetAccessTokenSession(ctx, s.AccessTokenSignature(ctx, entries[0].OAuth2.AccessToken), nil) + require.NoError(t, err) - assert.Equal(t, entries[idx].SessionAT.GetID(), entry.SessionAT.GetID()) - assert.Equal(t, entries[idx].SessionAT.GetID(), entry.SessionRT.GetID()) - assert.Equal(t, entries[idx].SessionRT.GetID(), entry.SessionAT.GetID()) - assert.Equal(t, entries[idx].SessionRT.GetID(), entry.SessionRT.GetID()) + entries[0].SessionRT, err = fositeStore.GetRefreshTokenSession(ctx, s.RefreshTokenSignature(ctx, entries[0].OAuth2.RefreshToken), nil) + require.NoError(t, err) - assert.Greater(t, entries[idx].SessionAT.GetSession().GetExpiresAt(fosite.AccessToken).Unix(), entry.SessionAT.GetSession().GetExpiresAt(fosite.AccessToken).Unix()) - assert.Greater(t, entries[idx].SessionRT.GetSession().GetExpiresAt(fosite.RefreshToken).Unix(), entry.SessionRT.GetSession().GetExpiresAt(fosite.RefreshToken).Unix()) - assert.Greater(t, entries[idx].SessionAT.GetRequestedAt().Unix(), entry.SessionAT.GetRequestedAt().Unix()) - assert.Greater(t, entries[idx].SessionRT.GetRequestedAt().Unix(), entry.SessionRT.GetRequestedAt().Unix()) + assert.ElementsMatch(t, entries[0].SessionAT.GetRequestedScopes(), originalScopes) + assert.ElementsMatch(t, entries[0].SessionRT.GetRequestedScopes(), originalScopes) + assert.ElementsMatch(t, entries[0].SessionAT.GetGrantedScopes(), originalScopes) + assert.ElementsMatch(t, entries[0].SessionRT.GetGrantedScopes(), originalScopes) + assert.Equal(t, strings.Join(originalScopes, " "), entries[0].OAuth2.Extra("scope")) + + for i, tc := range scenario.testCases { + t.Run(tc.name, func(t *testing.T) { + if scenario.checkTime { + time.Sleep(time.Second) + } + + idx := i + 1 + + opts := []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("refresh_token", entries[i].OAuth2.RefreshToken), + oauth2.SetAuthURLParam("grant_type", "refresh_token"), + } + + if len(tc.scopes) != 0 { + opts = append(opts, oauth2.SetAuthURLParam("scope", strings.Join(tc.scopes, " ")), oauth2.SetAuthURLParam("client_id", client.ClientID)) + } + + entries[idx].OAuth2, err = client.Exchange(ctx, "", opts...) + if len(tc.err) != 0 { + require.Error(t, err) + require.Nil(t, entries[idx].OAuth2) + require.Contains(t, err.Error(), tc.err) + + return + } + + require.NoError(t, err) + require.NotEmpty(t, entries[idx].OAuth2.AccessToken) + require.NotEmpty(t, entries[idx].OAuth2.RefreshToken) + + entries[idx].SessionAT, err = fositeStore.GetAccessTokenSession(ctx, s.AccessTokenSignature(ctx, entries[idx].OAuth2.AccessToken), nil) + require.NoError(t, err) + + entries[idx].SessionRT, err = fositeStore.GetRefreshTokenSession(ctx, s.RefreshTokenSignature(ctx, entries[idx].OAuth2.RefreshToken), nil) + require.NoError(t, err) + + if len(tc.scopes) != 0 { + assert.ElementsMatch(t, entries[idx].SessionAT.GetRequestedScopes(), tc.scopes) + assert.Equal(t, strings.Join(tc.expected, " "), entries[idx].OAuth2.Extra("scope")) + } else { + assert.ElementsMatch(t, entries[idx].SessionAT.GetRequestedScopes(), originalScopes) + assert.Equal(t, strings.Join(originalScopes, " "), entries[idx].OAuth2.Extra("scope")) + } + assert.ElementsMatch(t, entries[idx].SessionAT.GetGrantedScopes(), tc.expected) + assert.ElementsMatch(t, entries[idx].SessionRT.GetRequestedScopes(), originalScopes) + assert.ElementsMatch(t, entries[idx].SessionRT.GetGrantedScopes(), originalScopes) + + var ( + j int + entry step + ) + + assert.Equal(t, entries[idx].SessionAT.GetID(), entries[idx].SessionRT.GetID()) + + for j, entry = range entries { + if j == idx { + break + } + + assert.Equal(t, entries[idx].SessionAT.GetID(), entry.SessionAT.GetID()) + assert.Equal(t, entries[idx].SessionAT.GetID(), entry.SessionRT.GetID()) + assert.Equal(t, entries[idx].SessionRT.GetID(), entry.SessionAT.GetID()) + assert.Equal(t, entries[idx].SessionRT.GetID(), entry.SessionRT.GetID()) + + if scenario.checkTime { + assert.Greater(t, entries[idx].SessionAT.GetSession().GetExpiresAt(fosite.AccessToken).Unix(), entry.SessionAT.GetSession().GetExpiresAt(fosite.AccessToken).Unix()) + assert.Greater(t, entries[idx].SessionRT.GetSession().GetExpiresAt(fosite.RefreshToken).Unix(), entry.SessionRT.GetSession().GetExpiresAt(fosite.RefreshToken).Unix()) + assert.Greater(t, entries[idx].SessionAT.GetRequestedAt().Unix(), entry.SessionAT.GetRequestedAt().Unix()) + assert.Greater(t, entries[idx].SessionRT.GetRequestedAt().Unix(), entry.SessionRT.GetRequestedAt().Unix()) + } + } + }) } }) }