Skip to content

Commit

Permalink
enable standard forms of GCP auth for oci sources
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Stevens <[email protected]>
  • Loading branch information
thejosephstevens committed Oct 16, 2024
1 parent ac1007b commit d48b493
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 147 deletions.
79 changes: 33 additions & 46 deletions oci/auth/gcp/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@ package gcp

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
Expand All @@ -31,6 +28,8 @@ import (
"github.com/google/go-containerregistry/pkg/name"

"github.com/fluxcd/pkg/oci"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)

type gceToken struct {
Expand All @@ -50,8 +49,8 @@ func ValidHost(host string) bool {
// Client is a GCP GCR client which can log into the registry and return
// authorization information.
type Client struct {
tokenURL string
proxyURL *url.URL
proxyURL *url.URL
tokenSource oauth2.TokenSource
}

// Option is a functional option for configuring the client.
Expand All @@ -64,69 +63,57 @@ func WithProxyURL(proxyURL *url.URL) Option {
}
}

// WithTokenSource sets a custom token source for the client.
func (c *Client) WithTokenSource(ts oauth2.TokenSource) *Client {
c.tokenSource = ts
return c
}

// NewClient creates a new GCR client with default configurations.
func NewClient(opts ...Option) *Client {
client := &Client{tokenURL: GCP_TOKEN_URL}
client := &Client{}
for _, opt := range opts {
opt(client)
}
return client
}

// WithTokenURL sets the token URL used by the GCR client.
func (c *Client) WithTokenURL(url string) *Client {
c.tokenURL = url
return c
}

// getLoginAuth obtains authentication by getting a token from the metadata API
// on GCP. This assumes that the pod has right to pull the image which would be
// the case if it is hosted on GCP. It works with both service account and
// workload identity enabled clusters.
// getLoginAuth obtains authentication using the default GCP credential chain.
// This supports various authentication methods including service account JSON,
// external account JSON, user credentials, and GCE metadata service.
func (c *Client) getLoginAuth(ctx context.Context) (authn.AuthConfig, time.Time, error) {
var authConfig authn.AuthConfig

request, err := http.NewRequestWithContext(ctx, http.MethodGet, c.tokenURL, nil)
if err != nil {
return authConfig, time.Time{}, err
}
// Define the required scopes for accessing GCR.
scopes := []string{"https://www.googleapis.com/auth/cloud-platform"}

request.Header.Add("Metadata-Flavor", "Google")
var tokenSource oauth2.TokenSource
var err error

var transport http.RoundTripper
if c.proxyURL != nil {
t := http.DefaultTransport.(*http.Transport).Clone()
t.Proxy = http.ProxyURL(c.proxyURL)
transport = t
// Use the injected token source if available; otherwise, use the default.
if c.tokenSource != nil {
tokenSource = c.tokenSource
} else {
// Obtain the default token source.
tokenSource, err = google.DefaultTokenSource(ctx, scopes...)
if err != nil {
return authConfig, time.Time{}, fmt.Errorf("failed to get default token source: %w", err)
}
}

client := &http.Client{Transport: transport}
response, err := client.Do(request)
// Retrieve the token.
token, err := tokenSource.Token()
if err != nil {
return authConfig, time.Time{}, err
}
defer response.Body.Close()
defer io.Copy(io.Discard, response.Body)

if response.StatusCode != http.StatusOK {
return authConfig, time.Time{}, fmt.Errorf("unexpected status from metadata service: %s", response.Status)
}

var accessToken gceToken
decoder := json.NewDecoder(response.Body)
if err := decoder.Decode(&accessToken); err != nil {
return authConfig, time.Time{}, err
return authConfig, time.Time{}, fmt.Errorf("failed to obtain token: %w", err)
}

// Set up the authentication configuration.
authConfig = authn.AuthConfig{
Username: "oauth2accesstoken",
Password: accessToken.AccessToken,
Password: token.AccessToken,
}

// add expiresIn seconds to the current time to get the expiry time
expiresAt := time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)

return authConfig, expiresAt, nil
return authConfig, token.Expiry, nil
}

// Login attempts to get the authentication material for GCR.
Expand Down
134 changes: 68 additions & 66 deletions oci/auth/gcp/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,72 +18,69 @@ package gcp

import (
"context"
"net/http"
"net/http/httptest"
"fmt"
"testing"
"time"

"github.com/google/go-containerregistry/pkg/authn"
"github.com/google/go-containerregistry/pkg/name"
. "github.com/onsi/gomega"
"golang.org/x/oauth2"
)

const testValidGCRImage = "gcr.io/foo/bar:v1"

type fakeTokenSource struct {
token *oauth2.Token
err error
}

func (f *fakeTokenSource) Token() (*oauth2.Token, error) {
return f.token, f.err
}

func TestGetLoginAuth(t *testing.T) {
tests := []struct {
name string
responseBody string
statusCode int
token *oauth2.Token
tokenErr error
wantErr bool
wantAuthConfig authn.AuthConfig
}{
{
name: "success",
responseBody: `{
"access_token": "some-token",
"expires_in": 10,
"token_type": "foo"
}`,
statusCode: http.StatusOK,
token: &oauth2.Token{
AccessToken: "some-token",
TokenType: "Bearer",
Expiry: time.Now().Add(10 * time.Second),
},
wantAuthConfig: authn.AuthConfig{
Username: "oauth2accesstoken",
Password: "some-token",
},
},
{
name: "fail",
statusCode: http.StatusInternalServerError,
wantErr: true,
},
{
name: "invalid response",
responseBody: "foo",
statusCode: http.StatusOK,
wantErr: true,
name: "fail",
tokenErr: fmt.Errorf("token error"),
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := NewWithT(t)

handler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.statusCode)
w.Write([]byte(tt.responseBody))
// Create fake token source
fakeTS := &fakeTokenSource{
token: tt.token,
err: tt.tokenErr,
}
srv := httptest.NewServer(http.HandlerFunc(handler))
t.Cleanup(func() {
srv.Close()
})

gc := NewClient().WithTokenURL(srv.URL)
gc := NewClient().WithTokenSource(fakeTS)
a, expiresAt, err := gc.getLoginAuth(context.TODO())
g.Expect(err != nil).To(Equal(tt.wantErr))
if !tt.wantErr {
g.Expect(expiresAt).To(BeTemporally("~", time.Now().Add(10*time.Second), time.Second))
}
if tt.statusCode == http.StatusOK {
g.Expect(expiresAt).To(BeTemporally("~", tt.token.Expiry, time.Second))
g.Expect(a).To(Equal(tt.wantAuthConfig))
}
})
Expand Down Expand Up @@ -111,60 +108,65 @@ func TestValidHost(t *testing.T) {

func TestLogin(t *testing.T) {
tests := []struct {
name string
autoLogin bool
image string
statusCode int
testOIDC bool
wantErr bool
name string
autoLogin bool
image string
token *oauth2.Token
tokenErr error
testOIDC bool
wantErr bool
}{
{
name: "no auto login",
autoLogin: false,
image: testValidGCRImage,
statusCode: http.StatusOK,
wantErr: true,
name: "no auto login",
autoLogin: false,
image: testValidGCRImage,
wantErr: true,
},
{
name: "with auto login",
autoLogin: true,
image: testValidGCRImage,
testOIDC: true,
statusCode: http.StatusOK,
name: "with auto login",
autoLogin: true,
image: testValidGCRImage,
testOIDC: true,
token: &oauth2.Token{
AccessToken: "some-token",
TokenType: "Bearer",
Expiry: time.Now().Add(10 * time.Second),
},
},
{
name: "login failure",
autoLogin: true,
image: testValidGCRImage,
statusCode: http.StatusInternalServerError,
testOIDC: true,
wantErr: true,
name: "login failure",
autoLogin: true,
image: testValidGCRImage,
tokenErr: fmt.Errorf("token error"),
testOIDC: true,
wantErr: true,
},
{
name: "non GCR image",
autoLogin: true,
image: "foo/bar:v1",
statusCode: http.StatusOK,
name: "non GCR image",
autoLogin: true,
image: "foo/bar:v1",
token: &oauth2.Token{
AccessToken: "some-token",
TokenType: "Bearer",
Expiry: time.Now().Add(10 * time.Second),
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := NewWithT(t)

handler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.statusCode)
w.Write([]byte(`{"access_token": "some-token","expires_in": 10, "token_type": "foo"}`))
}
srv := httptest.NewServer(http.HandlerFunc(handler))
t.Cleanup(func() {
srv.Close()
})

ref, err := name.ParseReference(tt.image)
g.Expect(err).ToNot(HaveOccurred())

gc := NewClient().WithTokenURL(srv.URL)
// Create fake token source
fakeTS := &fakeTokenSource{
token: tt.token,
err: tt.tokenErr,
}

gc := NewClient().WithTokenSource(fakeTS)

_, err = gc.Login(context.TODO(), tt.autoLogin, tt.image, ref)
g.Expect(err != nil).To(Equal(tt.wantErr))
Expand Down
Loading

0 comments on commit d48b493

Please sign in to comment.