Skip to content

Commit

Permalink
Unit test ValidateCAPIAuthTokenAccessHandler (#885)
Browse files Browse the repository at this point in the history
  • Loading branch information
eaudetcobello authored Dec 11, 2024
1 parent 378d6cd commit e0561eb
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 38 deletions.
90 changes: 90 additions & 0 deletions src/k8s/pkg/k8sd/api/capi_access_handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package api_test

import (
"context"
"database/sql"
"net/http"
"testing"

"github.com/canonical/k8s/pkg/k8sd/api"
"github.com/canonical/k8s/pkg/k8sd/database"
testenv "github.com/canonical/k8s/pkg/utils/microcluster"
"github.com/canonical/microcluster/v2/state"
. "github.com/onsi/gomega"
)

func TestValidateCAPIAuthTokenAccessHandler(t *testing.T) {
g := NewWithT(t)

for _, tc := range []struct {
name string
tokenHeaderContent string
tokenDBContent string
expectErr bool
}{
{
name: "valid token",
tokenHeaderContent: "test-token",
tokenDBContent: "test-token",
expectErr: false,
},
{
name: "wrong token in header",
tokenHeaderContent: "invalid-token",
tokenDBContent: "expected-token",
expectErr: true,
},
{
name: "wrong token in db",
tokenHeaderContent: "expected-token",
tokenDBContent: "invalid-token",
expectErr: true,
},
{
name: "empty token in header",
tokenHeaderContent: "",
tokenDBContent: "test-token",
expectErr: true,
},
{
name: "empty token in db",
tokenHeaderContent: "test-token",
tokenDBContent: "",
expectErr: true,
},
{
name: "empty token in header and db",
tokenHeaderContent: "",
tokenDBContent: "",
expectErr: true,
},
} {
t.Run(tc.name, func(t *testing.T) {
testenv.WithState(t, func(ctx context.Context, s state.State) {
var err error
if tc.tokenDBContent != "" {
err = s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
return database.SetClusterAPIToken(ctx, tx, tc.tokenDBContent)
})
g.Expect(err).To(Not(HaveOccurred()))
}

req := &http.Request{
Header: make(http.Header),
}
req.Header.Set("Capi-Auth-Token", tc.tokenHeaderContent)

handler := api.ValidateCAPIAuthTokenAccessHandler("Capi-Auth-Token")
valid, resp := handler(s, req)

if tc.expectErr {
g.Expect(valid).To(BeFalse())
g.Expect(resp).To(Not(BeNil()))
} else {
g.Expect(valid).To(BeTrue())
g.Expect(resp).To(BeNil())
}
})
})
}
}
10 changes: 6 additions & 4 deletions src/k8s/pkg/k8sd/database/capi_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@ import (
"testing"

"github.com/canonical/k8s/pkg/k8sd/database"
testenv "github.com/canonical/k8s/pkg/utils/microcluster"
"github.com/canonical/microcluster/v2/state"
. "github.com/onsi/gomega"
)

func TestClusterAPIAuthTokens(t *testing.T) {
WithDB(t, func(ctx context.Context, db DB) {
testenv.WithState(t, func(ctx context.Context, s state.State) {
var token string = "test-token"

t.Run("SetAuthToken", func(t *testing.T) {
g := NewWithT(t)
err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
err := database.SetClusterAPIToken(ctx, tx, token)
g.Expect(err).To(Not(HaveOccurred()))
return nil
Expand All @@ -26,7 +28,7 @@ func TestClusterAPIAuthTokens(t *testing.T) {
t.Run("CheckAuthToken", func(t *testing.T) {
t.Run("ValidToken", func(t *testing.T) {
g := NewWithT(t)
err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
valid, err := database.ValidateClusterAPIToken(ctx, tx, token)
g.Expect(err).To(Not(HaveOccurred()))
g.Expect(valid).To(BeTrue())
Expand All @@ -37,7 +39,7 @@ func TestClusterAPIAuthTokens(t *testing.T) {

t.Run("InvalidToken", func(t *testing.T) {
g := NewWithT(t)
err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
valid, err := database.ValidateClusterAPIToken(ctx, tx, "invalid-token")
g.Expect(err).To(Not(HaveOccurred()))
g.Expect(valid).To(BeFalse())
Expand Down
16 changes: 9 additions & 7 deletions src/k8s/pkg/k8sd/database/cluster_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ import (
"github.com/canonical/k8s/pkg/k8sd/database"
"github.com/canonical/k8s/pkg/k8sd/types"
"github.com/canonical/k8s/pkg/utils"
testenv "github.com/canonical/k8s/pkg/utils/microcluster"
"github.com/canonical/microcluster/v2/state"
. "github.com/onsi/gomega"
)

func TestClusterConfig(t *testing.T) {
WithDB(t, func(ctx context.Context, d DB) {
testenv.WithState(t, func(ctx context.Context, s state.State) {
t.Run("Set", func(t *testing.T) {
g := NewWithT(t)
expectedClusterConfig := types.ClusterConfig{
Expand All @@ -24,15 +26,15 @@ func TestClusterConfig(t *testing.T) {
expectedClusterConfig.SetDefaults()

// Write some config to the database
err := d.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
_, err := database.SetClusterConfig(context.Background(), tx, expectedClusterConfig)
g.Expect(err).To(Not(HaveOccurred()))
return nil
})
g.Expect(err).To(Not(HaveOccurred()))

// Retrieve it and map it to the struct
err = d.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
err = s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
clusterConfig, err := database.GetClusterConfig(ctx, tx)
g.Expect(err).To(Not(HaveOccurred()))
g.Expect(clusterConfig).To(Equal(expectedClusterConfig))
Expand All @@ -52,7 +54,7 @@ func TestClusterConfig(t *testing.T) {
}
expectedClusterConfig.SetDefaults()

err := d.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
_, err := database.SetClusterConfig(context.Background(), tx, types.ClusterConfig{
Certificates: types.Certificates{
CACert: utils.Pointer("CA CERT NEW DATA"),
Expand All @@ -63,7 +65,7 @@ func TestClusterConfig(t *testing.T) {
})
g.Expect(err).To(HaveOccurred())

err = d.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
err = s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
clusterConfig, err := database.GetClusterConfig(ctx, tx)
g.Expect(err).To(Not(HaveOccurred()))
g.Expect(clusterConfig).To(Equal(expectedClusterConfig))
Expand All @@ -90,7 +92,7 @@ func TestClusterConfig(t *testing.T) {
}
expectedClusterConfig.SetDefaults()

err := d.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
returnedConfig, err := database.SetClusterConfig(context.Background(), tx, types.ClusterConfig{
Kubelet: types.Kubelet{
ClusterDNS: utils.Pointer("10.152.183.10"),
Expand All @@ -109,7 +111,7 @@ func TestClusterConfig(t *testing.T) {
})
g.Expect(err).To(Not(HaveOccurred()))

err = d.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
err = s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
clusterConfig, err := database.GetClusterConfig(ctx, tx)
g.Expect(err).To(Not(HaveOccurred()))
g.Expect(clusterConfig).To(Equal(expectedClusterConfig))
Expand Down
6 changes: 4 additions & 2 deletions src/k8s/pkg/k8sd/database/feature_status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ import (
"github.com/canonical/k8s/pkg/k8sd/database"
"github.com/canonical/k8s/pkg/k8sd/features"
"github.com/canonical/k8s/pkg/k8sd/types"
testenv "github.com/canonical/k8s/pkg/utils/microcluster"
"github.com/canonical/microcluster/v2/state"
. "github.com/onsi/gomega"
)

func TestFeatureStatus(t *testing.T) {
WithDB(t, func(ctx context.Context, db DB) {
_ = db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
testenv.WithState(t, func(ctx context.Context, s state.State) {
_ = s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
t0, _ := time.Parse(time.RFC3339, time.Now().Format(time.RFC3339))
networkStatus := types.FeatureStatus{
Enabled: true,
Expand Down
14 changes: 8 additions & 6 deletions src/k8s/pkg/k8sd/database/kubernetes_auth_tokens_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@ import (
"testing"

"github.com/canonical/k8s/pkg/k8sd/database"
testenv "github.com/canonical/k8s/pkg/utils/microcluster"
"github.com/canonical/microcluster/v2/state"
. "github.com/onsi/gomega"
)

func TestKubernetesAuthTokens(t *testing.T) {
WithDB(t, func(ctx context.Context, db DB) {
testenv.WithState(t, func(ctx context.Context, s state.State) {
var token1, token2 string

t.Run("GetOrCreateToken", func(t *testing.T) {
g := NewWithT(t)
err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
var err error

token1, err = database.GetOrCreateToken(ctx, tx, "user1", []string{"group1", "group2"})
Expand All @@ -33,7 +35,7 @@ func TestKubernetesAuthTokens(t *testing.T) {

t.Run("Existing", func(t *testing.T) {
g := NewWithT(t)
err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
token, err := database.GetOrCreateToken(ctx, tx, "user1", []string{"group1", "group2"})
g.Expect(err).To(Not(HaveOccurred()))
g.Expect(token).To(Equal(token1))
Expand All @@ -46,7 +48,7 @@ func TestKubernetesAuthTokens(t *testing.T) {
t.Run("CheckToken", func(t *testing.T) {
t.Run("user1", func(t *testing.T) {
g := NewWithT(t)
err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
username, groups, err := database.CheckToken(ctx, tx, token1)
g.Expect(err).To(Not(HaveOccurred()))
g.Expect(username).To(Equal("user1"))
Expand All @@ -57,7 +59,7 @@ func TestKubernetesAuthTokens(t *testing.T) {
})
t.Run("user2", func(t *testing.T) {
g := NewWithT(t)
err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
username, groups, err := database.CheckToken(ctx, tx, token2)
g.Expect(err).To(Not(HaveOccurred()))
g.Expect(username).To(Equal("user2"))
Expand All @@ -70,7 +72,7 @@ func TestKubernetesAuthTokens(t *testing.T) {

t.Run("DeleteToken", func(t *testing.T) {
g := NewWithT(t)
err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
err := database.DeleteToken(ctx, tx, token2)
g.Expect(err).To(Not(HaveOccurred()))

Expand Down
6 changes: 4 additions & 2 deletions src/k8s/pkg/k8sd/database/worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ import (
"time"

"github.com/canonical/k8s/pkg/k8sd/database"
testenv "github.com/canonical/k8s/pkg/utils/microcluster"
"github.com/canonical/microcluster/v2/state"
. "github.com/onsi/gomega"
)

func TestWorkerNodeToken(t *testing.T) {
WithDB(t, func(ctx context.Context, db DB) {
_ = db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
testenv.WithState(t, func(ctx context.Context, s state.State) {
_ = s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
tokenExpiry := time.Now().Add(time.Hour)
t.Run("Default", func(t *testing.T) {
g := NewWithT(t)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package database_test
package testenv

import (
"context"
"database/sql"
"fmt"
"testing"
"time"
Expand All @@ -21,21 +20,17 @@ const (
// nextIdx is used to pick different listen ports for each microcluster instance.
var nextIdx int

// DB is an interface for the internal microcluster DB type.
type DB interface {
Transaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error
}

// WithDB can be used to run isolated tests against the microcluster database.
// WithState can be used to run isolated tests against the microcluster database.
// The Database() can be accessed by calling s.Database().
//
// Example usage:
//
// func TestKubernetesAuthTokens(t *testing.T) {
// t.Run("ValidToken", func(t *testing.T) {
// g := NewWithT(t)
// WithDB(t, func(ctx context.Context, db DB) {
// WithState(t, func(ctx context.Context, s state.State) {
// err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
// token, err := database.GetOrCreateToken(ctx, tx, "user1", []string{"group1", "group2"})
// token, err := s.Database().GetOrCreateToken(ctx, tx, "user1", []string{"group1", "group2"})
// if !g.Expect(err).To(Not(HaveOccurred())) {
// return err
// }
Expand All @@ -46,7 +41,7 @@ type DB interface {
// })
// })
// }
func WithDB(t *testing.T, f func(context.Context, DB)) {
func WithState(t *testing.T, f func(context.Context, state.State)) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -57,16 +52,16 @@ func WithDB(t *testing.T, f func(context.Context, DB)) {
t.Fatalf("failed to create microcluster app: %v", err)
}

databaseCh := make(chan DB, 1)
stateChan := make(chan state.State, 1)
doneCh := make(chan error, 1)
defer close(databaseCh)
defer close(stateChan)
defer close(doneCh)

// app.Run() is blocking, so we get the database handle through a channel
// app.Run() is blocking, so we get the state handle through a channel
go func() {
doneCh <- app.Run(ctx, &state.Hooks{
PostBootstrap: func(ctx context.Context, s state.State, initConfig map[string]string) error {
databaseCh <- s.Database()
stateChan <- s
return nil
},
OnStart: func(ctx context.Context, s state.State) error {
Expand Down Expand Up @@ -95,8 +90,8 @@ func WithDB(t *testing.T, f func(context.Context, DB)) {
select {
case <-time.After(microclusterDatabaseInitTimeout):
t.Fatalf("timed out waiting for microcluster to start")
case db := <-databaseCh:
f(ctx, db)
case state := <-stateChan:
f(ctx, state)
}

// cancel context to stop the microcluster instance, and wait for it to shutdown
Expand Down

0 comments on commit e0561eb

Please sign in to comment.