Skip to content

Commit

Permalink
fix new listener failures
Browse files Browse the repository at this point in the history
Signed-off-by: clyang82 <[email protected]>
  • Loading branch information
clyang82 committed Dec 16, 2024
1 parent b70bfd7 commit 4824514
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 18 deletions.
54 changes: 37 additions & 17 deletions pkg/db/db_session/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (f *Default) Init(config *config.DatabaseConfig) {
))
}

dbx = stdlib.OpenDB(*connConfig, stdlib.OptionBeforeConnect(f.setPassword()))
dbx = stdlib.OpenDB(*connConfig, stdlib.OptionBeforeConnect(setPassword(config)))
dbx.SetMaxOpenConns(config.MaxOpenConnections)

// Connect GORM to use the same connection
Expand Down Expand Up @@ -94,32 +94,39 @@ func (f *Default) Init(config *config.DatabaseConfig) {
})
}

func (f *Default) setPassword() func(ctx context.Context, connConfig *pgx.ConnConfig) error {

func setPassword(dbConfig *config.DatabaseConfig) func(ctx context.Context, connConfig *pgx.ConnConfig) error {
return func(ctx context.Context, connConfig *pgx.ConnConfig) error {
if f.config.AuthMethod == constants.AuthMethodPassword {
connConfig.Password = f.config.Password
if dbConfig.AuthMethod == constants.AuthMethodPassword {
connConfig.Password = dbConfig.Password
return nil
} else if f.config.AuthMethod == constants.AuthMethodMicrosoftEntra {
if isExpired(f.config.Token) {
// ARO-HCP environment variable configuration is set by the Azure workload identity webhook.
// Use [WorkloadIdentityCredential] directly when not using the webhook or needing more control over its configuration.
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return err
}
token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: []string{f.config.TokenRequestScope}})
} else if dbConfig.AuthMethod == constants.AuthMethodMicrosoftEntra {
if isExpired(dbConfig.Token) {
token, err := getAccessToken(ctx, dbConfig)
if err != nil {
return err
}
connConfig.Password = token.Token
f.config.Token = &token
dbConfig.Token = token
}
}
return nil
}
}

func getAccessToken(ctx context.Context, dbConfig *config.DatabaseConfig) (*azcore.AccessToken, error) {
// ARO-HCP environment variable configuration is set by the Azure workload identity webhook.
// Use [WorkloadIdentityCredential] directly when not using the webhook or needing more control over its configuration.
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return nil, err
}
token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: []string{dbConfig.TokenRequestScope}})
if err != nil {
return nil, err
}
return &token, nil
}

func isExpired(accessToken *azcore.AccessToken) bool {
return accessToken == nil ||
time.Until(accessToken.ExpiresOn).Seconds() < constants.MinTokenLifeThreshold
Expand All @@ -144,20 +151,33 @@ func waitForNotification(ctx context.Context, l *pq.Listener, callback func(id s
case <-time.After(10 * time.Second):
logger.V(10).Infof("Received no events on channel during interval. Pinging source")
go func() {
// TODO: Need to handle the error, especially in cases of network failure.
l.Ping()
}()
}
}
}

func newListener(ctx context.Context, connstr, channel string, callback func(id string)) {
func newListener(ctx context.Context, dbConfig *config.DatabaseConfig, channel string, callback func(id string)) {
logger := ocmlogger.NewOCMLogger(ctx)

plog := func(ev pq.ListenerEventType, err error) {
if err != nil {
logger.Error(err.Error())
}
}
connstr := dbConfig.ConnectionString(true)
// append the password to the connection string
if dbConfig.AuthMethod == constants.AuthMethodPassword {
connstr += fmt.Sprintf(" password='%s'", dbConfig.Password)
} else if dbConfig.AuthMethod == constants.AuthMethodMicrosoftEntra {
token, err := getAccessToken(ctx, dbConfig)
if err != nil {
panic(err)
}
connstr += fmt.Sprintf(" password='%s'", token.Token)
}

listener := pq.NewListener(connstr, 10*time.Second, time.Minute, plog)
err := listener.Listen(channel)
if err != nil {
Expand All @@ -169,7 +189,7 @@ func newListener(ctx context.Context, connstr, channel string, callback func(id
}

func (f *Default) NewListener(ctx context.Context, channel string, callback func(id string)) {
newListener(ctx, f.config.ConnectionString(true), channel, callback)
newListener(ctx, f.config, channel, callback)
}

func (f *Default) New(ctx context.Context) *gorm.DB {
Expand Down
2 changes: 1 addition & 1 deletion pkg/db/db_session/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,5 +217,5 @@ func (f *Test) ResetDB() {
}

func (f *Test) NewListener(ctx context.Context, channel string, callback func(id string)) {
newListener(ctx, f.config.ConnectionString(true), channel, callback)
newListener(ctx, f.config, channel, callback)
}

0 comments on commit 4824514

Please sign in to comment.