From 178e39035ec55453219c501923c5db741d3caa7e Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Thu, 6 Oct 2022 15:48:00 -0700 Subject: [PATCH] Make call to auth metadata optional (#327) * Make call to auth metadata optional Signed-off-by: Katrina Rogan * debug Signed-off-by: Katrina Rogan * revert Signed-off-by: Katrina Rogan * undeprecate Signed-off-by: Katrina Rogan * Add unit tests Signed-off-by: Katrina Rogan * codecov is not very good Signed-off-by: Katrina Rogan Signed-off-by: Katrina Rogan --- flyteidl/clients/go/admin/auth_interceptor.go | 12 +++-- .../clients/go/admin/auth_interceptor_test.go | 51 +++++++++++++++++++ flyteidl/clients/go/admin/client.go | 12 +++-- flyteidl/clients/go/admin/client_test.go | 31 +++++++++++ flyteidl/clients/go/admin/config.go | 3 +- flyteidl/clients/go/admin/config_flags.go | 2 +- .../clients/go/admin/config_flags_test.go | 2 +- flyteidl/clients/go/admin/integration_test.go | 3 +- flyteidl/go.mod | 1 + flyteidl/go.sum | 2 + 10 files changed, 106 insertions(+), 13 deletions(-) diff --git a/flyteidl/clients/go/admin/auth_interceptor.go b/flyteidl/clients/go/admin/auth_interceptor.go index 400293e937..b1ede68dbd 100644 --- a/flyteidl/clients/go/admin/auth_interceptor.go +++ b/flyteidl/clients/go/admin/auth_interceptor.go @@ -27,9 +27,13 @@ func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.T return fmt.Errorf("failed to initialized token source provider. Err: %w", err) } - clientMetadata, err := authMetadataClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{}) - if err != nil { - return fmt.Errorf("failed to fetch client metadata. Error: %v", err) + authorizationMetadataKey := cfg.AuthorizationHeader + if len(authorizationMetadataKey) == 0 { + clientMetadata, err := authMetadataClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{}) + if err != nil { + return fmt.Errorf("failed to fetch client metadata. Error: %v", err) + } + authorizationMetadataKey = clientMetadata.AuthorizationMetadataKey } tokenSource, err := tokenSourceProvider.GetTokenSource(ctx) @@ -37,7 +41,7 @@ func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.T return err } - wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, clientMetadata.AuthorizationMetadataKey) + wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, authorizationMetadataKey) perRPCCredentials.Store(wrappedTokenSource) return nil } diff --git a/flyteidl/clients/go/admin/auth_interceptor_test.go b/flyteidl/clients/go/admin/auth_interceptor_test.go index 6cbf5441e5..0299269025 100644 --- a/flyteidl/clients/go/admin/auth_interceptor_test.go +++ b/flyteidl/clients/go/admin/auth_interceptor_test.go @@ -2,6 +2,7 @@ package admin import ( "context" + "errors" "fmt" "io" "net" @@ -233,3 +234,53 @@ func Test_newAuthInterceptor(t *testing.T) { assert.Falsef(t, f.IsInitialized(), "PerRPCCredentialFuture should not be initialized") }) } + +func TestMaterializeCredentials(t *testing.T) { + port := rand.IntnRange(10000, 60000) + t.Run("No public client config or oauth2 metadata endpoint lookup", func(t *testing.T) { + m := &mocks2.AuthMetadataServiceServer{} + m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata")) + m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get public client config")) + s := newAuthMetadataServer(t, port, m) + ctx := context.Background() + assert.NoError(t, s.Start(ctx)) + defer s.Close() + + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) + assert.NoError(t, err) + + f := NewPerRPCCredentialsFuture() + err = MaterializeCredentials(ctx, &Config{ + Endpoint: config.URL{URL: *u}, + UseInsecureConnection: true, + AuthType: AuthTypeClientSecret, + TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", port), + Scopes: []string{"all"}, + AuthorizationHeader: "authorization", + }, &mocks.TokenCache{}, f) + assert.NoError(t, err) + }) + t.Run("Failed to fetch client metadata", func(t *testing.T) { + m := &mocks2.AuthMetadataServiceServer{} + m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata")) + failedPublicClientConfigLookup := errors.New("expected err") + m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(nil, failedPublicClientConfigLookup) + s := newAuthMetadataServer(t, port, m) + ctx := context.Background() + assert.NoError(t, s.Start(ctx)) + defer s.Close() + + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) + assert.NoError(t, err) + + f := NewPerRPCCredentialsFuture() + err = MaterializeCredentials(ctx, &Config{ + Endpoint: config.URL{URL: *u}, + UseInsecureConnection: true, + AuthType: AuthTypeClientSecret, + TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", port), + Scopes: []string{"all"}, + }, &mocks.TokenCache{}, f) + assert.EqualError(t, err, "failed to fetch client metadata. Error: rpc error: code = Unknown desc = expected err") + }) +} diff --git a/flyteidl/clients/go/admin/client.go b/flyteidl/clients/go/admin/client.go index 52959ceb33..11095ebce8 100644 --- a/flyteidl/clients/go/admin/client.go +++ b/flyteidl/clients/go/admin/client.go @@ -86,9 +86,13 @@ func getAuthenticationDialOption(ctx context.Context, cfg *Config, tokenSourcePr return nil, errors.New("can't create authenticated channel without a TokenSourceProvider") } - clientMetadata, err := authClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{}) - if err != nil { - return nil, fmt.Errorf("failed to fetch client metadata. Error: %v", err) + authorizationMetadataKey := cfg.AuthorizationHeader + if len(authorizationMetadataKey) == 0 { + clientMetadata, err := authClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{}) + if err != nil { + return nil, fmt.Errorf("failed to fetch client metadata. Error: %v", err) + } + authorizationMetadataKey = clientMetadata.AuthorizationMetadataKey } tokenSource, err := tokenSourceProvider.GetTokenSource(ctx) @@ -96,7 +100,7 @@ func getAuthenticationDialOption(ctx context.Context, cfg *Config, tokenSourcePr return nil, err } - wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, clientMetadata.AuthorizationMetadataKey) + wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, authorizationMetadataKey) return grpc.WithPerRPCCredentials(wrappedTokenSource), nil } diff --git a/flyteidl/clients/go/admin/client_test.go b/flyteidl/clients/go/admin/client_test.go index 469ebaea35..017f4e8ff8 100644 --- a/flyteidl/clients/go/admin/client_test.go +++ b/flyteidl/clients/go/admin/client_test.go @@ -3,6 +3,7 @@ package admin import ( "context" "encoding/json" + "errors" "fmt" "io/ioutil" "net/http" @@ -11,6 +12,7 @@ import ( "testing" "time" + "github.com/jinzhu/copier" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "golang.org/x/oauth2" @@ -135,6 +137,20 @@ func TestGetAuthenticationDialOptionClientSecret(t *testing.T) { assert.Nil(t, dialOption) assert.NotNil(t, err) }) + t.Run("legal-no-external-calls", func(t *testing.T) { + mockAuthClient := new(mocks.AuthMetadataServiceClient) + mockAuthClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata")) + mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get public client config")) + var adminCfg Config + err := copier.Copy(&adminCfg, adminServiceConfig) + assert.NoError(t, err) + adminCfg.TokenURL = "http://localhost:1000/api/v1/token" + adminCfg.Scopes = []string{"all"} + adminCfg.AuthorizationHeader = "authorization" + dialOption, err := getAuthenticationDialOption(ctx, &adminCfg, nil, mockAuthClient) + assert.Nil(t, dialOption) + assert.NotNil(t, err) + }) t.Run("error during oauth2Metatdata", func(t *testing.T) { mockAuthClient := new(mocks.AuthMetadataServiceClient) mockAuthClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("failed")) @@ -142,6 +158,21 @@ func TestGetAuthenticationDialOptionClientSecret(t *testing.T) { assert.Nil(t, dialOption) assert.NotNil(t, err) }) + t.Run("error during public client config", func(t *testing.T) { + mockAuthClient := new(mocks.AuthMetadataServiceClient) + mockAuthClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata")) + failedPublicClientConfigLookup := errors.New("expected err") + mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(nil, failedPublicClientConfigLookup) + var adminCfg Config + err := copier.Copy(&adminCfg, adminServiceConfig) + assert.NoError(t, err) + adminCfg.TokenURL = "http://localhost:1000/api/v1/token" + adminCfg.Scopes = []string{"all"} + tokenProvider := ClientCredentialsTokenSourceProvider{} + dialOption, err := getAuthenticationDialOption(ctx, &adminCfg, tokenProvider, mockAuthClient) + assert.Nil(t, dialOption) + assert.EqualError(t, err, "failed to fetch client metadata. Error: expected err") + }) t.Run("error during flyte client", func(t *testing.T) { metatdata := &service.OAuth2MetadataResponse{ TokenEndpoint: "/token", diff --git a/flyteidl/clients/go/admin/config.go b/flyteidl/clients/go/admin/config.go index d7d0abcf71..b776451c40 100644 --- a/flyteidl/clients/go/admin/config.go +++ b/flyteidl/clients/go/admin/config.go @@ -64,8 +64,7 @@ type Config struct { // See the implementation of the 'grpcAuthorizationHeader' option in Flyte Admin for more information. But // basically we want to be able to use a different string to pass the token from this client to the the Admin service // because things might be running in a service mesh (like Envoy) that already uses the default 'authorization' header - // Deprecated: It will automatically be discovered through an anonymously accessible auth metadata service. - DeprecatedAuthorizationHeader string `json:"authorizationHeader" pflag:",Custom metadata header to pass JWT"` + AuthorizationHeader string `json:"authorizationHeader" pflag:",Custom metadata header to pass JWT"` PkceConfig pkce.Config `json:"pkceConfig" pflag:",Config for Pkce authentication flow."` diff --git a/flyteidl/clients/go/admin/config_flags.go b/flyteidl/clients/go/admin/config_flags.go index 6d4ffd8091..6d65d30f81 100755 --- a/flyteidl/clients/go/admin/config_flags.go +++ b/flyteidl/clients/go/admin/config_flags.go @@ -66,7 +66,7 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "scopes"), defaultConfig.Scopes, "List of scopes to request") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authorizationServerUrl"), defaultConfig.DeprecatedAuthorizationServerURL, "This is the URL to your IdP's authorization server. It'll default to Endpoint") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "tokenUrl"), defaultConfig.TokenURL, "OPTIONAL: Your IdP's token endpoint. It'll be discovered from flyte admin's OAuth Metadata endpoint if not provided.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authorizationHeader"), defaultConfig.DeprecatedAuthorizationHeader, "Custom metadata header to pass JWT") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authorizationHeader"), defaultConfig.AuthorizationHeader, "Custom metadata header to pass JWT") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "pkceConfig.timeout"), defaultConfig.PkceConfig.BrowserSessionTimeout.String(), "Amount of time the browser session would be active for authentication from client app.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "pkceConfig.refreshTime"), defaultConfig.PkceConfig.TokenRefreshGracePeriod.String(), "grace period from the token expiry after which it would refresh the token.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "deviceFlowConfig.refreshTime"), defaultConfig.DeviceFlowConfig.TokenRefreshGracePeriod.String(), "grace period from the token expiry after which it would refresh the token.") diff --git a/flyteidl/clients/go/admin/config_flags_test.go b/flyteidl/clients/go/admin/config_flags_test.go index cf97cdd92c..a44948d873 100755 --- a/flyteidl/clients/go/admin/config_flags_test.go +++ b/flyteidl/clients/go/admin/config_flags_test.go @@ -330,7 +330,7 @@ func TestConfig_SetFlags(t *testing.T) { cmdFlags.Set("authorizationHeader", testValue) if vString, err := cmdFlags.GetString("authorizationHeader"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.DeprecatedAuthorizationHeader) + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AuthorizationHeader) } else { assert.FailNow(t, err.Error()) diff --git a/flyteidl/clients/go/admin/integration_test.go b/flyteidl/clients/go/admin/integration_test.go index 3349bbe203..366f086f27 100644 --- a/flyteidl/clients/go/admin/integration_test.go +++ b/flyteidl/clients/go/admin/integration_test.go @@ -1,3 +1,4 @@ +//go:build integration // +build integration package admin @@ -31,7 +32,7 @@ func TestLiveAdminClient(t *testing.T) { ClientSecretLocation: "/Users/username/.ssh/admin/propeller_secret", DeprecatedAuthorizationServerURL: "https://lyft.okta.com/oauth2/ausc5wmjw96cRKvTd1t7", Scopes: []string{"svc"}, - DeprecatedAuthorizationHeader: "Flyte-Authorization", + AuthorizationHeader: "Flyte-Authorization", }) resp, err := client.ListProjects(ctx, &admin.ProjectListRequest{}) diff --git a/flyteidl/go.mod b/flyteidl/go.mod index d84d7370c6..65df4ab4b7 100644 --- a/flyteidl/go.mod +++ b/flyteidl/go.mod @@ -11,6 +11,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/grpc-ecosystem/grpc-gateway v1.16.0 + github.com/jinzhu/copier v0.3.5 github.com/mitchellh/mapstructure v1.4.1 github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 github.com/pkg/errors v0.9.1 diff --git a/flyteidl/go.sum b/flyteidl/go.sum index 60e1c96bb3..3064296375 100644 --- a/flyteidl/go.sum +++ b/flyteidl/go.sum @@ -338,6 +338,8 @@ github.com/imdario/mergo v0.3.5/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJ github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= +github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg= +github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.3/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=