Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(authn): improve jwks refresh retry mechanism #1319

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 45 additions & 19 deletions internal/authn/oidc/authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ type Authn struct {
backoffFrequency time.Duration

// Global backoff state
globalRetryCount int
globalFirstSeen time.Time
mu sync.Mutex
globalRetryCount int
globalFirstSeen time.Time
globalRetryKeyIds map[string]bool
mu sync.Mutex
}

// NewOidcAuthn initializes a new instance of the Authn struct with OpenID Connect (OIDC) configuration.
Expand Down Expand Up @@ -98,6 +99,7 @@ func NewOidcAuthn(ctx context.Context, conf config.Oidc) (*Authn, error) {
backoffMaxRetries: backoffMaxRetries,
backoffFrequency: backoffFrequency,
globalRetryCount: 0,
globalRetryKeyIds: make(map[string]bool),
globalFirstSeen: time.Time{},
mu: sync.Mutex{},
}
Expand Down Expand Up @@ -203,6 +205,7 @@ func (oidc *Authn) getKeyWithRetry(keyID string, ctx context.Context) (interface
slog.Info("resetting state as interval has passed or first seen is zero", "keyID", keyID)
oidc.globalFirstSeen = now
oidc.globalRetryCount = 0
oidc.globalRetryKeyIds = make(map[string]bool)
} else if oidc.globalRetryCount >= oidc.backoffMaxRetries {
// If max retries reached within the interval, unlock and check keyID once
slog.Warn("max retries reached within interval, will check keyID once", "keyID", keyID)
Expand All @@ -211,11 +214,16 @@ func (oidc *Authn) getKeyWithRetry(keyID string, ctx context.Context) (interface
// Try to fetch the keyID once
rawKey, err = oidc.fetchKey(keyID, ctx)
if err == nil {
// Reset global backoff state if a valid key is found
slog.Info("valid key found during backoff period, resetting state", "keyID", keyID)
oidc.mu.Lock()
oidc.globalRetryCount = 0
oidc.globalFirstSeen = time.Time{}
if _, ok := oidc.globalRetryKeyIds[keyID]; ok {
// Reset global backoff state if a valid key is found and that key had been retried.
// Use case would be someone trying to exploit with bad KeyIDs, and along comes a valid KeyID
// The valid KeyID should not reset the counters for a bad key
slog.Info("valid key found during backoff period, resetting state", "keyID", keyID)
oidc.globalRetryCount = 0
oidc.globalFirstSeen = time.Time{}
oidc.globalRetryKeyIds = make(map[string]bool)
}
oidc.mu.Unlock()
return rawKey, nil
}
Expand All @@ -229,6 +237,26 @@ func (oidc *Authn) getKeyWithRetry(keyID string, ctx context.Context) (interface
// Retry mechanism
retries := 0
for retries <= oidc.backoffMaxRetries {
rawKey, err = oidc.fetchKey(keyID, ctx)
if err == nil {
if retries != 0 {
oidc.mu.Lock()
oidc.globalRetryCount = 0
oidc.globalFirstSeen = time.Time{}
oidc.globalRetryKeyIds = make(map[string]bool)
oidc.mu.Unlock()
}
return rawKey, nil
}
oidc.mu.Lock()
initialGlobalRetryCount := oidc.globalRetryCount
oidc.globalRetryKeyIds[keyID] = true
if oidc.globalRetryCount > oidc.backoffMaxRetries {
slog.Error("key ID not found in JWKS due to global retries", "keyID", keyID, "globalRetryCount", oidc.globalRetryCount)
oidc.mu.Unlock()
return nil, errors.New("too many attempts, backoff in effect due to global retry count")
}
oidc.mu.Unlock()
if retries > 0 {
select {
case <-time.After(oidc.backoffFrequency):
Expand All @@ -240,28 +268,26 @@ func (oidc *Authn) getKeyWithRetry(keyID string, ctx context.Context) (interface
}
}

rawKey, err = oidc.fetchKey(keyID, ctx)
if err == nil {
// Log the successful key fetch and reset global state
slog.Info("successfully fetched key", "keyID", keyID)
oidc.mu.Lock()
oidc.globalRetryCount = 0
oidc.globalFirstSeen = time.Time{}
oidc.mu.Lock()
if oidc.globalRetryCount > initialGlobalRetryCount {
// Concurrent requests in retry loop at same time, another concurrent request already refreshed the JWKS
retries++
slog.Warn("another concurrent request already refreshed the JWKS")
oidc.mu.Unlock()
return rawKey, nil
continue
}

oidc.globalRetryCount++
slog.Warn("retrying to fetch JWKS due to error", "keyID", keyID, "retries", retries, "error", err)
retries++

oidc.mu.Lock()
oidc.globalRetryCount++
oidc.mu.Unlock()

if _, refreshErr := oidc.jwksSet.Refresh(ctx, oidc.JwksURI); refreshErr != nil {
oidc.mu.Unlock()
slog.Error("failed to refresh JWKS", "error", refreshErr)
return nil, refreshErr
}
// Unlock needs to follow Refresh to ensure that concurrent requests don't make duplicate calls to Refresh
oidc.mu.Unlock()
}

// Mark the global state to prevent further retries for the backoff interval
Expand Down
Loading
Loading