-
-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e8e33ab
commit cb01b09
Showing
17 changed files
with
858 additions
and
360 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
package jwkset | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"log/slog" | ||
"net/url" | ||
"time" | ||
) | ||
|
||
var ( | ||
// ErrNewClient fails to create a new JWK Set client. | ||
ErrNewClient = errors.New("failed to create new JWK Set client") | ||
) | ||
|
||
// HTTPClientOptions are options for creating a new JWK Set client. | ||
type HTTPClientOptions struct { | ||
// Given contains keys known from outside HTTP URLs. | ||
Given Storage | ||
// HTTPURLs are a mapping of HTTP URLs to JWK Set endpoints to storage implementations for the keys located at the | ||
// URL. If empty, HTTP will not be used. | ||
HTTPURLs map[string]Storage | ||
// PrioritizeHTTP is a flag that indicates whether keys from the HTTP URL should be prioritized over keys from the | ||
// given storage. | ||
PrioritizeHTTP bool | ||
} | ||
|
||
// Client is a JWK Set client. | ||
type httpClient struct { | ||
given Storage | ||
httpURLs map[string]Storage | ||
prioritizeHTTP bool | ||
} | ||
|
||
// NewHTTPClient creates a new JWK Set client from remote HTTP resources. | ||
func NewHTTPClient(options HTTPClientOptions) (Storage, error) { | ||
if options.Given == nil && len(options.HTTPURLs) == 0 { | ||
return nil, fmt.Errorf("%w: no given keys or HTTP URLs", ErrNewClient) | ||
} | ||
for u, store := range options.HTTPURLs { | ||
if store == nil { | ||
options.HTTPURLs[u] = NewMemoryStorage() | ||
} | ||
} | ||
given := options.Given | ||
if given == nil { | ||
given = NewMemoryStorage() | ||
} | ||
c := httpClient{ | ||
given: given, | ||
httpURLs: options.HTTPURLs, | ||
prioritizeHTTP: options.PrioritizeHTTP, | ||
} | ||
return c, nil | ||
} | ||
|
||
// NewDefaultHTTPClient creates a new JWK Set client with default options from remote HTTP resources. | ||
func NewDefaultHTTPClient(urls []string) (Storage, error) { | ||
clientOptions := HTTPClientOptions{ | ||
HTTPURLs: make(map[string]Storage), | ||
} | ||
for _, u := range urls { | ||
parsed, err := url.ParseRequestURI(u) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to parse given URL %q: %w", u, errors.Join(err, ErrNewClient)) | ||
} | ||
u = parsed.String() | ||
refreshErrorHandler := func(ctx context.Context, err error) { | ||
slog.Default().ErrorContext(ctx, "Failed to refresh HTTP JWK Set from remote HTTP resource.", | ||
"error", err, | ||
"url", u, | ||
) | ||
} | ||
options := HTTPClientStorageOptions{ | ||
NoErrorReturnFirstHTTPReq: true, | ||
RefreshErrorHandler: refreshErrorHandler, | ||
RefreshInterval: time.Hour, | ||
} | ||
c, err := NewStorageFromHTTP(parsed, options) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to create HTTP client storage for %q: %w", u, errors.Join(err, ErrNewClient)) | ||
} | ||
clientOptions.HTTPURLs[u] = c | ||
} | ||
return NewHTTPClient(clientOptions) | ||
} | ||
|
||
func (c httpClient) KeyDelete(ctx context.Context, keyID string) (ok bool, err error) { | ||
ok, err = c.given.KeyDelete(ctx, keyID) | ||
if err != nil && !errors.Is(err, ErrKeyNotFound) { | ||
return false, fmt.Errorf("failed to delete key with ID %q from given storage due to error: %w", keyID, err) | ||
} | ||
if ok { | ||
return true, nil | ||
} | ||
for _, store := range c.httpURLs { | ||
ok, err = store.KeyDelete(ctx, keyID) | ||
if err != nil && !errors.Is(err, ErrKeyNotFound) { | ||
return false, fmt.Errorf("failed to delete key with ID %q from HTTP storage due to error: %w", keyID, err) | ||
} | ||
if ok { | ||
return true, nil | ||
} | ||
} | ||
return false, nil | ||
} | ||
func (c httpClient) KeyRead(ctx context.Context, keyID string) (jwk JWK, err error) { | ||
if !c.prioritizeHTTP { | ||
jwk, err = c.given.KeyRead(ctx, keyID) | ||
switch { | ||
case errors.Is(err, ErrKeyNotFound): | ||
// Do nothing. | ||
case err != nil: | ||
return JWK{}, fmt.Errorf("failed to find JWT key with ID %q in given storage due to error: %w", keyID, err) | ||
default: | ||
return jwk, nil | ||
} | ||
} | ||
for _, store := range c.httpURLs { | ||
jwk, err = store.KeyRead(ctx, keyID) | ||
switch { | ||
case errors.Is(err, ErrKeyNotFound): | ||
continue | ||
case err != nil: | ||
return JWK{}, fmt.Errorf("failed to find JWT key with ID %q in HTTP storage due to error: %w", keyID, err) | ||
default: | ||
return jwk, nil | ||
} | ||
} | ||
if c.prioritizeHTTP { | ||
jwk, err = c.given.KeyRead(ctx, keyID) | ||
switch { | ||
case errors.Is(err, ErrKeyNotFound): | ||
// Do nothing. | ||
case err != nil: | ||
return JWK{}, fmt.Errorf("failed to find JWT key with ID %q in given storage due to error: %w", keyID, err) | ||
default: | ||
return jwk, nil | ||
} | ||
} | ||
return JWK{}, fmt.Errorf("%w %q", ErrKeyNotFound, keyID) | ||
} | ||
func (c httpClient) KeyReadAll(ctx context.Context) ([]JWK, error) { | ||
jwks, err := c.given.KeyReadAll(ctx) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to snapshot given keys due to error: %w", err) | ||
} | ||
for u, store := range c.httpURLs { | ||
j, err := store.KeyReadAll(ctx) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to snapshot HTTP keys from %q due to error: %w", u, err) | ||
} | ||
jwks = append(jwks, j...) | ||
} | ||
return jwks, nil | ||
} | ||
func (c httpClient) KeyWrite(ctx context.Context, jwk JWK) error { | ||
return c.given.KeyWrite(ctx, jwk) | ||
} | ||
|
||
func (c httpClient) JSON(ctx context.Context) (json.RawMessage, error) { | ||
m, err := c.combineStorage(ctx) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to combine storage due to error: %w", err) | ||
} | ||
return m.JSON(ctx) | ||
} | ||
func (c httpClient) JSONPublic(ctx context.Context) (json.RawMessage, error) { | ||
m, err := c.combineStorage(ctx) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to combine storage due to error: %w", err) | ||
} | ||
return m.JSONPublic(ctx) | ||
} | ||
func (c httpClient) JSONPrivate(ctx context.Context) (json.RawMessage, error) { | ||
m, err := c.combineStorage(ctx) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to combine storage due to error: %w", err) | ||
} | ||
return m.JSONPrivate(ctx) | ||
} | ||
func (c httpClient) JSONWithOptions(ctx context.Context, marshalOptions JWKMarshalOptions, validationOptions JWKValidateOptions) (json.RawMessage, error) { | ||
m, err := c.combineStorage(ctx) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to combine storage due to error: %w", err) | ||
} | ||
return m.JSONWithOptions(ctx, marshalOptions, validationOptions) | ||
} | ||
func (c httpClient) Marshal(ctx context.Context) (JWKSMarshal, error) { | ||
m, err := c.combineStorage(ctx) | ||
if err != nil { | ||
return JWKSMarshal{}, fmt.Errorf("failed to combine storage due to error: %w", err) | ||
} | ||
return m.Marshal(ctx) | ||
} | ||
func (c httpClient) MarshalWithOptions(ctx context.Context, marshalOptions JWKMarshalOptions, validationOptions JWKValidateOptions) (JWKSMarshal, error) { | ||
m, err := c.combineStorage(ctx) | ||
if err != nil { | ||
return JWKSMarshal{}, fmt.Errorf("failed to combine storage due to error: %w", err) | ||
} | ||
return m.MarshalWithOptions(ctx, marshalOptions, validationOptions) | ||
} | ||
|
||
func (c httpClient) combineStorage(ctx context.Context) (Storage, error) { | ||
jwks, err := c.KeyReadAll(ctx) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to snapshot keys due to error: %w", err) | ||
} | ||
m := NewMemoryStorage() | ||
for _, jwk := range jwks { | ||
err = m.KeyWrite(ctx, jwk) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to write key to memory storage due to error: %w", err) | ||
} | ||
} | ||
return m, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
package jwkset | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
) | ||
|
||
func TestClient(t *testing.T) { | ||
ctx, cancel := context.WithCancel(context.Background()) | ||
defer cancel() | ||
|
||
kid := "my-key-id" | ||
secret := []byte("my-hmac-secret") | ||
serverStore := NewMemoryStorage() | ||
marshalOptions := JWKMarshalOptions{ | ||
Private: true, | ||
} | ||
metadata := JWKMetadataOptions{ | ||
KID: kid, | ||
} | ||
options := JWKOptions{ | ||
Marshal: marshalOptions, | ||
Metadata: metadata, | ||
} | ||
jwk, err := NewJWKFromKey(secret, options) | ||
if err != nil { | ||
t.Fatalf("Failed to create a JWK from the given HMAC secret.\nError: %s", err) | ||
} | ||
err = serverStore.KeyWrite(ctx, jwk) | ||
if err != nil { | ||
t.Fatalf("Failed to write the given JWK to the store.\nError: %s", err) | ||
} | ||
rawJWKS, err := serverStore.JSON(ctx) | ||
if err != nil { | ||
t.Fatalf("Failed to get the JSON.\nError: %s", err) | ||
} | ||
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
_, _ = w.Write(rawJWKS) | ||
})) | ||
|
||
clientStore, err := NewDefaultHTTPClient([]string{server.URL}) | ||
if err != nil { | ||
t.Fatalf("Failed to create a new HTTP client.\nError: %s", err) | ||
} | ||
|
||
jwk, err = clientStore.KeyRead(ctx, kid) | ||
if err != nil { | ||
t.Fatalf("Failed to read the JWK.\nError: %s", err) | ||
} | ||
|
||
if !bytes.Equal(jwk.Key().([]byte), secret) { | ||
t.Fatalf("The key read from the HTTP client did not match the original key.") | ||
} | ||
|
||
jwks, err := clientStore.KeyReadAll(ctx) | ||
if err != nil { | ||
t.Fatalf("Failed to read all the JWKs.\nError: %s", err) | ||
} | ||
if len(jwks) != 1 { | ||
t.Fatalf("Expected to read 1 JWK, but got %d.", len(jwks)) | ||
} | ||
if !bytes.Equal(jwks[0].Key().([]byte), secret) { | ||
t.Fatalf("The key read from the HTTP client did not match the original key.") | ||
} | ||
|
||
ok, err := clientStore.KeyDelete(ctx, kid) | ||
if err != nil { | ||
t.Fatalf("Failed to delete the JWK.\nError: %s", err) | ||
} | ||
if !ok { | ||
t.Fatalf("Expected the key to be deleted.") | ||
} | ||
|
||
err = clientStore.KeyWrite(ctx, jwk) | ||
if err != nil { | ||
t.Fatalf("Failed to write the JWK.\nError: %s", err) | ||
} | ||
jwk, err = clientStore.KeyRead(ctx, kid) | ||
if err != nil { | ||
t.Fatalf("Failed to read the JWK.\nError: %s", err) | ||
} | ||
if !bytes.Equal(jwk.Key().([]byte), secret) { | ||
t.Fatalf("The key read from the HTTP client did not match the original key.") | ||
} | ||
} | ||
|
||
func TestClientError(t *testing.T) { | ||
_, err := NewHTTPClient(HTTPClientOptions{}) | ||
if err == nil { | ||
t.Fatalf("Expected an error when creating a new HTTP client without any URLs.") | ||
} | ||
} | ||
|
||
func TestClientJSON(t *testing.T) { | ||
c := httpClient{ | ||
given: NewMemoryStorage(), | ||
} | ||
testJSON(context.Background(), t, c) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
go 1.21.4 | ||
go 1.21.5 | ||
|
||
use ( | ||
../.. | ||
. | ||
../.. | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.