From cb01b09f5aff1091e1194535259f16ff33f3135f Mon Sep 17 00:00:00 2001 From: Micah Parks <66095735+MicahParks@users.noreply.github.com> Date: Sun, 10 Dec 2023 20:18:57 -0500 Subject: [PATCH] Add HTTP Client (#7) --- README.md | 9 - client.go | 219 +++++++++++++++++++++++ client_test.go | 103 +++++++++++ cmd/jwksetinfer/go.work | 4 +- cmd/jwksetinfer/main.go | 6 +- constants_test.go | 24 ++- error_test.go | 37 ++-- examples/http_server/main.go | 4 +- examples/storage_operations/main.go | 20 +-- jwk.go | 3 +- jwk_test.go | 71 ++++---- jwkset.go | 91 ---------- marshal.go | 33 ++++ marshal_test.go | 256 +++++++++++++-------------- storage.go | 264 +++++++++++++++++++++++++--- storage_test.go | 72 ++++---- website/handle/template/inspect.go | 2 +- 17 files changed, 858 insertions(+), 360 deletions(-) create mode 100644 client.go create mode 100644 client_test.go delete mode 100644 jwkset.go diff --git a/README.md b/README.md index 5afedc2..3b4af70 100644 --- a/README.md +++ b/README.md @@ -82,15 +82,6 @@ not implement any cryptographic algorithms itself. * This project does not currently support JWK Set encryption using JWE. This would involve implementing the relevant JWE specifications. It may be implemented in the future if there is interest. Open a GitHub issue to express interest. -# Test coverage - -``` -$ go test -cover -PASS -coverage: 85.5% of statements -ok github.com/MicahParks/jwkset 0.013s -``` - # See also * [`github.com/MicahParks/jcp`](https://github.com/MicahParks/jcp) diff --git a/client.go b/client.go new file mode 100644 index 0000000..2d8e50c --- /dev/null +++ b/client.go @@ -0,0 +1,219 @@ +package jwkset + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/url" + "time" +) + +var ( + // ErrNewClient fails to create a new JWK Set client. + ErrNewClient = errors.New("failed to create new JWK Set client") +) + +// HTTPClientOptions are options for creating a new JWK Set client. +type HTTPClientOptions struct { + // Given contains keys known from outside HTTP URLs. + Given Storage + // HTTPURLs are a mapping of HTTP URLs to JWK Set endpoints to storage implementations for the keys located at the + // URL. If empty, HTTP will not be used. + HTTPURLs map[string]Storage + // PrioritizeHTTP is a flag that indicates whether keys from the HTTP URL should be prioritized over keys from the + // given storage. + PrioritizeHTTP bool +} + +// Client is a JWK Set client. +type httpClient struct { + given Storage + httpURLs map[string]Storage + prioritizeHTTP bool +} + +// NewHTTPClient creates a new JWK Set client from remote HTTP resources. +func NewHTTPClient(options HTTPClientOptions) (Storage, error) { + if options.Given == nil && len(options.HTTPURLs) == 0 { + return nil, fmt.Errorf("%w: no given keys or HTTP URLs", ErrNewClient) + } + for u, store := range options.HTTPURLs { + if store == nil { + options.HTTPURLs[u] = NewMemoryStorage() + } + } + given := options.Given + if given == nil { + given = NewMemoryStorage() + } + c := httpClient{ + given: given, + httpURLs: options.HTTPURLs, + prioritizeHTTP: options.PrioritizeHTTP, + } + return c, nil +} + +// NewDefaultHTTPClient creates a new JWK Set client with default options from remote HTTP resources. +func NewDefaultHTTPClient(urls []string) (Storage, error) { + clientOptions := HTTPClientOptions{ + HTTPURLs: make(map[string]Storage), + } + for _, u := range urls { + parsed, err := url.ParseRequestURI(u) + if err != nil { + return nil, fmt.Errorf("failed to parse given URL %q: %w", u, errors.Join(err, ErrNewClient)) + } + u = parsed.String() + refreshErrorHandler := func(ctx context.Context, err error) { + slog.Default().ErrorContext(ctx, "Failed to refresh HTTP JWK Set from remote HTTP resource.", + "error", err, + "url", u, + ) + } + options := HTTPClientStorageOptions{ + NoErrorReturnFirstHTTPReq: true, + RefreshErrorHandler: refreshErrorHandler, + RefreshInterval: time.Hour, + } + c, err := NewStorageFromHTTP(parsed, options) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client storage for %q: %w", u, errors.Join(err, ErrNewClient)) + } + clientOptions.HTTPURLs[u] = c + } + return NewHTTPClient(clientOptions) +} + +func (c httpClient) KeyDelete(ctx context.Context, keyID string) (ok bool, err error) { + ok, err = c.given.KeyDelete(ctx, keyID) + if err != nil && !errors.Is(err, ErrKeyNotFound) { + return false, fmt.Errorf("failed to delete key with ID %q from given storage due to error: %w", keyID, err) + } + if ok { + return true, nil + } + for _, store := range c.httpURLs { + ok, err = store.KeyDelete(ctx, keyID) + if err != nil && !errors.Is(err, ErrKeyNotFound) { + return false, fmt.Errorf("failed to delete key with ID %q from HTTP storage due to error: %w", keyID, err) + } + if ok { + return true, nil + } + } + return false, nil +} +func (c httpClient) KeyRead(ctx context.Context, keyID string) (jwk JWK, err error) { + if !c.prioritizeHTTP { + jwk, err = c.given.KeyRead(ctx, keyID) + switch { + case errors.Is(err, ErrKeyNotFound): + // Do nothing. + case err != nil: + return JWK{}, fmt.Errorf("failed to find JWT key with ID %q in given storage due to error: %w", keyID, err) + default: + return jwk, nil + } + } + for _, store := range c.httpURLs { + jwk, err = store.KeyRead(ctx, keyID) + switch { + case errors.Is(err, ErrKeyNotFound): + continue + case err != nil: + return JWK{}, fmt.Errorf("failed to find JWT key with ID %q in HTTP storage due to error: %w", keyID, err) + default: + return jwk, nil + } + } + if c.prioritizeHTTP { + jwk, err = c.given.KeyRead(ctx, keyID) + switch { + case errors.Is(err, ErrKeyNotFound): + // Do nothing. + case err != nil: + return JWK{}, fmt.Errorf("failed to find JWT key with ID %q in given storage due to error: %w", keyID, err) + default: + return jwk, nil + } + } + return JWK{}, fmt.Errorf("%w %q", ErrKeyNotFound, keyID) +} +func (c httpClient) KeyReadAll(ctx context.Context) ([]JWK, error) { + jwks, err := c.given.KeyReadAll(ctx) + if err != nil { + return nil, fmt.Errorf("failed to snapshot given keys due to error: %w", err) + } + for u, store := range c.httpURLs { + j, err := store.KeyReadAll(ctx) + if err != nil { + return nil, fmt.Errorf("failed to snapshot HTTP keys from %q due to error: %w", u, err) + } + jwks = append(jwks, j...) + } + return jwks, nil +} +func (c httpClient) KeyWrite(ctx context.Context, jwk JWK) error { + return c.given.KeyWrite(ctx, jwk) +} + +func (c httpClient) JSON(ctx context.Context) (json.RawMessage, error) { + m, err := c.combineStorage(ctx) + if err != nil { + return nil, fmt.Errorf("failed to combine storage due to error: %w", err) + } + return m.JSON(ctx) +} +func (c httpClient) JSONPublic(ctx context.Context) (json.RawMessage, error) { + m, err := c.combineStorage(ctx) + if err != nil { + return nil, fmt.Errorf("failed to combine storage due to error: %w", err) + } + return m.JSONPublic(ctx) +} +func (c httpClient) JSONPrivate(ctx context.Context) (json.RawMessage, error) { + m, err := c.combineStorage(ctx) + if err != nil { + return nil, fmt.Errorf("failed to combine storage due to error: %w", err) + } + return m.JSONPrivate(ctx) +} +func (c httpClient) JSONWithOptions(ctx context.Context, marshalOptions JWKMarshalOptions, validationOptions JWKValidateOptions) (json.RawMessage, error) { + m, err := c.combineStorage(ctx) + if err != nil { + return nil, fmt.Errorf("failed to combine storage due to error: %w", err) + } + return m.JSONWithOptions(ctx, marshalOptions, validationOptions) +} +func (c httpClient) Marshal(ctx context.Context) (JWKSMarshal, error) { + m, err := c.combineStorage(ctx) + if err != nil { + return JWKSMarshal{}, fmt.Errorf("failed to combine storage due to error: %w", err) + } + return m.Marshal(ctx) +} +func (c httpClient) MarshalWithOptions(ctx context.Context, marshalOptions JWKMarshalOptions, validationOptions JWKValidateOptions) (JWKSMarshal, error) { + m, err := c.combineStorage(ctx) + if err != nil { + return JWKSMarshal{}, fmt.Errorf("failed to combine storage due to error: %w", err) + } + return m.MarshalWithOptions(ctx, marshalOptions, validationOptions) +} + +func (c httpClient) combineStorage(ctx context.Context) (Storage, error) { + jwks, err := c.KeyReadAll(ctx) + if err != nil { + return nil, fmt.Errorf("failed to snapshot keys due to error: %w", err) + } + m := NewMemoryStorage() + for _, jwk := range jwks { + err = m.KeyWrite(ctx, jwk) + if err != nil { + return nil, fmt.Errorf("failed to write key to memory storage due to error: %w", err) + } + } + return m, nil +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..06a7ca9 --- /dev/null +++ b/client_test.go @@ -0,0 +1,103 @@ +package jwkset + +import ( + "bytes" + "context" + "net/http" + "net/http/httptest" + "testing" +) + +func TestClient(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + kid := "my-key-id" + secret := []byte("my-hmac-secret") + serverStore := NewMemoryStorage() + marshalOptions := JWKMarshalOptions{ + Private: true, + } + metadata := JWKMetadataOptions{ + KID: kid, + } + options := JWKOptions{ + Marshal: marshalOptions, + Metadata: metadata, + } + jwk, err := NewJWKFromKey(secret, options) + if err != nil { + t.Fatalf("Failed to create a JWK from the given HMAC secret.\nError: %s", err) + } + err = serverStore.KeyWrite(ctx, jwk) + if err != nil { + t.Fatalf("Failed to write the given JWK to the store.\nError: %s", err) + } + rawJWKS, err := serverStore.JSON(ctx) + if err != nil { + t.Fatalf("Failed to get the JSON.\nError: %s", err) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write(rawJWKS) + })) + + clientStore, err := NewDefaultHTTPClient([]string{server.URL}) + if err != nil { + t.Fatalf("Failed to create a new HTTP client.\nError: %s", err) + } + + jwk, err = clientStore.KeyRead(ctx, kid) + if err != nil { + t.Fatalf("Failed to read the JWK.\nError: %s", err) + } + + if !bytes.Equal(jwk.Key().([]byte), secret) { + t.Fatalf("The key read from the HTTP client did not match the original key.") + } + + jwks, err := clientStore.KeyReadAll(ctx) + if err != nil { + t.Fatalf("Failed to read all the JWKs.\nError: %s", err) + } + if len(jwks) != 1 { + t.Fatalf("Expected to read 1 JWK, but got %d.", len(jwks)) + } + if !bytes.Equal(jwks[0].Key().([]byte), secret) { + t.Fatalf("The key read from the HTTP client did not match the original key.") + } + + ok, err := clientStore.KeyDelete(ctx, kid) + if err != nil { + t.Fatalf("Failed to delete the JWK.\nError: %s", err) + } + if !ok { + t.Fatalf("Expected the key to be deleted.") + } + + err = clientStore.KeyWrite(ctx, jwk) + if err != nil { + t.Fatalf("Failed to write the JWK.\nError: %s", err) + } + jwk, err = clientStore.KeyRead(ctx, kid) + if err != nil { + t.Fatalf("Failed to read the JWK.\nError: %s", err) + } + if !bytes.Equal(jwk.Key().([]byte), secret) { + t.Fatalf("The key read from the HTTP client did not match the original key.") + } +} + +func TestClientError(t *testing.T) { + _, err := NewHTTPClient(HTTPClientOptions{}) + if err == nil { + t.Fatalf("Expected an error when creating a new HTTP client without any URLs.") + } +} + +func TestClientJSON(t *testing.T) { + c := httpClient{ + given: NewMemoryStorage(), + } + testJSON(context.Background(), t, c) +} diff --git a/cmd/jwksetinfer/go.work b/cmd/jwksetinfer/go.work index 86c9102..664d035 100644 --- a/cmd/jwksetinfer/go.work +++ b/cmd/jwksetinfer/go.work @@ -1,6 +1,6 @@ -go 1.21.4 +go 1.21.5 use ( - ../.. . + ../.. ) diff --git a/cmd/jwksetinfer/main.go b/cmd/jwksetinfer/main.go index 46bc3eb..e7feeb4 100644 --- a/cmd/jwksetinfer/main.go +++ b/cmd/jwksetinfer/main.go @@ -44,7 +44,7 @@ func main() { allPEM = s.String() } - jwks := jwkset.NewMemory() + jwks := jwkset.NewMemoryStorage() i := 0 const kidPrefix = "UniqueKeyID" @@ -81,7 +81,7 @@ func main() { ) os.Exit(1) } - err = jwks.Store.WriteKey(ctx, jwk) + err = jwks.KeyWrite(ctx, jwk) if err != nil { l.Error("Failed to write JWK.", logErr, err, @@ -110,7 +110,7 @@ func main() { ) os.Exit(1) } - err = jwks.Store.WriteKey(ctx, jwk) + err = jwks.KeyWrite(ctx, jwk) if err != nil { l.Error("Failed to write JWK.", logErr, err, diff --git a/constants_test.go b/constants_test.go index 98add4c..a143d25 100644 --- a/constants_test.go +++ b/constants_test.go @@ -4,11 +4,22 @@ import ( "testing" ) +const ( + invalid = "invalid" +) + func TestALG(t *testing.T) { a := AlgHS256 if a.String() != string(a) { t.Errorf("Failed to get proper string from String method.") } + if !a.IANARegistered() { + t.Errorf("Failed to validate valid ALG.") + } + a = invalid + if a.IANARegistered() { + t.Errorf("Do not validate invalid ALG.") + } } func TestCRV(t *testing.T) { @@ -16,6 +27,13 @@ func TestCRV(t *testing.T) { if c.String() != string(c) { t.Errorf("Failed to get proper string from String method.") } + if !c.IANARegistered() { + t.Errorf("Failed to validate valid CRV.") + } + c = invalid + if c.IANARegistered() { + t.Errorf("Do not validate invalid CRV.") + } } func TestKEYOPS(t *testing.T) { @@ -26,7 +44,7 @@ func TestKEYOPS(t *testing.T) { if !k.IANARegistered() { t.Errorf("Failed to validate valid KEYOPS.") } - k = "invalid" + k = invalid if k.IANARegistered() { t.Errorf("Do not validate invalid KEYOPS.") } @@ -40,7 +58,7 @@ func TestKTY(t *testing.T) { if !k.IANARegistered() { t.Errorf("Failed to validate valid KTY.") } - k = "invalid" + k = invalid if k.IANARegistered() { t.Errorf("Do not validate invalid KTY.") } @@ -54,7 +72,7 @@ func TestUSE(t *testing.T) { if !u.IANARegistered() { t.Errorf("Failed to validate valid USE.") } - u = "invalid" + u = invalid if u.IANARegistered() { t.Errorf("Do not validate invalid USE.") } diff --git a/error_test.go b/error_test.go index a3a8779..698a0a9 100644 --- a/error_test.go +++ b/error_test.go @@ -1,12 +1,11 @@ -package jwkset_test +package jwkset import ( "context" + "encoding/json" "errors" "testing" "time" - - "github.com/MicahParks/jwkset" ) var ( @@ -15,25 +14,43 @@ var ( type storageError struct{} -func (s storageError) DeleteKey(_ context.Context, _ string) (ok bool, err error) { +func (s storageError) KeyDelete(_ context.Context, _ string) (ok bool, err error) { return false, errStorage } -func (s storageError) ReadKey(_ context.Context, _ string) (jwkset.JWK, error) { - return jwkset.JWK{}, errStorage +func (s storageError) KeyRead(_ context.Context, _ string) (JWK, error) { + return JWK{}, errStorage } -func (s storageError) SnapshotKeys(_ context.Context) ([]jwkset.JWK, error) { +func (s storageError) KeyReadAll(_ context.Context) ([]JWK, error) { return nil, errStorage } -func (s storageError) WriteKey(_ context.Context, _ jwkset.JWK) error { +func (s storageError) KeyWrite(_ context.Context, _ JWK) error { return errStorage } +func (s storageError) JSON(_ context.Context) (json.RawMessage, error) { + return nil, errStorage +} +func (s storageError) JSONPublic(_ context.Context) (json.RawMessage, error) { + return nil, errStorage +} +func (s storageError) JSONPrivate(_ context.Context) (json.RawMessage, error) { + return nil, errStorage +} +func (s storageError) JSONWithOptions(_ context.Context, _ JWKMarshalOptions, _ JWKValidateOptions) (json.RawMessage, error) { + return nil, errStorage +} +func (s storageError) Marshal(_ context.Context) (JWKSMarshal, error) { + return JWKSMarshal{}, errStorage +} +func (s storageError) MarshalWithOptions(_ context.Context, _ JWKMarshalOptions, _ JWKValidateOptions) (JWKSMarshal, error) { + return JWKSMarshal{}, errStorage +} + func TestStorageError(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - jwks := jwkset.NewMemory() - jwks.Store = storageError{} + jwks := storageError{} _, err := jwks.JSONPublic(ctx) if err == nil { diff --git a/examples/http_server/main.go b/examples/http_server/main.go index 1abe859..757f4a4 100644 --- a/examples/http_server/main.go +++ b/examples/http_server/main.go @@ -19,7 +19,7 @@ func main() { ctx := context.Background() logger := log.New(os.Stdout, "", 0) - jwkSet := jwkset.NewMemory() + jwkSet := jwkset.NewMemoryStorage() // Create an RSA key. key, err := rsa.GenerateKey(rand.Reader, 4096) @@ -42,7 +42,7 @@ func main() { } // Write the key to the JWK Set storage. - err = jwkSet.Store.WriteKey(ctx, jwk) + err = jwkSet.KeyWrite(ctx, jwk) if err != nil { logger.Fatalf(logFmt, "Failed to store RSA key.", err) } diff --git a/examples/storage_operations/main.go b/examples/storage_operations/main.go index dfdc9ac..f61348e 100644 --- a/examples/storage_operations/main.go +++ b/examples/storage_operations/main.go @@ -23,7 +23,7 @@ func main() { logger := log.New(os.Stdout, "", 0) // Create a new JWK Set using memory-backed storage. - jwkSet := jwkset.NewMemory() + jwkSet := jwkset.NewMemoryStorage() // Create a new ECDSA key and store it in the JWK Set. ec, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) @@ -35,7 +35,7 @@ func main() { if err != nil { logger.Fatalf(logFmt, "Failed to create JWK from ECDSA key.", err) } - err = jwkSet.Store.WriteKey(ctx, jwk) + err = jwkSet.KeyWrite(ctx, jwk) if err != nil { logger.Fatalf(logFmt, "Failed to store ECDSA key.", err) } @@ -50,7 +50,7 @@ func main() { if err != nil { logger.Fatalf(logFmt, "Failed to create JWK from EdDSA key.", err) } - err = jwkSet.Store.WriteKey(ctx, jwk) + err = jwkSet.KeyWrite(ctx, jwk) if err != nil { logger.Fatalf(logFmt, "Failed to store EdDSA key.", err) } @@ -65,7 +65,7 @@ func main() { if err != nil { logger.Fatalf(logFmt, "Failed to create JWK from RSA key.", err) } - err = jwkSet.Store.WriteKey(ctx, jwk) + err = jwkSet.KeyWrite(ctx, jwk) if err != nil { logger.Fatalf(logFmt, "Failed to store RSA key.", err) } @@ -77,7 +77,7 @@ func main() { if err != nil { logger.Fatalf(logFmt, "Failed to create JWK from HMAC key.", err) } - err = jwkSet.Store.WriteKey(ctx, jwk) + err = jwkSet.KeyWrite(ctx, jwk) if err != nil { logger.Fatalf(logFmt, "Failed to store HMAC key.", err) } @@ -91,7 +91,7 @@ func main() { logger.Println(string(jsonRepresentation)) // Delete the previously added RSA key from the JWK Set, then reprint the JSON representation. - _, err = jwkSet.Store.DeleteKey(ctx, rID) + _, err = jwkSet.KeyDelete(ctx, rID) if err != nil { logger.Fatalf(logFmt, "Failed to delete RSA key.", err) } @@ -103,7 +103,7 @@ func main() { logger.Println(string(jsonRepresentation)) // Delete the previously added ECDSA key from the JWK Set, add a new one, then reprint the JSON representation. - _, err = jwkSet.Store.DeleteKey(ctx, ecID) + _, err = jwkSet.KeyDelete(ctx, ecID) if err != nil { logger.Fatalf(logFmt, "Failed to delete ECDSA key.", err) } @@ -115,7 +115,7 @@ func main() { if err != nil { logger.Fatalf(logFmt, "Failed to create JWK from ECDSA key.", err) } - err = jwkSet.Store.WriteKey(ctx, jwk) + err = jwkSet.KeyWrite(ctx, jwk) if err != nil { logger.Fatalf(logFmt, "Failed to store ECDSA key.", err) } @@ -127,7 +127,7 @@ func main() { logger.Println(string(jsonRepresentation)) // Read the previously added EdDSA key from the JWK Set, the print its private key. - jwk, err = jwkSet.Store.ReadKey(ctx, edID) + jwk, err = jwkSet.KeyRead(ctx, edID) if err != nil { logger.Fatalf(logFmt, "Failed to read EdDSA key.", err) } @@ -138,7 +138,7 @@ func main() { logger.Printf("Retrieved EdDSA private key Base64RawURL: %s", base64.RawURLEncoding.EncodeToString(edKey)) // Read the previously added HMAC key from the JWK Set, the print it. - jwk, err = jwkSet.Store.ReadKey(ctx, hid) + jwk, err = jwkSet.KeyRead(ctx, hid) if err != nil { logger.Fatalf(logFmt, "Failed to read HMAC key.", err) } diff --git a/jwk.go b/jwk.go index a91a10c..4343b5b 100644 --- a/jwk.go +++ b/jwk.go @@ -153,9 +153,8 @@ func NewJWKFromX5C(options JWKOptions) (JWK, error) { if cert.PublicKeyAlgorithm == x509.Ed25519 { if options.Metadata.ALG != "" && options.Metadata.ALG != AlgEdDSA { return JWK{}, fmt.Errorf("%w: ALG in metadata does not match ALG in X.509 certificate", errors.Join(ErrOptions, ErrX509Mismatch)) - } else { - options.Metadata.ALG = AlgEdDSA } + options.Metadata.ALG = AlgEdDSA } j := JWK{ diff --git a/jwk_test.go b/jwk_test.go index 6597c28..5dd075e 100644 --- a/jwk_test.go +++ b/jwk_test.go @@ -1,4 +1,4 @@ -package jwkset_test +package jwkset import ( "context" @@ -10,15 +10,13 @@ import ( "encoding/pem" "testing" "time" - - "github.com/MicahParks/jwkset" ) func TestNewJWKFromRawJSON(t *testing.T) { - marshalOptions := jwkset.JWKMarshalOptions{ + marshalOptions := JWKMarshalOptions{ Private: true, } - jwk, err := jwkset.NewJWKFromRawJSON([]byte(edExpected), marshalOptions, jwkset.JWKValidateOptions{}) + jwk, err := NewJWKFromRawJSON([]byte(edExpected), marshalOptions, JWKValidateOptions{}) if err != nil { t.Fatalf("Failed to create JWK from raw JSON. %s", err) } @@ -26,7 +24,7 @@ func TestNewJWKFromRawJSON(t *testing.T) { t.Fatalf("Incorrect KID. %s", jwk.Marshal().KID) } - _, err = jwkset.NewJWKFromRawJSON([]byte("invalid"), jwkset.JWKMarshalOptions{}, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromRawJSON([]byte("invalid"), JWKMarshalOptions{}, JWKValidateOptions{}) if err == nil { t.Fatal("Expected an error.") } @@ -36,8 +34,11 @@ func TestJSON(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - jwks := jwkset.NewMemory() + jwks := NewMemoryStorage() + testJSON(ctx, t, jwks) +} +func testJSON(ctx context.Context, t *testing.T, jwks Storage) { b, err := base64.RawURLEncoding.DecodeString(x25519PrivateKey) if err != nil { t.Fatalf("Failed to decode ECDH X25519 private key. %s", err) @@ -46,14 +47,14 @@ func TestJSON(t *testing.T) { if err != nil { t.Fatalf("Failed to generate ECDH X25519 key. %s", err) } - writeKey(ctx, t, jwks, x25519Priv, x25519ID, false) + writeKey(ctx, t, jwks, x25519Priv, x25519ID, true) block, _ := pem.Decode([]byte(ecPrivateKey)) eKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { t.Fatalf("Failed to parse EC private key. %s", err) } - writeKey(ctx, t, jwks, eKey, eID, false) + writeKey(ctx, t, jwks, eKey, eID, true) edPriv, err := base64.RawURLEncoding.DecodeString(edPrivateKey) if err != nil { @@ -64,15 +65,14 @@ func TestJSON(t *testing.T) { t.Fatalf("Failed to decode EdDSA public key. %s", err) } ed := ed25519.PrivateKey(append(edPriv, edPub...)) - writeKey(ctx, t, jwks, ed, edID, false) + writeKey(ctx, t, jwks, ed, edID, true) block, _ = pem.Decode([]byte(rsaPrivateKey)) rKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { t.Fatalf("Failed to parse RSA private key. %s", err) } - const rID = "myRSAKey" - writeKey(ctx, t, jwks, rKey, rID, false) + writeKey(ctx, t, jwks, rKey, rID, true) hKey := []byte(hmacSecret) writeKey(ctx, t, jwks, hKey, hID, true) @@ -88,19 +88,6 @@ func TestJSON(t *testing.T) { t.Fatalf("Failed to get JSON. %s", err) } compareJSON(t, jsonRepresentation, true) - - jwks = jwkset.NewMemory() - writeKey(ctx, t, jwks, x25519Priv, x25519ID, true) - writeKey(ctx, t, jwks, eKey, eID, true) - writeKey(ctx, t, jwks, ed, edID, true) - writeKey(ctx, t, jwks, rKey, rID, true) - writeKey(ctx, t, jwks, hKey, hID, true) - - jsonRepresentation, err = jwks.JSON(ctx) - if err != nil { - t.Fatalf("Failed to get JSON. %s", err) - } - compareJSON(t, jsonRepresentation, true) } func compareJSON(t *testing.T, actual json.RawMessage, private bool) { @@ -115,13 +102,16 @@ func compareJSON(t *testing.T, actual json.RawMessage, private bool) { } wrongLength := false + var expectedKeys int if private && len(keys.Keys) != 5 { + expectedKeys = 5 wrongLength = true } else if !private && len(keys.Keys) != 4 { + expectedKeys = 4 wrongLength = true } if wrongLength { - t.Fatalf("Expected 3 keys. Got %d. HMAC keys should not have a JSON representation.", len(keys.Keys)) + t.Fatalf("Expected %d keys. Got %d. HMAC keys should not have a JSON representation.", expectedKeys, len(keys.Keys)) } for _, key := range keys.Keys { @@ -132,34 +122,34 @@ func compareJSON(t *testing.T, actual json.RawMessage, private bool) { var expectedJSON json.RawMessage var matchingAttributes []string - switch jwkset.KTY(kty) { - case jwkset.KtyEC: + switch KTY(kty) { + case KtyEC: expectedJSON = json.RawMessage(ecExpected) matchingAttributes = []string{"kty", "kid", "crv", "x", "y"} if private { matchingAttributes = append(matchingAttributes, "d") } - case jwkset.KtyOKP: + case KtyOKP: matchingAttributes = []string{"crv", "kty", "kid", "x"} if private { matchingAttributes = append(matchingAttributes, "d") } - switch jwkset.CRV(key["crv"].(string)) { - case jwkset.CrvEd25519: + switch CRV(key["crv"].(string)) { + case CrvEd25519: matchingAttributes = append(matchingAttributes, "alg") expectedJSON = json.RawMessage(edExpected) - case jwkset.CrvX25519: + case CrvX25519: expectedJSON = json.RawMessage(x25519Expected) default: t.Fatalf("Unknown OKP curve %q.", key["crv"].(string)) } - case jwkset.KtyRSA: + case KtyRSA: expectedJSON = json.RawMessage(rsaExpected) matchingAttributes = []string{"kty", "kid", "n", "e"} if private { matchingAttributes = append(matchingAttributes, "d", "p", "q", "dp", "dq", "qi") } - case jwkset.KtyOct: + case KtyOct: if private { expectedJSON = json.RawMessage(hmacExpected) matchingAttributes = []string{"kty", "kid", "k"} @@ -189,22 +179,22 @@ func compareJSON(t *testing.T, actual json.RawMessage, private bool) { } } -func writeKey(ctx context.Context, t *testing.T, jwks jwkset.JWKSet, key any, keyID string, private bool) { - marshal := jwkset.JWKMarshalOptions{ +func writeKey(ctx context.Context, t *testing.T, jwks Storage, key any, keyID string, private bool) { + marshal := JWKMarshalOptions{ Private: private, } - metadata := jwkset.JWKMetadataOptions{ + metadata := JWKMetadataOptions{ KID: keyID, } - options := jwkset.JWKOptions{ + options := JWKOptions{ Marshal: marshal, Metadata: metadata, } - jwk, err := jwkset.NewJWKFromKey(key, options) + jwk, err := NewJWKFromKey(key, options) if err != nil { t.Fatalf("Failed to create JWK from key ID %q. %s", keyID, err) } - err = jwks.Store.WriteKey(ctx, jwk) + err = jwks.KeyWrite(ctx, jwk) if err != nil { t.Fatalf("Failed to write key ID %q. %s", keyID, err) } @@ -215,6 +205,7 @@ const ( eID = "myECKey" edID = "myEdDSAKey" hID = "myHMACKey" + rID = "myRSAKey" ) /* diff --git a/jwkset.go b/jwkset.go deleted file mode 100644 index bb4b542..0000000 --- a/jwkset.go +++ /dev/null @@ -1,91 +0,0 @@ -package jwkset - -import ( - "context" - "encoding/json" - "errors" - "fmt" -) - -// JWKSet is a set of JSON Web Keys. -type JWKSet struct { - Store Storage -} - -// NewMemory creates a new in-memory JWKSet. -func NewMemory() JWKSet { - return JWKSet{ - Store: NewMemoryStorage(), - } -} - -func (j JWKSet) JSON(ctx context.Context) (json.RawMessage, error) { - jwks, err := j.Marshal(ctx) - if err != nil { - return nil, fmt.Errorf("failed to marshal JWK Set: %w", err) - } - return json.Marshal(jwks) -} - -// JSONPublic creates the JSON representation of the public keys in JWKSet. -func (j JWKSet) JSONPublic(ctx context.Context) (json.RawMessage, error) { - return j.JSONWithOptions(ctx, JWKMarshalOptions{}, JWKValidateOptions{}) -} - -// JSONPrivate creates the JSON representation of the JWKSet public and private key material. -func (j JWKSet) JSONPrivate(ctx context.Context) (json.RawMessage, error) { - marshalOptions := JWKMarshalOptions{ - Private: true, - } - return j.JSONWithOptions(ctx, marshalOptions, JWKValidateOptions{}) -} - -// JSONWithOptions creates the JSON representation of the JWKSet with the given options. These options override whatever -// options are set on the individual JWKs. -func (j JWKSet) JSONWithOptions(ctx context.Context, marshalOptions JWKMarshalOptions, validationOptions JWKValidateOptions) (json.RawMessage, error) { - jwks, err := j.MarshalWithOptions(ctx, marshalOptions, validationOptions) - if err != nil { - return nil, fmt.Errorf("failed to marshal JWK Set with options: %w", err) - } - return json.Marshal(jwks) -} - -// Marshal transforms the JWK Set's current state into a Go type that can be marshaled into JSON. -func (j JWKSet) Marshal(ctx context.Context) (JWKSMarshal, error) { - keys, err := j.Store.SnapshotKeys(ctx) - if err != nil { - return JWKSMarshal{}, fmt.Errorf("failed to read snapshot of all keys from storage: %w", err) - } - jwks := JWKSMarshal{} - for _, key := range keys { - jwks.Keys = append(jwks.Keys, key.Marshal()) - } - return jwks, nil -} - -// MarshalWithOptions transforms the JWK Set's current state into a Go type that can be marshaled into JSON with the -// given options. These options override whatever options are set on the individual JWKs. -func (j JWKSet) MarshalWithOptions(ctx context.Context, marshalOptions JWKMarshalOptions, validationOptions JWKValidateOptions) (JWKSMarshal, error) { - jwks := JWKSMarshal{} - - keys, err := j.Store.SnapshotKeys(ctx) - if err != nil { - return JWKSMarshal{}, fmt.Errorf("failed to read snapshot of all keys from storage: %w", err) - } - - for _, key := range keys { - options := key.options - options.Marshal = marshalOptions - options.Validate = validationOptions - marshal, err := keyMarshal(key.Key(), options) - if err != nil { - if errors.Is(err, ErrOptions) { - continue - } - return JWKSMarshal{}, fmt.Errorf("failed to marshal key: %w", err) - } - jwks.Keys = append(jwks.Keys, marshal) - } - - return jwks, nil -} diff --git a/marshal.go b/marshal.go index e7dbef4..eba7a6c 100644 --- a/marshal.go +++ b/marshal.go @@ -1,6 +1,7 @@ package jwkset import ( + "context" "crypto/ecdh" "crypto/ecdsa" "crypto/ed25519" @@ -76,6 +77,38 @@ type JWKSMarshal struct { Keys []JWKMarshal `json:"keys"` } +// JWKSlice converts the JWKSMarshal to a []JWK. +func (j JWKSMarshal) JWKSlice() ([]JWK, error) { + slice := make([]JWK, len(j.Keys)) + for i, key := range j.Keys { + marshalOptions := JWKMarshalOptions{ + Private: true, + } + jwk, err := keyUnmarshal(key, marshalOptions, JWKValidateOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal JWK: %w", err) + } + slice[i] = jwk + } + return slice, nil +} + +// ToStorage converts the JWKSMarshal to a Storage. +func (j JWKSMarshal) ToStorage() (Storage, error) { + m := NewMemoryStorage() + jwks, err := j.JWKSlice() + if err != nil { + return nil, fmt.Errorf("failed to create a slice of JWK from JWKSMarshal: %w", err) + } + for _, jwk := range jwks { + err = m.KeyWrite(context.Background(), jwk) + if err != nil { + return nil, fmt.Errorf("failed to write JWK to storage: %w", err) + } + } + return m, nil +} + func keyMarshal(key any, options JWKOptions) (JWKMarshal, error) { m := JWKMarshal{} m.ALG = options.Metadata.ALG diff --git a/marshal_test.go b/marshal_test.go index 9392034..4d23f8b 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -1,4 +1,4 @@ -package jwkset_test +package jwkset import ( "bytes" @@ -12,8 +12,6 @@ import ( "math/big" "slices" "testing" - - "github.com/MicahParks/jwkset" ) const ( @@ -53,11 +51,11 @@ const ( ) func TestMarshalECDH(t *testing.T) { - checkMarshal := func(marshal jwkset.JWKMarshal, options jwkset.JWKOptions) { + checkMarshal := func(marshal JWKMarshal, options JWKOptions) { if marshal.ALG != "" { t.Fatal(`Marshaled key parameter "alg" should be empty when not set.`) } - if marshal.CRV != jwkset.CrvX25519 { + if marshal.CRV != CrvX25519 { t.Fatal(`Marshaled key parameter "crv" does not match original key.`) } if options.Marshal.Private { @@ -69,7 +67,7 @@ func TestMarshalECDH(t *testing.T) { t.Fatalf("Asymmetric private key should be unsupported for given options.") } } - if marshal.KTY != jwkset.KtyOKP { + if marshal.KTY != KtyOKP { t.Fatal(`Marshaled key parameter "kty" does not match original key.`) } if marshal.X != ecdhX25519X { @@ -78,22 +76,22 @@ func TestMarshalECDH(t *testing.T) { } private := makeECDHX25519Private(t) - options := jwkset.JWKOptions{} - jwk, err := jwkset.NewJWKFromKey(private, options) + options := JWKOptions{} + jwk, err := NewJWKFromKey(private, options) if err != nil { t.Fatalf("Failed to marshal key with correct options. %s", err) } checkMarshal(jwk.Marshal(), options) options.Marshal.Private = true - jwk, err = jwkset.NewJWKFromKey(private, options) + jwk, err = NewJWKFromKey(private, options) if err != nil { t.Fatalf("Failed to marshal key with correct options. %s", err) } checkMarshal(jwk.Marshal(), options) options.Marshal.Private = false - jwk, err = jwkset.NewJWKFromKey(private.Public(), options) + jwk, err = NewJWKFromKey(private.Public(), options) if err != nil { t.Fatalf("Failed to marshal key with correct options. %s", err) } @@ -103,16 +101,16 @@ func TestMarshalECDH(t *testing.T) { func TestUnmarshalECDH(t *testing.T) { private := makeECDHX25519Private(t) - marshal := jwkset.JWKMarshal{ - CRV: jwkset.CrvX25519, + marshal := JWKMarshal{ + CRV: CrvX25519, D: ecdhX25519D, KID: myKeyID, - KTY: jwkset.KtyOKP, + KTY: KtyOKP, X: ecdhX25519X, } - marshalOptions := jwkset.JWKMarshalOptions{} - jwk, err := jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + marshalOptions := JWKMarshalOptions{} + jwk, err := NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err != nil { t.Fatalf("Failed to unmarshal key with correct options. %s", err) } @@ -121,7 +119,7 @@ func TestUnmarshalECDH(t *testing.T) { } marshalOptions.Private = true - jwk, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + jwk, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err != nil { t.Fatalf("Failed to unmarshal key with correct options. %s", err) } @@ -133,50 +131,50 @@ func TestUnmarshalECDH(t *testing.T) { } marshal.D = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "d" is invalid raw Base64URL. %s`, err) } marshal.X = "" - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) - if !errors.Is(err, jwkset.ErrKeyUnmarshalParameter) { + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) + if !errors.Is(err, ErrKeyUnmarshalParameter) { t.Fatalf(`Should get ErrKeyUnmarshalParameter when parameter "x" is empty. %s`, err) } marshal.X = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "x" is invalid raw Base64URL. %s`, err) } marshal.CRV = "" - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) - if !errors.Is(err, jwkset.ErrKeyUnmarshalParameter) { + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) + if !errors.Is(err, ErrKeyUnmarshalParameter) { t.Fatalf(`Should get ErrKeyUnmarshalParameter when parameter "crv" is empty. %s`, err) } - marshal.CRV = jwkset.CrvX25519 + marshal.CRV = CrvX25519 invalidSize := base64.RawURLEncoding.EncodeToString([]byte("invalidSize")) marshal.X = invalidSize marshal.D = ecdhX25519D - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) - if !errors.Is(err, jwkset.ErrKeyUnmarshalParameter) { + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) + if !errors.Is(err, ErrKeyUnmarshalParameter) { t.Fatalf(`Should get ErrKeyUnmarshalParameter when parameter "x" is invalid size. %s`, err) } marshal.X = ecdhX25519X marshal.D = invalidSize - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) - if !errors.Is(err, jwkset.ErrKeyUnmarshalParameter) { + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) + if !errors.Is(err, ErrKeyUnmarshalParameter) { t.Fatalf(`Should get ErrKeyUnmarshalParameter when parameter "d" is invalid size. %s`, err) } } func TestMarshalECDSA(t *testing.T) { - keyOps := []jwkset.KEYOPS{jwkset.KeyOpsSign, jwkset.KeyOpsVerify} - checkMarshal := func(marshal jwkset.JWKMarshal, options jwkset.JWKOptions) { - if marshal.ALG != jwkset.AlgES256 { + keyOps := []KEYOPS{KeyOpsSign, KeyOpsVerify} + checkMarshal := func(marshal JWKMarshal, options JWKOptions) { + if marshal.ALG != AlgES256 { t.Fatal(`Marshaled parameter "alg" does not match original key.`) } if marshal.KID != myKeyID { @@ -185,10 +183,10 @@ func TestMarshalECDSA(t *testing.T) { if !slices.Equal(marshal.KEYOPS, keyOps) { t.Fatal(`Marshaled parameter "key_ops" does not match original key.`) } - if marshal.USE != jwkset.UseSig { + if marshal.USE != UseSig { t.Fatal(`Marshaled parameter "use" does not match original key.`) } - if marshal.CRV != jwkset.CrvP256 { + if marshal.CRV != CrvP256 { t.Fatal(`Marshaled parameter "crv" does not match original key.`) } if options.Marshal.Private { @@ -200,7 +198,7 @@ func TestMarshalECDSA(t *testing.T) { t.Fatal("Asymmetric private key should be unsupported for given options.") } } - if marshal.KTY != jwkset.KtyEC { + if marshal.KTY != KtyEC { t.Fatal(`Marshaled parameter "kty" does not match original key.`) } if marshal.X != ecdsaP256X { @@ -212,13 +210,13 @@ func TestMarshalECDSA(t *testing.T) { } private := makeECDSAP256(t) - metadata := jwkset.JWKMetadataOptions{ - ALG: jwkset.AlgES256, + metadata := JWKMetadataOptions{ + ALG: AlgES256, KID: myKeyID, KEYOPS: keyOps, - USE: jwkset.UseSig, + USE: UseSig, } - options := jwkset.JWKOptions{ + options := JWKOptions{ Metadata: metadata, } jwk := newJWK(t, private, options) @@ -235,7 +233,7 @@ func TestMarshalECDSA(t *testing.T) { } func TestUnmarshalECDSA(t *testing.T) { - checkUnmarshal := func(jwk jwkset.JWK, options jwkset.JWKMarshalOptions, original *ecdsa.PrivateKey) { + checkUnmarshal := func(jwk JWK, options JWKMarshalOptions, original *ecdsa.PrivateKey) { var public *ecdsa.PublicKey if options.Private { private := jwk.Key().(*ecdsa.PrivateKey) @@ -257,16 +255,16 @@ func TestUnmarshalECDSA(t *testing.T) { } } - marshal := jwkset.JWKMarshal{ - CRV: jwkset.CrvP256, + marshal := JWKMarshal{ + CRV: CrvP256, D: ecdsaP256D, - KTY: jwkset.KtyEC, + KTY: KtyEC, X: ecdsaP256X, Y: ecdsaP256Y, } key := makeECDSAP256(t) - marshalOptions := jwkset.JWKMarshalOptions{} + marshalOptions := JWKMarshalOptions{} jwk := newJWKFromMarshal(t, marshal, marshalOptions) checkUnmarshal(jwk, marshalOptions, key) @@ -275,7 +273,7 @@ func TestUnmarshalECDSA(t *testing.T) { checkUnmarshal(jwk, marshalOptions, key) key = makeECDSAP384(t) - marshal.CRV = jwkset.CrvP384 + marshal.CRV = CrvP384 marshal.D = ecdsaP384D marshal.X = ecdsaP384X marshal.Y = ecdsaP384Y @@ -283,7 +281,7 @@ func TestUnmarshalECDSA(t *testing.T) { checkUnmarshal(jwk, marshalOptions, key) key = makeECDSAP521(t) - marshal.CRV = jwkset.CrvP521 + marshal.CRV = CrvP521 marshal.D = ecdsaP521D marshal.X = ecdsaP521X marshal.Y = ecdsaP521Y @@ -291,34 +289,34 @@ func TestUnmarshalECDSA(t *testing.T) { checkUnmarshal(jwk, marshalOptions, key) marshal.CRV = "" - _, err := jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) - if !errors.Is(err, jwkset.ErrKeyUnmarshalParameter) { + _, err := NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) + if !errors.Is(err, ErrKeyUnmarshalParameter) { t.Fatalf(`Should get ErrKeyUnmarshalParameter when parameter "crv" is empty. %s`, err) } marshal.CRV = "invalid" - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) - if !errors.Is(err, jwkset.ErrKeyUnmarshalParameter) { + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) + if !errors.Is(err, ErrKeyUnmarshalParameter) { t.Fatalf(`Should get ErrKeyUnmarshalParameter when parameter "crv" is invalid. %s`, err) } - marshal.CRV = jwkset.CrvP521 + marshal.CRV = CrvP521 marshal.D = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "d" is invalid raw Base64 URL. %s`, err) } marshal.D = ecdsaP521D marshal.X = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "x" is invalid raw Base64 URL. %s`, err) } marshal.X = ecdsaP521X marshal.Y = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "y" is invalid raw Base64 URL. %s`, err) } @@ -326,11 +324,11 @@ func TestUnmarshalECDSA(t *testing.T) { } func TestMarshalEdDSA(t *testing.T) { - checkJWK := func(marshal jwkset.JWKMarshal, options jwkset.JWKOptions) { - if marshal.ALG != jwkset.AlgEdDSA { + checkJWK := func(marshal JWKMarshal, options JWKOptions) { + if marshal.ALG != AlgEdDSA { t.Fatal(`Marshaled key parameter "alg" does not match original key.`) } - if marshal.CRV != jwkset.CrvEd25519 { + if marshal.CRV != CrvEd25519 { t.Fatal(`Marshaled key parameter "crv" does not match original key.`) } if options.Marshal.Private { @@ -342,7 +340,7 @@ func TestMarshalEdDSA(t *testing.T) { t.Fatalf("Asymmetric private key should be unsupported for given options.") } } - if marshal.KTY != jwkset.KtyOKP { + if marshal.KTY != KtyOKP { t.Fatal(`Marshaled key parameter "kty" does not match original key.`) } if marshal.X != eddsaPublic { @@ -351,15 +349,15 @@ func TestMarshalEdDSA(t *testing.T) { } private := makeEdDSA(t) - options := jwkset.JWKOptions{} - jwk, err := jwkset.NewJWKFromKey(private, options) + options := JWKOptions{} + jwk, err := NewJWKFromKey(private, options) if err != nil { t.Fatalf("Failed to marshal key with correct options. %s", err) } checkJWK(jwk.Marshal(), options) options.Marshal.Private = true - jwk, err = jwkset.NewJWKFromKey(private, options) + jwk, err = NewJWKFromKey(private, options) if err != nil { t.Fatalf("Failed to marshal key with correct options. %s", err) @@ -367,7 +365,7 @@ func TestMarshalEdDSA(t *testing.T) { checkJWK(jwk.Marshal(), options) options.Marshal.Private = false - jwk, err = jwkset.NewJWKFromKey(private.Public(), options) + jwk, err = NewJWKFromKey(private.Public(), options) if err != nil { t.Fatalf("Failed to marshal key with correct options. %s", err) } @@ -377,17 +375,17 @@ func TestMarshalEdDSA(t *testing.T) { func TestUnmarshalEdDSA(t *testing.T) { private := makeEdDSA(t) - marshal := jwkset.JWKMarshal{ - ALG: jwkset.AlgEdDSA, - CRV: jwkset.CrvEd25519, + marshal := JWKMarshal{ + ALG: AlgEdDSA, + CRV: CrvEd25519, D: eddsaPrivate, KID: myKeyID, - KTY: jwkset.KtyOKP, + KTY: KtyOKP, X: eddsaPublic, } - marshalOptions := jwkset.JWKMarshalOptions{} - jwk, err := jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + marshalOptions := JWKMarshalOptions{} + jwk, err := NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err != nil { t.Fatalf("Failed to unmarshal key with correct options. %s", err) } @@ -396,7 +394,7 @@ func TestUnmarshalEdDSA(t *testing.T) { } marshalOptions.Private = true - jwk, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + jwk, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err != nil { t.Fatalf("Failed to unmarshal key with correct options. %s", err) } @@ -408,56 +406,56 @@ func TestUnmarshalEdDSA(t *testing.T) { } marshal.D = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "d" is invalid raw Base64URL. %s`, err) } marshal.X = "" - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) - if !errors.Is(err, jwkset.ErrKeyUnmarshalParameter) { + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) + if !errors.Is(err, ErrKeyUnmarshalParameter) { t.Fatalf(`Should get ErrKeyUnmarshalParameter when parameter "x" is empty. %s`, err) } marshal.X = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "x" is invalid raw Base64URL. %s`, err) } marshal.CRV = "" - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) - if !errors.Is(err, jwkset.ErrKeyUnmarshalParameter) { + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) + if !errors.Is(err, ErrKeyUnmarshalParameter) { t.Fatalf(`Should get ErrKeyUnmarshalParameter when parameter "crv" is empty. %s`, err) } - marshal.CRV = jwkset.CrvEd25519 + marshal.CRV = CrvEd25519 invalidSize := base64.RawURLEncoding.EncodeToString([]byte("invalidSize")) marshal.X = invalidSize marshal.D = eddsaPrivate - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) - if !errors.Is(err, jwkset.ErrKeyUnmarshalParameter) { + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) + if !errors.Is(err, ErrKeyUnmarshalParameter) { t.Fatalf(`Should get ErrKeyUnmarshalParameter when parameter "x" is invalid size. %s`, err) } marshal.X = eddsaPublic marshal.D = invalidSize - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) - if !errors.Is(err, jwkset.ErrKeyUnmarshalParameter) { + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) + if !errors.Is(err, ErrKeyUnmarshalParameter) { t.Fatalf(`Should get ErrKeyUnmarshalParameter when parameter "d" is invalid size. %s`, err) } } func TestMarshalOct(t *testing.T) { key := []byte(hmacSecret) - options := jwkset.JWKOptions{} - _, err := jwkset.NewJWKFromKey(key, options) - if !errors.Is(err, jwkset.ErrOptions) { + options := JWKOptions{} + _, err := NewJWKFromKey(key, options) + if !errors.Is(err, ErrOptions) { t.Fatalf("Symmetric key should be unsupported for given options. %s", err) } options.Marshal.Private = true - jwk, err := jwkset.NewJWKFromKey(key, options) + jwk, err := NewJWKFromKey(key, options) if err != nil { t.Fatalf("Failed to marshal key with correct options. %s", err) } @@ -465,26 +463,26 @@ func TestMarshalOct(t *testing.T) { if jwk.Marshal().K != base64.RawURLEncoding.EncodeToString(jwk.Key().([]byte)) { t.Fatalf("Unmarshaled key does not match original key.") } - if jwk.Marshal().KTY != jwkset.KtyOct { + if jwk.Marshal().KTY != KtyOct { t.Fatalf("Key type does not match original key.") } } func TestUnmarshalOct(t *testing.T) { - marshal := jwkset.JWKMarshal{ + marshal := JWKMarshal{ K: base64.RawURLEncoding.EncodeToString([]byte(hmacSecret)), KID: myKeyID, - KTY: jwkset.KtyOct, + KTY: KtyOct, } - options := jwkset.JWKMarshalOptions{} - _, err := jwkset.NewJWKFromMarshal(marshal, options, jwkset.JWKValidateOptions{}) - if !errors.Is(err, jwkset.ErrOptions) { + options := JWKMarshalOptions{} + _, err := NewJWKFromMarshal(marshal, options, JWKValidateOptions{}) + if !errors.Is(err, ErrOptions) { t.Fatalf("Symmetric key should be unsupported for given options. %s", err) } options.Private = true - jwk, err := jwkset.NewJWKFromMarshal(marshal, options, jwkset.JWKValidateOptions{}) + jwk, err := NewJWKFromMarshal(marshal, options, JWKValidateOptions{}) if err != nil { t.Fatalf("Failed to unmarshal key with correct options. %s", err) } @@ -496,13 +494,13 @@ func TestUnmarshalOct(t *testing.T) { } marshal.K = "" - _, err = jwkset.NewJWKFromMarshal(marshal, options, jwkset.JWKValidateOptions{}) - if !errors.Is(err, jwkset.ErrKeyUnmarshalParameter) { + _, err = NewJWKFromMarshal(marshal, options, JWKValidateOptions{}) + if !errors.Is(err, ErrKeyUnmarshalParameter) { t.Fatalf(`Should get ErrKeyUnmarshalParameter when parameter "k" is empty. %s`, err) } marshal.K = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, options, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, options, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "k" is invalid raw Base64URL. %s`, err) } @@ -510,11 +508,11 @@ func TestUnmarshalOct(t *testing.T) { func TestMarshalRSA(t *testing.T) { private := makeRSA(t) - checkMarshal := func(marshal jwkset.JWKMarshal, options jwkset.JWKOptions) { + checkMarshal := func(marshal JWKMarshal, options JWKOptions) { if marshal.E != rsa2048E { t.Fatal(`Marshal parameter "e" does not match original key.`) } - if marshal.KTY != jwkset.KtyRSA { + if marshal.KTY != KtyRSA { t.Fatal(`Marshal parameter "kty" does not match original key.`) } if marshal.N != rsa2048N { @@ -594,22 +592,22 @@ func TestMarshalRSA(t *testing.T) { } } - options := jwkset.JWKOptions{} - jwk, err := jwkset.NewJWKFromKey(private, options) + options := JWKOptions{} + jwk, err := NewJWKFromKey(private, options) if err != nil { t.Fatalf("Failed to marshal key with correct options. %s", err) } checkMarshal(jwk.Marshal(), options) options.Marshal.Private = true - jwk, err = jwkset.NewJWKFromKey(private, options) + jwk, err = NewJWKFromKey(private, options) if err != nil { t.Fatalf("Failed to marshal key with correct options. %s", err) } checkMarshal(jwk.Marshal(), options) options.Marshal.Private = false - jwk, err = jwkset.NewJWKFromKey(&jwk.Key().(*rsa.PrivateKey).PublicKey, options) + jwk, err = NewJWKFromKey(&jwk.Key().(*rsa.PrivateKey).PublicKey, options) if err != nil { t.Fatalf("Failed to marshal key with correct options. %s", err) } @@ -617,7 +615,7 @@ func TestMarshalRSA(t *testing.T) { } func TestUnmarshalRSA(t *testing.T) { - checkJWK := func(jwk jwkset.JWK, options jwkset.JWKMarshalOptions, original *rsa.PrivateKey) { + checkJWK := func(jwk JWK, options JWKMarshalOptions, original *rsa.PrivateKey) { var public *rsa.PublicKey var ok bool if options.Private { @@ -670,17 +668,17 @@ func TestUnmarshalRSA(t *testing.T) { } private := makeRSA(t) - marshal := jwkset.JWKMarshal{ + marshal := JWKMarshal{ E: rsa2048E, D: rsa2048D, DP: rsa2048DP, DQ: rsa2048DQ, - KTY: jwkset.KtyRSA, + KTY: KtyRSA, N: rsa2048N, P: rsa2048P, Q: rsa2048Q, QI: rsa2048QI, - OTH: []jwkset.OtherPrimes{ + OTH: []OtherPrimes{ { D: rsa2048OthD1, R: rsa2048OthR1, @@ -699,104 +697,104 @@ func TestUnmarshalRSA(t *testing.T) { }, } - marshalOptions := jwkset.JWKMarshalOptions{} - jwk, err := jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + marshalOptions := JWKMarshalOptions{} + jwk, err := NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err != nil { t.Fatalf("Failed to unmarshal key with correct options. %s", err) } checkJWK(jwk, marshalOptions, private) marshalOptions.Private = true - jwk, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + jwk, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err != nil { t.Fatalf("Failed to unmarshal key with correct options. %s", err) } checkJWK(jwk, marshalOptions, private) marshal.N = "" - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatal(`Should get error when parameter "n" is empty.`) } marshal.N = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "n" is invalid raw Base64 URL. %s`, err) } marshal.N = rsa2048N marshal.E = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "e" is invalid raw Base64 URL. %s`, err) } marshal.E = rsa2048E marshal.D = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "d" is invalid raw Base64 URL. %s`, err) } marshal.D = rsa2048D marshal.DP = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "dp" is invalid raw Base64 URL. %s`, err) } marshal.DP = rsa2048DP marshal.DQ = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "dq" is invalid raw Base64 URL. %s`, err) } marshal.DQ = rsa2048DQ marshal.P = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "p" is invalid raw Base64 URL. %s`, err) } marshal.P = rsa2048P marshal.Q = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "q" is invalid raw Base64 URL. %s`, err) } marshal.Q = rsa2048Q marshal.QI = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "qi" is invalid raw Base64 URL. %s`, err) } marshal.QI = rsa2048QI marshal.OTH[0].D = "" - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) - if !errors.Is(err, jwkset.ErrKeyUnmarshalParameter) { + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) + if !errors.Is(err, ErrKeyUnmarshalParameter) { t.Fatalf(`Should get error when parameter "oth" "d" is empty. %s`, err) } marshal.OTH[0].D = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "oth" "d"" is invalid raw Base64 URL. %s`, err) } marshal.OTH[0].D = rsa2048OthD1 marshal.OTH[0].R = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "oth" "r"" is invalid raw Base64 URL. %s`, err) } marshal.OTH[0].R = rsa2048OthR1 marshal.OTH[0].T = invalidB64URL - _, err = jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) + _, err = NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err == nil { t.Fatalf(`Should get error when parameter "oth" "t"" is invalid raw Base64 URL. %s`, err) } @@ -804,20 +802,20 @@ func TestUnmarshalRSA(t *testing.T) { } func TestMarshalUnsupported(t *testing.T) { - _, err := jwkset.NewJWKFromMarshal(jwkset.JWKMarshal{}, jwkset.JWKMarshalOptions{}, jwkset.JWKValidateOptions{}) - if !errors.Is(err, jwkset.ErrUnsupportedKey) { + _, err := NewJWKFromMarshal(JWKMarshal{}, JWKMarshalOptions{}, JWKValidateOptions{}) + if !errors.Is(err, ErrUnsupportedKey) { t.Fatalf("Unsupported key type should be unsupported for given options. %s", err) } } func TestUnmarshalUnsupported(t *testing.T) { - marshal := jwkset.JWKMarshal{ + marshal := JWKMarshal{ KTY: "unsupported", } - marshalOptions := jwkset.JWKMarshalOptions{} - _, err := jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) - if !errors.Is(err, jwkset.ErrUnsupportedKey) { + marshalOptions := JWKMarshalOptions{} + _, err := NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) + if !errors.Is(err, ErrUnsupportedKey) { t.Fatalf("Unsupported key type should return ErrUnsupportedKey. %s", err) } } @@ -1018,15 +1016,15 @@ func makeRSA(t *testing.T) *rsa.PrivateKey { return private } -func newJWK(t *testing.T, key any, options jwkset.JWKOptions) jwkset.JWK { - jwk, err := jwkset.NewJWKFromKey(key, options) +func newJWK(t *testing.T, key any, options JWKOptions) JWK { + jwk, err := NewJWKFromKey(key, options) if err != nil { t.Fatalf("Failed to marshal key with correct options. %s", err) } return jwk } -func newJWKFromMarshal(t *testing.T, marshal jwkset.JWKMarshal, marshalOptions jwkset.JWKMarshalOptions) jwkset.JWK { - jwk, err := jwkset.NewJWKFromMarshal(marshal, marshalOptions, jwkset.JWKValidateOptions{}) +func newJWKFromMarshal(t *testing.T, marshal JWKMarshal, marshalOptions JWKMarshalOptions) JWK { + jwk, err := NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) if err != nil { t.Fatalf("Failed to marshal key with correct options. %s", err) } diff --git a/storage.go b/storage.go index 7d63ac6..9d26d0c 100644 --- a/storage.go +++ b/storage.go @@ -2,31 +2,51 @@ package jwkset import ( "context" + "encoding/json" "errors" "fmt" + "net/http" + "net/url" "slices" "sync" + "time" ) -// ErrKeyNotFound is returned by a Storage implementation when a key is not found. -var ErrKeyNotFound = errors.New("key not found") +var ( + // ErrKeyNotFound is returned by a Storage implementation when a key is not found. + ErrKeyNotFound = errors.New("key not found") + // ErrInvalidHTTPStatusCode is returned when the HTTP status code is invalid. + ErrInvalidHTTPStatusCode = errors.New("invalid HTTP status code") +) // Storage handles storage operations for a JWKSet. type Storage interface { - // DeleteKey deletes a key from the storage. It will return ok as true if the key was present for deletion. - DeleteKey(ctx context.Context, keyID string) (ok bool, err error) - - // ReadKey reads a key from the storage. If the key is not present, it returns ErrKeyNotFound. Any pointers returned + // KeyDelete deletes a key from the storage. It will return ok as true if the key was present for deletion. + KeyDelete(ctx context.Context, keyID string) (ok bool, err error) + // KeyRead reads a key from the storage. If the key is not present, it returns ErrKeyNotFound. Any pointers returned // should be considered read-only. - ReadKey(ctx context.Context, keyID string) (JWK, error) - - // SnapshotKeys reads a snapshot of all keys from storage. As with ReadKey, any pointers returned should be + KeyRead(ctx context.Context, keyID string) (JWK, error) + // KeyReadAll reads a snapshot of all keys from storage. As with ReadKey, any pointers returned should be // considered read-only. - SnapshotKeys(ctx context.Context) ([]JWK, error) - - // WriteKey writes a key to the storage. If the key already exists, it will be overwritten. After writing a key, + KeyReadAll(ctx context.Context) ([]JWK, error) + // KeyWrite writes a key to the storage. If the key already exists, it will be overwritten. After writing a key, // any pointers written should be considered owned by the underlying storage. - WriteKey(ctx context.Context, jwk JWK) error + KeyWrite(ctx context.Context, jwk JWK) error + + // JSON creates the JSON representation of the JWKSet. + JSON(ctx context.Context) (json.RawMessage, error) + // JSONPublic creates the JSON representation of the public keys in JWKSet. + JSONPublic(ctx context.Context) (json.RawMessage, error) + // JSONPrivate creates the JSON representation of the JWKSet public and private key material. + JSONPrivate(ctx context.Context) (json.RawMessage, error) + // JSONWithOptions creates the JSON representation of the JWKSet with the given options. These options override whatever + // options are set on the individual JWKs. + JSONWithOptions(ctx context.Context, marshalOptions JWKMarshalOptions, validationOptions JWKValidateOptions) (json.RawMessage, error) + // Marshal transforms the JWK Set's current state into a Go type that can be marshaled into JSON. + Marshal(ctx context.Context) (JWKSMarshal, error) + // MarshalWithOptions transforms the JWK Set's current state into a Go type that can be marshaled into JSON with the + // given options. These options override whatever options are set on the individual JWKs. + MarshalWithOptions(ctx context.Context, marshalOptions JWKMarshalOptions, validationOptions JWKValidateOptions) (JWKSMarshal, error) } var _ Storage = &memoryJWKSet{} @@ -41,12 +61,7 @@ func NewMemoryStorage() Storage { return &memoryJWKSet{} } -func (m *memoryJWKSet) SnapshotKeys(_ context.Context) ([]JWK, error) { - m.mux.RLock() - defer m.mux.RUnlock() - return slices.Clone(m.set), nil -} -func (m *memoryJWKSet) DeleteKey(_ context.Context, keyID string) (ok bool, err error) { +func (m *memoryJWKSet) KeyDelete(_ context.Context, keyID string) (ok bool, err error) { m.mux.Lock() defer m.mux.Unlock() for i, jwk := range m.set { @@ -57,7 +72,7 @@ func (m *memoryJWKSet) DeleteKey(_ context.Context, keyID string) (ok bool, err } return ok, nil } -func (m *memoryJWKSet) ReadKey(_ context.Context, keyID string) (JWK, error) { +func (m *memoryJWKSet) KeyRead(_ context.Context, keyID string) (JWK, error) { m.mux.RLock() defer m.mux.RUnlock() for _, jwk := range m.set { @@ -67,7 +82,12 @@ func (m *memoryJWKSet) ReadKey(_ context.Context, keyID string) (JWK, error) { } return JWK{}, fmt.Errorf("%w: kid %q", ErrKeyNotFound, keyID) } -func (m *memoryJWKSet) WriteKey(_ context.Context, jwk JWK) error { +func (m *memoryJWKSet) KeyReadAll(_ context.Context) ([]JWK, error) { + m.mux.RLock() + defer m.mux.RUnlock() + return slices.Clone(m.set), nil +} +func (m *memoryJWKSet) KeyWrite(_ context.Context, jwk JWK) error { m.mux.Lock() defer m.mux.Unlock() for i, j := range m.set { @@ -79,3 +99,205 @@ func (m *memoryJWKSet) WriteKey(_ context.Context, jwk JWK) error { m.set = append(m.set, jwk) return nil } + +func (m *memoryJWKSet) JSON(ctx context.Context) (json.RawMessage, error) { + jwks, err := m.Marshal(ctx) + if err != nil { + return nil, fmt.Errorf("failed to marshal JWK Set: %w", err) + } + return json.Marshal(jwks) +} +func (m *memoryJWKSet) JSONPublic(ctx context.Context) (json.RawMessage, error) { + return m.JSONWithOptions(ctx, JWKMarshalOptions{}, JWKValidateOptions{}) +} +func (m *memoryJWKSet) JSONPrivate(ctx context.Context) (json.RawMessage, error) { + marshalOptions := JWKMarshalOptions{ + Private: true, + } + return m.JSONWithOptions(ctx, marshalOptions, JWKValidateOptions{}) +} +func (m *memoryJWKSet) JSONWithOptions(ctx context.Context, marshalOptions JWKMarshalOptions, validationOptions JWKValidateOptions) (json.RawMessage, error) { + jwks, err := m.MarshalWithOptions(ctx, marshalOptions, validationOptions) + if err != nil { + return nil, fmt.Errorf("failed to marshal JWK Set with options: %w", err) + } + return json.Marshal(jwks) +} +func (m *memoryJWKSet) Marshal(ctx context.Context) (JWKSMarshal, error) { + keys, err := m.KeyReadAll(ctx) + if err != nil { + return JWKSMarshal{}, fmt.Errorf("failed to read snapshot of all keys from storage: %w", err) + } + jwks := JWKSMarshal{} + for _, key := range keys { + jwks.Keys = append(jwks.Keys, key.Marshal()) + } + return jwks, nil +} +func (m *memoryJWKSet) MarshalWithOptions(ctx context.Context, marshalOptions JWKMarshalOptions, validationOptions JWKValidateOptions) (JWKSMarshal, error) { + jwks := JWKSMarshal{} + + keys, err := m.KeyReadAll(ctx) + if err != nil { + return JWKSMarshal{}, fmt.Errorf("failed to read snapshot of all keys from storage: %w", err) + } + + for _, key := range keys { + options := key.options + options.Marshal = marshalOptions + options.Validate = validationOptions + marshal, err := keyMarshal(key.Key(), options) + if err != nil { + if errors.Is(err, ErrOptions) { + continue + } + return JWKSMarshal{}, fmt.Errorf("failed to marshal key: %w", err) + } + jwks.Keys = append(jwks.Keys, marshal) + } + + return jwks, nil +} + +// HTTPClientStorageOptions are used to configure the behavior of NewStorageFromHTTP. +type HTTPClientStorageOptions struct { + // Client is the HTTP client to use for requests. + // + // This defaults to http.DefaultClient. + Client *http.Client + + // Ctx is used when performing HTTP requests. It is also used to end the refresh goroutine when it's no longer + // needed. + // + // This defaults to context.Background(). + Ctx context.Context + + // HTTPExpectedStatus is the expected HTTP status code for the HTTP request. + // + // This defaults to http.StatusOK. + HTTPExpectedStatus int + + // HTTPMethod is the HTTP method to use for the HTTP request. + // + // This defaults to http.MethodGet. + HTTPMethod string + + // HTTPTimeout is the timeout for the HTTP request. When the Ctx option is also provided, this value is used for a + // child context. + // + // This defaults to time.Minute. + HTTPTimeout time.Duration + + // NoErrorReturnFirstHTTPReq will create the Storage without error if the first HTTP request fails. + NoErrorReturnFirstHTTPReq bool + + // RefreshErrorHandler is a function that consumes errors that happen during an HTTP refresh. This is only effectual + // if RefreshInterval is set. + // + // If NoErrorReturnFirstHTTPReq is set, this function will be called when if the first HTTP request fails. + RefreshErrorHandler func(ctx context.Context, err error) + + // RefreshInterval is the interval at which the HTTP URL is refreshed and the JWK Set is processed. This option will + // launch a "refresh goroutine" to refresh the remote HTTP resource at the given interval. + // + // Provide the Ctx option to end the goroutine when it's no longer needed. + RefreshInterval time.Duration + + // Storage is the underlying storage implementation to use. + // + // This defaults to NewMemoryStorage(). + Storage Storage +} + +// NewStorageFromHTTP creates a new Storage implementation that processes a remote HTTP resource for a JWK Set. If +// the RefreshInterval option is not set, the remote HTTP resource will be requested and processed before returning. If +// the RefreshInterval option is set, a background goroutine will be launched to refresh the remote HTTP resource and +// not block the return of this function. +func NewStorageFromHTTP(u *url.URL, options HTTPClientStorageOptions) (Storage, error) { + if options.Client == nil { + options.Client = http.DefaultClient + } + if options.Ctx == nil { + options.Ctx = context.Background() + } + if options.HTTPExpectedStatus == 0 { + options.HTTPExpectedStatus = http.StatusOK + } + if options.HTTPTimeout == 0 { + options.HTTPTimeout = time.Minute + } + if options.HTTPMethod == "" { + options.HTTPMethod = http.MethodGet + } + store := options.Storage + if store == nil { + store = NewMemoryStorage() + } + + refresh := func(ctx context.Context) error { + req, err := http.NewRequestWithContext(ctx, options.HTTPMethod, u.String(), nil) + if err != nil { + return fmt.Errorf("failed to create HTTP request for JWK Set refresh: %w", err) + } + resp, err := options.Client.Do(req) + if err != nil { + return fmt.Errorf("failed to perform HTTP request for JWK Set refresh: %w", err) + } + //goland:noinspection GoUnhandledErrorResult + defer resp.Body.Close() + if resp.StatusCode != options.HTTPExpectedStatus { + return fmt.Errorf("%w: %d", ErrInvalidHTTPStatusCode, resp.StatusCode) + } + var jwks JWKSMarshal + err = json.NewDecoder(resp.Body).Decode(&jwks) + if err != nil { + return fmt.Errorf("failed to decode JWK Set response: %w", err) + } + for _, marshal := range jwks.Keys { + marshalOptions := JWKMarshalOptions{ + Private: true, + } + jwk, err := NewJWKFromMarshal(marshal, marshalOptions, JWKValidateOptions{}) + if err != nil { + return fmt.Errorf("failed to create JWK from JWK Marshal: %w", err) + } + err = store.KeyWrite(options.Ctx, jwk) + if err != nil { + return fmt.Errorf("failed to write JWK to memory storage: %w", err) + } + } + return nil + } + + ctx, cancel := context.WithTimeout(options.Ctx, options.HTTPTimeout) + defer cancel() + err := refresh(ctx) + cancel() + if err != nil { + if options.NoErrorReturnFirstHTTPReq { + options.RefreshErrorHandler(ctx, err) + return store, nil + } + return nil, fmt.Errorf("failed to perform first HTTP request for JWK Set: %w", err) + } + + go func() { // Refresh goroutine. + ticker := time.NewTicker(options.RefreshInterval) + defer ticker.Stop() + for { + select { + case <-options.Ctx.Done(): + return + case <-ticker.C: + ctx, cancel = context.WithTimeout(options.Ctx, options.HTTPTimeout) + err = refresh(ctx) + cancel() + if err != nil && options.RefreshErrorHandler != nil { + options.RefreshErrorHandler(ctx, err) + } + } + } + }() + + return store, nil +} diff --git a/storage_test.go b/storage_test.go index 9f98630..d57360d 100644 --- a/storage_test.go +++ b/storage_test.go @@ -1,4 +1,4 @@ -package jwkset_test +package jwkset import ( "bytes" @@ -6,8 +6,6 @@ import ( "errors" "testing" "time" - - "github.com/MicahParks/jwkset" ) const ( @@ -24,21 +22,21 @@ var ( type storageTestParams struct { ctx context.Context cancel context.CancelFunc - jwks jwkset.JWKSet + jwks Storage } -func TestMemoryDeleteKey(t *testing.T) { +func TestMemoryKeyDelete(t *testing.T) { params := setupMemory() defer params.cancel() - store := params.jwks.Store + store := params.jwks jwk := newStorageTestJWK(t, hmacKey1, kidWritten) - err := store.WriteKey(params.ctx, jwk) + err := store.KeyWrite(params.ctx, jwk) if err != nil { t.Fatalf("Failed to write key. %s", err) } - ok, err := store.DeleteKey(params.ctx, kidMissing) + ok, err := store.KeyDelete(params.ctx, kidMissing) if err != nil { t.Fatalf("Failed to delete missing key. %s", err) } @@ -46,7 +44,7 @@ func TestMemoryDeleteKey(t *testing.T) { t.Fatalf("Deleted missing key.") } - ok, err = store.DeleteKey(params.ctx, kidWritten) + ok, err = store.KeyDelete(params.ctx, kidWritten) if err != nil { t.Fatalf("Failed to delete written key. %s", err) } @@ -55,23 +53,23 @@ func TestMemoryDeleteKey(t *testing.T) { } } -func TestMemoryReadKey(t *testing.T) { +func TestMemoryKeyRead(t *testing.T) { params := setupMemory() defer params.cancel() - store := params.jwks.Store + store := params.jwks jwk := newStorageTestJWK(t, hmacKey1, kidWritten) - err := store.WriteKey(params.ctx, jwk) + err := store.KeyWrite(params.ctx, jwk) if err != nil { t.Fatalf("Failed to write key. %s", err) } - _, err = store.ReadKey(params.ctx, kidMissing) - if !errors.Is(err, jwkset.ErrKeyNotFound) { - t.Fatalf("Should have specific error when reading missing key.\n Actual: %s\n Expected: %s", err, jwkset.ErrKeyNotFound) + _, err = store.KeyRead(params.ctx, kidMissing) + if !errors.Is(err, ErrKeyNotFound) { + t.Fatalf("Should have specific error when reading missing key.\n Actual: %s\n Expected: %s", err, ErrKeyNotFound) } - key, err := store.ReadKey(params.ctx, kidWritten) + key, err := store.KeyRead(params.ctx, kidWritten) if err != nil { t.Fatalf("Failed to read written key. %s", err) } @@ -81,12 +79,12 @@ func TestMemoryReadKey(t *testing.T) { } jwk = newStorageTestJWK(t, hmacKey2, kidWritten) - err = store.WriteKey(params.ctx, jwk) + err = store.KeyWrite(params.ctx, jwk) if err != nil { t.Fatalf("Failed to overwrite key. %s", err) } - key, err = store.ReadKey(params.ctx, kidWritten) + key, err = store.KeyRead(params.ctx, kidWritten) if err != nil { t.Fatalf("Failed to read written key. %s", err) } @@ -95,35 +93,35 @@ func TestMemoryReadKey(t *testing.T) { t.Fatalf("Read key does not match written key.") } - _, err = store.DeleteKey(params.ctx, kidWritten) + _, err = store.KeyDelete(params.ctx, kidWritten) if err != nil { t.Fatalf("Failed to delete written key. %s", err) } - _, err = store.ReadKey(params.ctx, kidWritten) - if !errors.Is(err, jwkset.ErrKeyNotFound) { - t.Fatalf("Should have specific error when reading missing key.\n Actual: %s\n Expected: %s", err, jwkset.ErrKeyNotFound) + _, err = store.KeyRead(params.ctx, kidWritten) + if !errors.Is(err, ErrKeyNotFound) { + t.Fatalf("Should have specific error when reading missing key.\n Actual: %s\n Expected: %s", err, ErrKeyNotFound) } } -func TestMemorySnapshotKeys(t *testing.T) { +func TestMemoryKeyReadAll(t *testing.T) { params := setupMemory() defer params.cancel() - store := params.jwks.Store + store := params.jwks jwk := newStorageTestJWK(t, hmacKey1, kidWritten) - err := store.WriteKey(params.ctx, jwk) + err := store.KeyWrite(params.ctx, jwk) if err != nil { t.Fatalf("Failed to write key 1. %s", err) } jwk = newStorageTestJWK(t, hmacKey2, kidWritten2) - err = store.WriteKey(params.ctx, jwk) + err = store.KeyWrite(params.ctx, jwk) if err != nil { t.Fatalf("Failed to write key 2. %s", err) } - keys, err := store.SnapshotKeys(params.ctx) + keys, err := store.KeyReadAll(params.ctx) if err != nil { t.Fatalf("Failed to snapshot keys. %s", err) } @@ -150,26 +148,26 @@ func TestMemorySnapshotKeys(t *testing.T) { } } -func TestMemoryWriteKey(t *testing.T) { +func TestMemoryKeyWrite(t *testing.T) { params := setupMemory() defer params.cancel() - store := params.jwks.Store + store := params.jwks jwk := newStorageTestJWK(t, hmacKey1, kidWritten) - err := store.WriteKey(params.ctx, jwk) + err := store.KeyWrite(params.ctx, jwk) if err != nil { t.Fatalf("Failed to write key. %s", err) } jwk = newStorageTestJWK(t, hmacKey2, kidWritten) - err = store.WriteKey(params.ctx, jwk) + err = store.KeyWrite(params.ctx, jwk) if err != nil { t.Fatalf("Failed to overwrite key. %s", err) } } func setupMemory() (params storageTestParams) { - jwkSet := jwkset.NewMemory() + jwkSet := NewMemoryStorage() ctx, cancel := context.WithTimeout(context.Background(), time.Second) params = storageTestParams{ ctx: ctx, @@ -179,18 +177,18 @@ func setupMemory() (params storageTestParams) { return params } -func newStorageTestJWK(t *testing.T, key any, keyID string) jwkset.JWK { - marshal := jwkset.JWKMarshalOptions{ +func newStorageTestJWK(t *testing.T, key any, keyID string) JWK { + marshal := JWKMarshalOptions{ Private: true, } - metadata := jwkset.JWKMetadataOptions{ + metadata := JWKMetadataOptions{ KID: keyID, } - options := jwkset.JWKOptions{ + options := JWKOptions{ Marshal: marshal, Metadata: metadata, } - jwk, err := jwkset.NewJWKFromKey(key, options) + jwk, err := NewJWKFromKey(key, options) if err != nil { t.Fatalf("Failed to create JWK. %s", err) } diff --git a/website/handle/template/inspect.go b/website/handle/template/inspect.go index 1d2aa42..27a217b 100644 --- a/website/handle/template/inspect.go +++ b/website/handle/template/inspect.go @@ -22,7 +22,7 @@ func (i *Inspect) ApplyMiddleware(h http.Handler) http.Handler { cache := middleware.CreateCacheControl(middleware.CacheDefaults) return cache(middleware.EncodeGzip(h)) } -func (i *Inspect) Authorize(w http.ResponseWriter, r *http.Request) (authorized bool, modified *http.Request, skipTemplate bool) { +func (i *Inspect) Authorize(_ http.ResponseWriter, r *http.Request) (authorized bool, modified *http.Request, skipTemplate bool) { return true, r, false } func (i *Inspect) Initialize(s server.Server) error {