From b28bc17d72e1cdfb37cd0c3df77d205983afd31b Mon Sep 17 00:00:00 2001 From: hackerman <3372410+aeneasr@users.noreply.github.com> Date: Tue, 5 Nov 2024 13:27:25 +0100 Subject: [PATCH] =?UTF-8?q?Revert=20"fix:=20cpu=20contention=20when=20read?= =?UTF-8?q?ing=20JWKs=20and=20suppress=20generating=20duplica=E2=80=A6"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit d5f65c570a02a999867f323630d2a5d099526054. --- cmd/server/helper_cert.go | 2 +- jwk/handler.go | 13 +++++-- jwk/helper.go | 78 ++++++++++++++++++++------------------- jwk/helper_test.go | 23 ++++++------ jwk/jwt_strategy.go | 3 +- 5 files changed, 63 insertions(+), 56 deletions(-) diff --git a/cmd/server/helper_cert.go b/cmd/server/helper_cert.go index e2012292be2..6cef67bc362 100644 --- a/cmd/server/helper_cert.go +++ b/cmd/server/helper_cert.go @@ -58,7 +58,7 @@ func GetOrCreateTLSCertificate(ctx context.Context, d driver.Registry, iface con } // no certificates configured: self-sign a new cert - priv, err := jwk.GetOrGenerateKeySetPrivateKey(ctx, d, d.SoftwareKeyManager(), TlsKeyName, uuid.Must(uuid.NewV4()).String(), "RS256") + priv, err := jwk.GetOrGenerateKeys(ctx, d, d.SoftwareKeyManager(), TlsKeyName, uuid.Must(uuid.NewV4()).String(), "RS256") if err != nil { d.Logger().WithError(err).Fatal("Unable to fetch or generate HTTPS TLS key pair") return nil // in case Fatal is hooked diff --git a/jwk/handler.go b/jwk/handler.go index 5e12f9f31ad..7d48445321e 100644 --- a/jwk/handler.go +++ b/jwk/handler.go @@ -13,6 +13,7 @@ import ( "github.com/ory/x/httprouterx" "github.com/gofrs/uuid" + "github.com/pkg/errors" "github.com/ory/x/urlx" @@ -100,11 +101,17 @@ func (h *Handler) discoverJsonWebKeys(w http.ResponseWriter, r *http.Request) { for _, set := range wellKnownKeys { set := set eg.Go(func() error { - keySet, err := GetOrGenerateKeySet(ctx, h.r, h.r.KeyManager(), set, uuid.Must(uuid.NewV4()).String(), string(jose.RS256)) - if err != nil { + k, err := h.r.KeyManager().GetKeySet(ctx, set) + if errors.Is(err, x.ErrNotFound) { + h.r.Logger().Warnf("JSON Web Key Set %q does not exist yet, generating new key pair...", set) + k, err = h.r.KeyManager().GenerateAndPersistKeySet(ctx, set, uuid.Must(uuid.NewV4()).String(), string(jose.RS256), "sig") + if err != nil { + return err + } + } else if err != nil { return err } - keys <- ExcludePrivateKeys(keySet) + keys <- ExcludePrivateKeys(k) return nil }) } diff --git a/jwk/helper.go b/jwk/helper.go index 194dbffd3fe..50f3a28b2d2 100644 --- a/jwk/helper.go +++ b/jwk/helper.go @@ -12,67 +12,69 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" - - "golang.org/x/sync/singleflight" + "sync" hydra "github.com/ory/hydra-client-go/v2" - "github.com/ory/hydra/v2/x" "github.com/ory/x/josex" "github.com/ory/x/errorsx" + "github.com/ory/hydra/v2/x" + jose "github.com/go-jose/go-jose/v3" "github.com/pkg/errors" ) -var jwkGenFlightGroup singleflight.Group +var mapLock sync.RWMutex +var locks = map[string]*sync.RWMutex{} + +func getLock(set string) *sync.RWMutex { + mapLock.Lock() + defer mapLock.Unlock() + if _, ok := locks[set]; !ok { + locks[set] = new(sync.RWMutex) + } + return locks[set] +} func EnsureAsymmetricKeypairExists(ctx context.Context, r InternalRegistry, alg, set string) error { - _, err := GetOrGenerateKeySetPrivateKey(ctx, r, r.KeyManager(), set, set, alg) + _, err := GetOrGenerateKeys(ctx, r, r.KeyManager(), set, set, alg) return err } -func GetOrGenerateKeySetPrivateKey(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (private *jose.JSONWebKey, err error) { - keySet, err := GetOrGenerateKeySet(ctx, r, m, set, kid, alg) - if err != nil { - return nil, err - } - - privKey, err := FindPrivateKey(keySet) - if err == nil { - return privKey, nil - } - - keySet, err = generateKeySet(ctx, r, m, set, kid, alg) - if err != nil { - return nil, err - } - - return FindPrivateKey(keySet) -} +func GetOrGenerateKeys(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (private *jose.JSONWebKey, err error) { + getLock(set).Lock() + defer getLock(set).Unlock() -func GetOrGenerateKeySet(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (*jose.JSONWebKeySet, error) { keys, err := m.GetKeySet(ctx, set) - if err != nil && !errors.Is(err, x.ErrNotFound) { + if errors.Is(err, x.ErrNotFound) || keys != nil && len(keys.Keys) == 0 { + r.Logger().Warnf("JSON Web Key Set \"%s\" does not exist yet, generating new key pair...", set) + keys, err = m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig") + if err != nil { + return nil, err + } + } else if err != nil { return nil, err - } else if keys != nil && len(keys.Keys) > 0 { - return keys, nil } - return generateKeySet(ctx, r, m, set, kid, alg) -} - -func generateKeySet(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (*jose.JSONWebKeySet, error) { - // Suppress duplicate key set generation jobs where the set+alg match. - keysResult, err, _ := jwkGenFlightGroup.Do(set+alg, func() (any, error) { + privKey, privKeyErr := FindPrivateKey(keys) + if privKeyErr == nil { + return privKey, nil + } else { r.Logger().WithField("jwks", set).Warnf("JSON Web Key not found in JSON Web Key Set %s, generating new key pair...", set) - return m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig") - }) - if err != nil { - return nil, err + + keys, err = m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig") + if err != nil { + return nil, err + } + + privKey, err := FindPrivateKey(keys) + if err != nil { + return nil, err + } + return privKey, nil } - return keysResult.(*jose.JSONWebKeySet), nil } func First(keys []jose.JSONWebKey) *jose.JSONWebKey { diff --git a/jwk/helper_test.go b/jwk/helper_test.go index c4f4d6e18b5..c1a5ee46387 100644 --- a/jwk/helper_test.go +++ b/jwk/helper_test.go @@ -27,11 +27,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/ory/x/contextx" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/x" + "github.com/ory/x/contextx" ) type fakeSigner struct { @@ -227,46 +226,46 @@ func TestGetOrGenerateKeys(t *testing.T) { return NewMockManager(ctrl) } - t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GetKeySetError", func(t *testing.T) { + t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GetKeySetError", func(t *testing.T) { keyManager := km(t) keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(nil, errors.New("GetKeySetError")) - privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256") + privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") assert.Nil(t, privKey) assert.EqualError(t, err, "GetKeySetError") }) - t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GenerateAndPersistKeySetError", func(t *testing.T) { + t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySetError", func(t *testing.T) { keyManager := km(t) keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(nil, errors.Wrap(x.ErrNotFound, "")) keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(nil, errors.New("GetKeySetError")) - privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256") + privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") assert.Nil(t, privKey) assert.EqualError(t, err, "GetKeySetError") }) - t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GenerateAndPersistKeySetError", func(t *testing.T) { + t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySetError", func(t *testing.T) { keyManager := km(t) keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil) keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(nil, errors.New("GetKeySetError")) - privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256") + privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") assert.Nil(t, privKey) assert.EqualError(t, err, "GetKeySetError") }) - t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GetKeySet_ContainsMissingPrivateKey", func(t *testing.T) { + t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GetKeySet_ContainsMissingPrivateKey", func(t *testing.T) { keyManager := km(t) keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil) keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(keySet, nil) - privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256") + privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") assert.NoError(t, err) assert.Equal(t, privKey, &keySet.Keys[0]) }) - t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GenerateAndPersistKeySet_ContainsMissingPrivateKey", func(t *testing.T) { + t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySet_ContainsMissingPrivateKey", func(t *testing.T) { keyManager := km(t) keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil) keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(keySetWithoutPrivateKey, nil).Times(1) - privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256") + privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") assert.Nil(t, privKey) assert.EqualError(t, err, "key not found") }) diff --git a/jwk/jwt_strategy.go b/jwk/jwt_strategy.go index 9fdd6c48374..6154066459b 100644 --- a/jwk/jwt_strategy.go +++ b/jwk/jwt_strategy.go @@ -13,7 +13,6 @@ import ( "github.com/gofrs/uuid" "github.com/ory/fosite" - "github.com/ory/hydra/v2/driver/config" "github.com/pkg/errors" @@ -41,7 +40,7 @@ func NewDefaultJWTSigner(c *config.DefaultProvider, r InternalRegistry, setID st } func (j *DefaultJWTSigner) getKeys(ctx context.Context) (private *jose.JSONWebKey, err error) { - private, err = GetOrGenerateKeySetPrivateKey(ctx, j.r, j.r.KeyManager(), j.setID, uuid.Must(uuid.NewV4()).String(), string(jose.RS256)) + private, err = GetOrGenerateKeys(ctx, j.r, j.r.KeyManager(), j.setID, uuid.Must(uuid.NewV4()).String(), string(jose.RS256)) if err == nil { return private, nil }