From 2eb263a59583e68991f165fb04d3ea248e34f577 Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Fri, 13 Dec 2024 21:39:15 +0000 Subject: [PATCH 1/9] Move request handlers out of server package The servers package, and router.go in particular, had become quite large. Address this by moving some things out to separate packages: * http request handlers all move to pkg/server/handlers. * node password bootstrap auth handler goes into pkg/nodepassword with the other nodepassword code. While we're at it, also be more consistent about calling variables that hold a config.Control struct or reference `control` instead of `config` or `server`. Signed-off-by: Brad Davidson --- pkg/cli/secretsencrypt/secrets_encrypt.go | 15 +- pkg/cli/token/token.go | 3 +- pkg/daemons/control/deps/deps.go | 6 +- pkg/nodepassword/validate.go | 226 +++++++ pkg/server/{ => handlers}/cert.go | 40 +- pkg/server/handlers/handlers.go | 400 +++++++++++++ pkg/server/handlers/router.go | 66 ++ pkg/server/{ => handlers}/secrets-encrypt.go | 130 ++-- pkg/server/{ => handlers}/token.go | 38 +- pkg/server/router.go | 595 ------------------- pkg/server/server.go | 19 +- 11 files changed, 820 insertions(+), 718 deletions(-) create mode 100644 pkg/nodepassword/validate.go rename pkg/server/{ => handlers}/cert.go (84%) create mode 100644 pkg/server/handlers/handlers.go create mode 100644 pkg/server/handlers/router.go rename pkg/server/{ => handlers}/secrets-encrypt.go (71%) rename pkg/server/{ => handlers}/token.go (64%) delete mode 100644 pkg/server/router.go diff --git a/pkg/cli/secretsencrypt/secrets_encrypt.go b/pkg/cli/secretsencrypt/secrets_encrypt.go index b0c6256e2877..b8a08535c730 100644 --- a/pkg/cli/secretsencrypt/secrets_encrypt.go +++ b/pkg/cli/secretsencrypt/secrets_encrypt.go @@ -15,6 +15,7 @@ import ( "github.com/k3s-io/k3s/pkg/proctitle" "github.com/k3s-io/k3s/pkg/secretsencrypt" "github.com/k3s-io/k3s/pkg/server" + "github.com/k3s-io/k3s/pkg/server/handlers" "github.com/k3s-io/k3s/pkg/version" "github.com/pkg/errors" "github.com/urfave/cli" @@ -54,7 +55,7 @@ func Enable(app *cli.Context) error { if err != nil { return err } - b, err := json.Marshal(server.EncryptionRequest{Enable: ptr.To(true)}) + b, err := json.Marshal(handlers.EncryptionRequest{Enable: ptr.To(true)}) if err != nil { return err } @@ -73,7 +74,7 @@ func Disable(app *cli.Context) error { if err != nil { return err } - b, err := json.Marshal(server.EncryptionRequest{Enable: ptr.To(false)}) + b, err := json.Marshal(handlers.EncryptionRequest{Enable: ptr.To(false)}) if err != nil { return err } @@ -96,7 +97,7 @@ func Status(app *cli.Context) error { if err != nil { return wrapServerError(err) } - status := server.EncryptionState{} + status := handlers.EncryptionState{} if err := json.Unmarshal(data, &status); err != nil { return err } @@ -153,7 +154,7 @@ func Prepare(app *cli.Context) error { if err != nil { return err } - b, err := json.Marshal(server.EncryptionRequest{ + b, err := json.Marshal(handlers.EncryptionRequest{ Stage: ptr.To(secretsencrypt.EncryptionPrepare), Force: cmds.ServerConfig.EncryptForce, }) @@ -175,7 +176,7 @@ func Rotate(app *cli.Context) error { if err != nil { return err } - b, err := json.Marshal(server.EncryptionRequest{ + b, err := json.Marshal(handlers.EncryptionRequest{ Stage: ptr.To(secretsencrypt.EncryptionRotate), Force: cmds.ServerConfig.EncryptForce, }) @@ -197,7 +198,7 @@ func Reencrypt(app *cli.Context) error { if err != nil { return err } - b, err := json.Marshal(server.EncryptionRequest{ + b, err := json.Marshal(handlers.EncryptionRequest{ Stage: ptr.To(secretsencrypt.EncryptionReencryptActive), Force: cmds.ServerConfig.EncryptForce, Skip: cmds.ServerConfig.EncryptSkip, @@ -220,7 +221,7 @@ func RotateKeys(app *cli.Context) error { if err != nil { return err } - b, err := json.Marshal(server.EncryptionRequest{ + b, err := json.Marshal(handlers.EncryptionRequest{ Stage: ptr.To(secretsencrypt.EncryptionRotateKeys), }) if err != nil { diff --git a/pkg/cli/token/token.go b/pkg/cli/token/token.go index 9d514d7b5286..64d6026cc7bf 100644 --- a/pkg/cli/token/token.go +++ b/pkg/cli/token/token.go @@ -16,6 +16,7 @@ import ( "github.com/k3s-io/k3s/pkg/kubeadm" "github.com/k3s-io/k3s/pkg/proctitle" "github.com/k3s-io/k3s/pkg/server" + "github.com/k3s-io/k3s/pkg/server/handlers" "github.com/k3s-io/k3s/pkg/util" "github.com/k3s-io/k3s/pkg/version" "github.com/pkg/errors" @@ -153,7 +154,7 @@ func Rotate(app *cli.Context) error { if err != nil { return err } - b, err := json.Marshal(server.TokenRotateRequest{ + b, err := json.Marshal(handlers.TokenRotateRequest{ NewToken: ptr.To(cmds.TokenConfig.NewToken), }) if err != nil { diff --git a/pkg/daemons/control/deps/deps.go b/pkg/daemons/control/deps/deps.go index ecd25347c60e..e3596ecc52ce 100644 --- a/pkg/daemons/control/deps/deps.go +++ b/pkg/daemons/control/deps/deps.go @@ -391,11 +391,11 @@ func genClientCerts(config *config.Control) error { } } - if _, err = factory(user.KubeProxy, nil, runtime.ClientKubeProxyCert, runtime.ClientKubeProxyKey); err != nil { + if _, _, err := certutil.LoadOrGenerateKeyFile(runtime.ClientKubeProxyKey, regen); err != nil { return err } - // This user (system:k3s-controller by default) must be bound to a role in rolebindings.yaml or the downstream equivalent - if _, err = factory("system:"+version.Program+"-controller", nil, runtime.ClientK3sControllerCert, runtime.ClientK3sControllerKey); err != nil { + + if _, _, err := certutil.LoadOrGenerateKeyFile(runtime.ClientK3sControllerKey, regen); err != nil { return err } diff --git a/pkg/nodepassword/validate.go b/pkg/nodepassword/validate.go new file mode 100644 index 000000000000..3832a782cd44 --- /dev/null +++ b/pkg/nodepassword/validate.go @@ -0,0 +1,226 @@ +package nodepassword + +import ( + "context" + "net/http" + "os" + "path" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/gorilla/mux" + "github.com/k3s-io/k3s/pkg/daemons/config" + "github.com/k3s-io/k3s/pkg/util" + "github.com/pkg/errors" + coreclient "github.com/rancher/wrangler/v3/pkg/generated/controllers/core/v1" + "github.com/sirupsen/logrus" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/apiserver/pkg/authentication/user" + "k8s.io/apiserver/pkg/endpoints/request" + "k8s.io/kubernetes/pkg/auth/nodeidentifier" +) + +var identifier = nodeidentifier.NewDefaultNodeIdentifier() + +// NodeAuthValidator returns a node name, or http error code and error +type NodeAuthValidator func(req *http.Request) (string, int, error) + +// nodeInfo contains information on the requesting node, derived from auth creds +// and request headers. +type nodeInfo struct { + Name string + Password string + User user.Info +} + +// GetNodeAuthValidator returns a function that will be called to validate node password authentication. +// Node password authentication is used when requesting kubelet certificates, and verifies that the +// credentials are valid for the requested node name, and that the node password is valid if it exists. +// These checks prevent a user with access to one agent from requesting kubelet certificates that +// could be used to impersonate another cluster member. +func GetNodeAuthValidator(ctx context.Context, control *config.Control) NodeAuthValidator { + runtime := control.Runtime + deferredNodes := map[string]bool{} + var secretClient coreclient.SecretController + var nodeClient coreclient.NodeController + var mu sync.Mutex + + return func(req *http.Request) (string, int, error) { + node, err := getNodeInfo(req) + if err != nil { + return "", http.StatusBadRequest, err + } + + // node identity auth uses an existing kubelet client cert instead of auth token. + // If used, validate that the node identity matches the requested node name. + nodeName, isNodeAuth := identifier.NodeIdentity(node.User) + if isNodeAuth && nodeName != node.Name { + return "", http.StatusBadRequest, errors.New("header node name does not match auth node name") + } + + if secretClient == nil || nodeClient == nil { + if runtime.Core != nil { + // initialize the client if we can + secretClient = runtime.Core.Core().V1().Secret() + nodeClient = runtime.Core.Core().V1().Node() + } else if node.Name == os.Getenv("NODE_NAME") { + // If we're verifying our own password, verify it locally and ensure a secret later. + return verifyLocalPassword(ctx, control, &mu, deferredNodes, node) + } else if control.DisableAPIServer && !isNodeAuth { + // If we're running on an etcd-only node, and the request didn't use Node Identity auth, + // defer node password verification until an apiserver joins the cluster. + return verifyRemotePassword(ctx, control, &mu, deferredNodes, node) + } else { + // Otherwise, reject the request until the core is ready. + return "", http.StatusServiceUnavailable, util.ErrCoreNotReady + } + } + + // verify that the node exists, if using Node Identity auth + if err := verifyNode(ctx, nodeClient, node); err != nil { + return "", http.StatusUnauthorized, err + } + + // verify that the node password secret matches, or create it if it does not + if err := Ensure(secretClient, node.Name, node.Password); err != nil { + // if the verification failed, reject the request + if errors.Is(err, ErrVerifyFailed) { + return "", http.StatusForbidden, err + } + // If verification failed due to an error creating the node password secret, allow + // the request, but retry verification until the outage is resolved. This behavior + // allows nodes to join the cluster during outages caused by validating webhooks + // blocking secret creation - if the outage requires new nodes to join in order to + // run the webhook pods, we must fail open here to resolve the outage. + return verifyRemotePassword(ctx, control, &mu, deferredNodes, node) + } + + return node.Name, http.StatusOK, nil + } +} + +// getNodeInfo returns node name, password, and user extracted +// from request headers and context. An error is returned +// if any critical fields are missing. +func getNodeInfo(req *http.Request) (*nodeInfo, error) { + user, ok := request.UserFrom(req.Context()) + if !ok { + return nil, errors.New("auth user not set") + } + + program := mux.Vars(req)["program"] + nodeName := req.Header.Get(program + "-Node-Name") + if nodeName == "" { + return nil, errors.New("node name not set") + } + + nodePassword := req.Header.Get(program + "-Node-Password") + if nodePassword == "" { + return nil, errors.New("node password not set") + } + + return &nodeInfo{ + Name: strings.ToLower(nodeName), + Password: nodePassword, + User: user, + }, nil +} + +// verifyLocalPassword is used to validate the local node's password secret directly against the node password file, when the apiserver is unavailable. +// This is only used early in startup, when a control-plane node's agent is starting up without a functional apiserver. +func verifyLocalPassword(ctx context.Context, control *config.Control, mu *sync.Mutex, deferredNodes map[string]bool, node *nodeInfo) (string, int, error) { + // do not attempt to verify the node password if the local host is not running an agent and does not have a node resource. + if control.DisableAgent { + return node.Name, http.StatusOK, nil + } + + // use same password file location that the agent creates + nodePasswordRoot := "/" + if control.Rootless { + nodePasswordRoot = filepath.Join(path.Dir(control.DataDir), "agent") + } + nodeConfigPath := filepath.Join(nodePasswordRoot, "etc", "rancher", "node") + nodePasswordFile := filepath.Join(nodeConfigPath, "password") + + passBytes, err := os.ReadFile(nodePasswordFile) + if err != nil { + return "", http.StatusInternalServerError, errors.Wrap(err, "unable to read node password file") + } + + passHash, err := Hasher.CreateHash(strings.TrimSpace(string(passBytes))) + if err != nil { + return "", http.StatusInternalServerError, errors.Wrap(err, "unable to hash node password file") + } + + if err := Hasher.VerifyHash(passHash, node.Password); err != nil { + return "", http.StatusForbidden, errors.Wrap(err, "unable to verify local node password") + } + + mu.Lock() + defer mu.Unlock() + + if _, ok := deferredNodes[node.Name]; !ok { + deferredNodes[node.Name] = true + go ensureSecret(ctx, control, node) + logrus.Infof("Password verified locally for node %s", node.Name) + } + + return node.Name, http.StatusOK, nil +} + +// verifyRemotePassword is used when the server does not have a local apisever, as in the case of etcd-only nodes. +// The node password is ensured once an apiserver joins the cluster. +func verifyRemotePassword(ctx context.Context, control *config.Control, mu *sync.Mutex, deferredNodes map[string]bool, node *nodeInfo) (string, int, error) { + mu.Lock() + defer mu.Unlock() + + if _, ok := deferredNodes[node.Name]; !ok { + deferredNodes[node.Name] = true + go ensureSecret(ctx, control, node) + logrus.Infof("Password verification deferred for node %s", node.Name) + } + + return node.Name, http.StatusOK, nil +} + +// verifyNode confirms that a node with the given name exists, to prevent auth +// from succeeding with a client certificate for a node that has been deleted from the cluster. +func verifyNode(ctx context.Context, nodeClient coreclient.NodeController, node *nodeInfo) error { + if nodeName, isNodeAuth := identifier.NodeIdentity(node.User); isNodeAuth { + if _, err := nodeClient.Cache().Get(nodeName); err != nil { + return errors.Wrap(err, "unable to verify node identity") + } + } + return nil +} + +// ensureSecret validates a server's node password secret once the apiserver is up. +// As the node has already joined the cluster at this point, this is purely informational. +func ensureSecret(ctx context.Context, control *config.Control, node *nodeInfo) { + runtime := control.Runtime + _ = wait.PollUntilContextCancel(ctx, time.Second*5, true, func(ctx context.Context) (bool, error) { + if runtime.Core != nil { + secretClient := runtime.Core.Core().V1().Secret() + // This is consistent with events attached to the node generated by the kubelet + // https://github.com/kubernetes/kubernetes/blob/612130dd2f4188db839ea5c2dea07a96b0ad8d1c/pkg/kubelet/kubelet.go#L479-L485 + nodeRef := &corev1.ObjectReference{ + Kind: "Node", + Name: node.Name, + UID: types.UID(node.Name), + Namespace: "", + } + if err := Ensure(secretClient, node.Name, node.Password); err != nil { + runtime.Event.Eventf(nodeRef, corev1.EventTypeWarning, "NodePasswordValidationFailed", "Deferred node password secret validation failed: %v", err) + // Return true to stop polling if the password verification failed; only retry on secret creation errors. + return errors.Is(err, ErrVerifyFailed), nil + } + runtime.Event.Event(nodeRef, corev1.EventTypeNormal, "NodePasswordValidationComplete", "Deferred node password secret validation complete") + return true, nil + } + return false, nil + }) +} diff --git a/pkg/server/cert.go b/pkg/server/handlers/cert.go similarity index 84% rename from pkg/server/cert.go rename to pkg/server/handlers/cert.go index ea03a305dfd1..f6170f8bbebf 100644 --- a/pkg/server/cert.go +++ b/pkg/server/handlers/cert.go @@ -1,4 +1,4 @@ -package server +package handlers import ( "bytes" @@ -28,14 +28,14 @@ import ( "k8s.io/client-go/util/keyutil" ) -func caCertReplaceHandler(server *config.Control) http.HandlerFunc { +func CACertReplace(control *config.Control) http.HandlerFunc { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPut { util.SendError(fmt.Errorf("method not allowed"), resp, req, http.StatusMethodNotAllowed) return } force, _ := strconv.ParseBool(req.FormValue("force")) - if err := caCertReplace(server, req.Body, force); err != nil { + if err := caCertReplace(control, req.Body, force); err != nil { util.SendErrorWithID(err, "certificate", resp, req, http.StatusInternalServerError) return } @@ -48,54 +48,54 @@ func caCertReplaceHandler(server *config.Control) http.HandlerFunc { // validated to confirm that the new certs share a common root with the existing certs, and if so are saved to // the datastore. If the functions succeeds, servers should be restarted immediately to load the new certs // from the bootstrap data. -func caCertReplace(server *config.Control, buf io.ReadCloser, force bool) error { - tmpdir, err := os.MkdirTemp(server.DataDir, ".rotate-ca-tmp-") +func caCertReplace(control *config.Control, buf io.ReadCloser, force bool) error { + tmpdir, err := os.MkdirTemp(control.DataDir, ".rotate-ca-tmp-") if err != nil { return err } defer os.RemoveAll(tmpdir) runtime := config.NewRuntime(nil) - runtime.EtcdConfig = server.Runtime.EtcdConfig - runtime.ServerToken = server.Runtime.ServerToken + runtime.EtcdConfig = control.Runtime.EtcdConfig + runtime.ServerToken = control.Runtime.ServerToken - tmpServer := &config.Control{ + tmpControl := &config.Control{ Runtime: runtime, - Token: server.Token, + Token: control.Token, DataDir: tmpdir, } - deps.CreateRuntimeCertFiles(tmpServer) + deps.CreateRuntimeCertFiles(tmpControl) bootstrapData := bootstrap.PathsDataformat{} if err := json.NewDecoder(buf).Decode(&bootstrapData); err != nil { return err } - if err := bootstrap.WriteToDiskFromStorage(bootstrapData, &tmpServer.Runtime.ControlRuntimeBootstrap); err != nil { + if err := bootstrap.WriteToDiskFromStorage(bootstrapData, &tmpControl.Runtime.ControlRuntimeBootstrap); err != nil { return err } - if err := defaultBootstrap(server, tmpServer); err != nil { + if err := defaultBootstrap(control, tmpControl); err != nil { return errors.Wrap(err, "failed to set default bootstrap values") } - if err := validateBootstrap(server, tmpServer); err != nil { + if err := validateBootstrap(control, tmpControl); err != nil { if !force { return errors.Wrap(err, "failed to validate new CA certificates and keys") } logrus.Warnf("Save of CA certificates and keys forced, ignoring validation errors: %v", err) } - return cluster.Save(context.TODO(), tmpServer, true) + return cluster.Save(context.TODO(), tmpControl, true) } // defaultBootstrap provides default values from the existing bootstrap fields // if the value is not tagged for rotation, or the current value is empty. -func defaultBootstrap(oldServer, newServer *config.Control) error { +func defaultBootstrap(oldControl, newControl *config.Control) error { errs := []error{} // Use reflection to iterate over all of the bootstrap fields, checking files at each of the new paths. - oldMeta := reflect.ValueOf(&oldServer.Runtime.ControlRuntimeBootstrap).Elem() - newMeta := reflect.ValueOf(&newServer.Runtime.ControlRuntimeBootstrap).Elem() + oldMeta := reflect.ValueOf(&oldControl.Runtime.ControlRuntimeBootstrap).Elem() + newMeta := reflect.ValueOf(&newControl.Runtime.ControlRuntimeBootstrap).Elem() // use the existing file if the new file does not exist or is empty for _, field := range reflect.VisibleFields(oldMeta.Type()) { @@ -122,12 +122,12 @@ func defaultBootstrap(oldServer, newServer *config.Control) error { // validateBootstrap checks the new certs and keys to ensure that the cluster would function properly were they to be used. // - The new leaf CA certificates must be verifiable using the same root and intermediate certs as the current leaf CA certificates. // - The new service account signing key bundle must include the currently active signing key. -func validateBootstrap(oldServer, newServer *config.Control) error { +func validateBootstrap(oldControl, newControl *config.Control) error { errs := []error{} // Use reflection to iterate over all of the bootstrap fields, checking files at each of the new paths. - oldMeta := reflect.ValueOf(&oldServer.Runtime.ControlRuntimeBootstrap).Elem() - newMeta := reflect.ValueOf(&newServer.Runtime.ControlRuntimeBootstrap).Elem() + oldMeta := reflect.ValueOf(&oldControl.Runtime.ControlRuntimeBootstrap).Elem() + newMeta := reflect.ValueOf(&newControl.Runtime.ControlRuntimeBootstrap).Elem() for _, field := range reflect.VisibleFields(oldMeta.Type()) { // Only handle bootstrap fields tagged for rotation diff --git a/pkg/server/handlers/handlers.go b/pkg/server/handlers/handlers.go new file mode 100644 index 000000000000..f060c4c17b40 --- /dev/null +++ b/pkg/server/handlers/handlers.go @@ -0,0 +1,400 @@ +package handlers + +import ( + "context" + "crypto" + "crypto/x509" + "fmt" + "net" + "net/http" + "os" + "strconv" + "strings" + "time" + + "github.com/gorilla/mux" + "github.com/k3s-io/k3s/pkg/bootstrap" + "github.com/k3s-io/k3s/pkg/cli/cmds" + "github.com/k3s-io/k3s/pkg/daemons/config" + "github.com/k3s-io/k3s/pkg/etcd" + "github.com/k3s-io/k3s/pkg/nodepassword" + "github.com/k3s-io/k3s/pkg/util" + "github.com/pkg/errors" + certutil "github.com/rancher/dynamiclistener/cert" + "github.com/sirupsen/logrus" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/json" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apiserver/pkg/authentication/user" + typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1" +) + +func CACerts(config *config.Control) http.Handler { + var ca []byte + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + if ca == nil { + var err error + ca, err = os.ReadFile(config.Runtime.ServerCA) + if err != nil { + util.SendError(err, resp, req) + return + } + } + resp.Header().Set("content-type", "text/plain") + resp.Write(ca) + }) +} + +func Ping() http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + data := []byte("pong") + resp.WriteHeader(http.StatusOK) + resp.Header().Set("Content-Type", "text/plain") + resp.Header().Set("Content-Length", strconv.Itoa(len(data))) + resp.Write(data) + }) +} + +func ServingKubeletCert(control *config.Control, auth nodepassword.NodeAuthValidator) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + nodeName, errCode, err := auth(req) + if err != nil { + util.SendError(err, resp, req, errCode) + return + } + + keyFile := control.Runtime.ServingKubeletKey + caCerts, caKey, key, err := getCACertAndKeys(control.Runtime.ServerCA, control.Runtime.ServerCAKey, keyFile) + if err != nil { + util.SendError(err, resp, req) + return + } + + ips := []net.IP{net.ParseIP("127.0.0.1")} + program := mux.Vars(req)["program"] + if nodeIP := req.Header.Get(program + "-Node-IP"); nodeIP != "" { + for _, v := range strings.Split(nodeIP, ",") { + ip := net.ParseIP(v) + if ip == nil { + util.SendError(fmt.Errorf("invalid node IP address %s", ip), resp, req) + return + } + ips = append(ips, ip) + } + } + + cert, err := certutil.NewSignedCert(certutil.Config{ + CommonName: nodeName, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + AltNames: certutil.AltNames{ + DNSNames: []string{nodeName, "localhost"}, + IPs: ips, + }, + }, key, caCerts[0], caKey) + if err != nil { + util.SendError(err, resp, req) + return + } + + keyBytes, err := os.ReadFile(keyFile) + if err != nil { + http.Error(resp, err.Error(), http.StatusInternalServerError) + return + } + + resp.Write(util.EncodeCertsPEM(cert, caCerts)) + resp.Write(keyBytes) + }) +} + +func ClientKubeletCert(control *config.Control, auth nodepassword.NodeAuthValidator) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + nodeName, errCode, err := auth(req) + if err != nil { + util.SendError(err, resp, req, errCode) + return + } + + keyFile := control.Runtime.ClientKubeletKey + caCerts, caKey, key, err := getCACertAndKeys(control.Runtime.ClientCA, control.Runtime.ClientCAKey, keyFile) + if err != nil { + util.SendError(err, resp, req) + return + } + + cert, err := certutil.NewSignedCert(certutil.Config{ + CommonName: "system:node:" + nodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }, key, caCerts[0], caKey) + if err != nil { + util.SendError(err, resp, req) + return + } + + keyBytes, err := os.ReadFile(keyFile) + if err != nil { + http.Error(resp, err.Error(), http.StatusInternalServerError) + return + } + + resp.Write(util.EncodeCertsPEM(cert, caCerts)) + resp.Write(keyBytes) + }) +} + +func ClientKubeProxyCert(control *config.Control) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + keyFile := control.Runtime.ClientKubeProxyKey + caCerts, caKey, key, err := getCACertAndKeys(control.Runtime.ClientCA, control.Runtime.ClientCAKey, keyFile) + if err != nil { + util.SendError(err, resp, req) + return + } + + cert, err := certutil.NewSignedCert(certutil.Config{ + CommonName: user.KubeProxy, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }, key, caCerts[0], caKey) + if err != nil { + util.SendError(err, resp, req) + return + } + + keyBytes, err := os.ReadFile(keyFile) + if err != nil { + http.Error(resp, err.Error(), http.StatusInternalServerError) + return + } + + resp.Write(util.EncodeCertsPEM(cert, caCerts)) + resp.Write(keyBytes) + }) +} + +func ClientControllerCert(control *config.Control) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + keyFile := control.Runtime.ClientK3sControllerKey + caCerts, caKey, key, err := getCACertAndKeys(control.Runtime.ClientCA, control.Runtime.ClientCAKey, keyFile) + if err != nil { + util.SendError(err, resp, req) + return + } + + // This user (system:k3s-controller by default) must be bound to a role in rolebindings.yaml or the downstream equivalent + program := mux.Vars(req)["program"] + cert, err := certutil.NewSignedCert(certutil.Config{ + CommonName: "system:" + program + "-controller", + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }, key, caCerts[0], caKey) + if err != nil { + util.SendError(err, resp, req) + return + } + + keyBytes, err := os.ReadFile(keyFile) + if err != nil { + http.Error(resp, err.Error(), http.StatusInternalServerError) + return + } + + resp.Write(util.EncodeCertsPEM(cert, caCerts)) + resp.Write(keyBytes) + }) +} + +func File(fileName ...string) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + resp.Header().Set("Content-Type", "text/plain") + + if len(fileName) == 1 { + http.ServeFile(resp, req, fileName[0]) + return + } + + for _, f := range fileName { + bytes, err := os.ReadFile(f) + if err != nil { + util.SendError(errors.Wrapf(err, "failed to read %s", f), resp, req, http.StatusInternalServerError) + return + } + resp.Write(bytes) + } + }) +} + +func APIServer(control *config.Control, cfg *cmds.Server) http.Handler { + if cfg.DisableAPIServer { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + util.SendError(util.ErrAPIDisabled, resp, req, http.StatusServiceUnavailable) + }) + } + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + if control.Runtime != nil && control.Runtime.APIServer != nil { + control.Runtime.APIServer.ServeHTTP(resp, req) + } else { + util.SendError(util.ErrAPINotReady, resp, req, http.StatusServiceUnavailable) + } + }) +} + +func APIServers(control *config.Control) http.Handler { + collectAddresses := getAddressCollector(control) + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + ctx, cancel := context.WithTimeout(req.Context(), 5*time.Second) + defer cancel() + endpoints := collectAddresses(ctx) + resp.Header().Set("content-type", "application/json") + if err := json.NewEncoder(resp).Encode(endpoints); err != nil { + util.SendError(errors.Wrap(err, "failed to encode apiserver endpoints"), resp, req, http.StatusInternalServerError) + } + }) +} + +func Config(control *config.Control, cfg *cmds.Server) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + // Startup hooks may read and modify cmds.Server in a goroutine, but as these are copied into + // config.Control before the startup hooks are called, any modifications need to be sync'd back + // into the struct before it is sent to agents. + // At this time we don't sync all the fields, just those known to be touched by startup hooks. + control.DisableKubeProxy = cfg.DisableKubeProxy + resp.Header().Set("content-type", "application/json") + if err := json.NewEncoder(resp).Encode(control); err != nil { + util.SendError(errors.Wrap(err, "failed to encode agent config"), resp, req, http.StatusInternalServerError) + } + }) +} + +func Readyz(control *config.Control) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + if control.Runtime.Core == nil { + util.SendError(util.ErrCoreNotReady, resp, req, http.StatusServiceUnavailable) + return + } + data := []byte("ok") + resp.WriteHeader(http.StatusOK) + resp.Header().Set("Content-Type", "text/plain") + resp.Header().Set("Content-Length", strconv.Itoa(len(data))) + resp.Write(data) + }) +} + +func Bootstrap(control *config.Control) http.Handler { + if control.Runtime.HTTPBootstrap { + return bootstrap.Handler(&control.Runtime.ControlRuntimeBootstrap) + } + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + logrus.Warnf("Received HTTP bootstrap request from %s, but embedded etcd is not enabled.", req.RemoteAddr) + util.SendError(errors.New("etcd disabled"), resp, req, http.StatusBadRequest) + }) +} + +func Static(urlPrefix, staticDir string) http.Handler { + return http.StripPrefix(urlPrefix, http.FileServer(http.Dir(staticDir))) +} + +func getCACertAndKeys(caCertFile, caKeyFile, signingKeyFile string) ([]*x509.Certificate, crypto.Signer, crypto.Signer, error) { + keyBytes, err := os.ReadFile(signingKeyFile) + if err != nil { + return nil, nil, nil, err + } + + key, err := certutil.ParsePrivateKeyPEM(keyBytes) + if err != nil { + return nil, nil, nil, err + } + + caKeyBytes, err := os.ReadFile(caKeyFile) + if err != nil { + return nil, nil, nil, err + } + + caKey, err := certutil.ParsePrivateKeyPEM(caKeyBytes) + if err != nil { + return nil, nil, nil, err + } + + caBytes, err := os.ReadFile(caCertFile) + if err != nil { + return nil, nil, nil, err + } + + caCert, err := certutil.ParseCertsPEM(caBytes) + if err != nil { + return nil, nil, nil, err + } + + return caCert, caKey.(crypto.Signer), key.(crypto.Signer), nil +} + +// addressGetter is a common signature for functions that return an address channel +type addressGetter func(ctx context.Context) <-chan []string + +// kubernetesGetter returns a function that returns a channel that can be read to get apiserver addresses from kubernetes endpoints +func kubernetesGetter(control *config.Control) addressGetter { + var endpointsClient typedcorev1.EndpointsInterface + return func(ctx context.Context) <-chan []string { + ch := make(chan []string, 1) + go func() { + if endpointsClient == nil { + if control.Runtime.K8s != nil { + endpointsClient = control.Runtime.K8s.CoreV1().Endpoints(metav1.NamespaceDefault) + } + } + if endpointsClient != nil { + if endpoint, err := endpointsClient.Get(ctx, "kubernetes", metav1.GetOptions{}); err != nil { + logrus.Debugf("Failed to get apiserver addresses from kubernetes: %v", err) + } else { + ch <- util.GetAddresses(endpoint) + } + } + close(ch) + }() + return ch + } +} + +// etcdGetter returns a function that returns a channel that can be read to get apiserver addresses from etcd +func etcdGetter(control *config.Control) addressGetter { + return func(ctx context.Context) <-chan []string { + ch := make(chan []string, 1) + go func() { + if addresses, err := etcd.GetAPIServerURLsFromETCD(ctx, control); err != nil { + logrus.Debugf("Failed to get apiserver addresses from etcd: %v", err) + } else { + ch <- addresses + } + close(ch) + }() + return ch + } +} + +// getAddressCollector returns a function that can be called to return +// apiserver addresses from both kubernetes and etcd +func getAddressCollector(control *config.Control) func(ctx context.Context) []string { + getFromKubernetes := kubernetesGetter(control) + getFromEtcd := etcdGetter(control) + + // read from both kubernetes and etcd in parallel, returning the collected results + return func(ctx context.Context) []string { + a := sets.Set[string]{} + r := []string{} + k8sCh := getFromKubernetes(ctx) + k8sOk := true + etcdCh := getFromEtcd(ctx) + etcdOk := true + + for k8sOk || etcdOk { + select { + case r, k8sOk = <-k8sCh: + a.Insert(r...) + case r, etcdOk = <-etcdCh: + a.Insert(r...) + case <-ctx.Done(): + return a.UnsortedList() + } + } + return a.UnsortedList() + } +} diff --git a/pkg/server/handlers/router.go b/pkg/server/handlers/router.go new file mode 100644 index 000000000000..eef6ffe7d5d8 --- /dev/null +++ b/pkg/server/handlers/router.go @@ -0,0 +1,66 @@ +package handlers + +import ( + "context" + "net/http" + "path/filepath" + + "github.com/gorilla/mux" + "github.com/k3s-io/k3s/pkg/cli/cmds" + "github.com/k3s-io/k3s/pkg/daemons/config" + "github.com/k3s-io/k3s/pkg/nodepassword" + "github.com/k3s-io/k3s/pkg/server/auth" + "github.com/k3s-io/k3s/pkg/version" + "k8s.io/apiserver/pkg/authentication/user" + bootstrapapi "k8s.io/cluster-bootstrap/token/api" +) + +const ( + staticURL = "/static/" +) + +func NewHandler(ctx context.Context, control *config.Control, cfg *cmds.Server) http.Handler { + nodeAuth := nodepassword.GetNodeAuthValidator(ctx, control) + + prefix := "/v1-{program}" + authed := mux.NewRouter().SkipClean(true) + authed.NotFoundHandler = APIServer(control, cfg) + authed.Use(auth.HasRole(control, version.Program+":agent", user.NodesGroup, bootstrapapi.BootstrapDefaultGroup)) + authed.Handle(prefix+"/serving-kubelet.crt", ServingKubeletCert(control, nodeAuth)) + authed.Handle(prefix+"/client-kubelet.crt", ClientKubeletCert(control, nodeAuth)) + authed.Handle(prefix+"/client-kube-proxy.crt", ClientKubeProxyCert(control)) + authed.Handle(prefix+"/client-{program}-controller.crt", ClientControllerCert(control)) + authed.Handle(prefix+"/client-ca.crt", File(control.Runtime.ClientCA)) + authed.Handle(prefix+"/server-ca.crt", File(control.Runtime.ServerCA)) + authed.Handle(prefix+"/apiservers", APIServers(control)) + authed.Handle(prefix+"/config", Config(control, cfg)) + authed.Handle(prefix+"/readyz", Readyz(control)) + + nodeAuthed := mux.NewRouter().SkipClean(true) + nodeAuthed.NotFoundHandler = authed + nodeAuthed.Use(auth.HasRole(control, user.NodesGroup)) + nodeAuthed.Handle(prefix+"/connect", control.Runtime.Tunnel) + + serverAuthed := mux.NewRouter().SkipClean(true) + serverAuthed.NotFoundHandler = nodeAuthed + serverAuthed.Use(auth.HasRole(control, version.Program+":server")) + serverAuthed.Handle(prefix+"/encrypt/status", EncryptionStatus(control)) + serverAuthed.Handle(prefix+"/encrypt/config", EncryptionConfig(ctx, control)) + serverAuthed.Handle(prefix+"/cert/cacerts", CACertReplace(control)) + serverAuthed.Handle(prefix+"/server-bootstrap", Bootstrap(control)) + serverAuthed.Handle(prefix+"/token", TokenRequest(ctx, control)) + + systemAuthed := mux.NewRouter().SkipClean(true) + systemAuthed.NotFoundHandler = serverAuthed + systemAuthed.MethodNotAllowedHandler = serverAuthed + systemAuthed.Use(auth.HasRole(control, user.SystemPrivilegedGroup)) + systemAuthed.Methods(http.MethodConnect).Handler(control.Runtime.Tunnel) + + router := mux.NewRouter().SkipClean(true) + router.NotFoundHandler = systemAuthed + router.PathPrefix(staticURL).Handler(Static(staticURL, filepath.Join(control.DataDir, "static"))) + router.Handle("/cacerts", CACerts(control)) + router.Handle("/ping", Ping()) + + return router +} diff --git a/pkg/server/secrets-encrypt.go b/pkg/server/handlers/secrets-encrypt.go similarity index 71% rename from pkg/server/secrets-encrypt.go rename to pkg/server/handlers/secrets-encrypt.go index a3759d9617c4..a6e04048420b 100644 --- a/pkg/server/secrets-encrypt.go +++ b/pkg/server/handlers/secrets-encrypt.go @@ -1,4 +1,4 @@ -package server +package handlers import ( "context" @@ -60,9 +60,9 @@ func getEncryptionRequest(req *http.Request) (*EncryptionRequest, error) { return result, err } -func encryptionStatusHandler(server *config.Control) http.Handler { +func EncryptionStatus(control *config.Control) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - status, err := encryptionStatus(server) + status, err := encryptionStatus(control) if err != nil { util.SendErrorWithID(err, "secret-encrypt", resp, req, http.StatusInternalServerError) return @@ -77,9 +77,9 @@ func encryptionStatusHandler(server *config.Control) http.Handler { }) } -func encryptionStatus(server *config.Control) (EncryptionState, error) { +func encryptionStatus(control *config.Control) (EncryptionState, error) { state := EncryptionState{} - providers, err := secretsencrypt.GetEncryptionProviders(server.Runtime) + providers, err := secretsencrypt.GetEncryptionProviders(control.Runtime) if os.IsNotExist(err) { return state, nil } else if err != nil { @@ -87,17 +87,17 @@ func encryptionStatus(server *config.Control) (EncryptionState, error) { } if providers[1].Identity != nil && providers[0].AESCBC != nil { state.Enable = ptr.To(true) - } else if providers[0].Identity != nil && providers[1].AESCBC != nil || !server.EncryptSecrets { + } else if providers[0].Identity != nil && providers[1].AESCBC != nil || !control.EncryptSecrets { state.Enable = ptr.To(false) } - if err := verifyEncryptionHashAnnotation(server.Runtime, server.Runtime.Core.Core(), ""); err != nil { + if err := verifyEncryptionHashAnnotation(control.Runtime, control.Runtime.Core.Core(), ""); err != nil { state.HashMatch = false state.HashError = err.Error() } else { state.HashMatch = true } - stage, _, err := getEncryptionHashAnnotation(server.Runtime.Core.Core()) + stage, _, err := getEncryptionHashAnnotation(control.Runtime.Core.Core()) if err != nil { return state, err } @@ -122,21 +122,21 @@ func encryptionStatus(server *config.Control) (EncryptionState, error) { return state, nil } -func encryptionEnable(ctx context.Context, server *config.Control, enable bool) error { - providers, err := secretsencrypt.GetEncryptionProviders(server.Runtime) +func encryptionEnable(ctx context.Context, control *config.Control, enable bool) error { + providers, err := secretsencrypt.GetEncryptionProviders(control.Runtime) if err != nil { return err } if len(providers) > 2 { return fmt.Errorf("more than 2 providers (%d) found in secrets encryption", len(providers)) } - curKeys, err := secretsencrypt.GetEncryptionKeys(server.Runtime, false) + curKeys, err := secretsencrypt.GetEncryptionKeys(control.Runtime, false) if err != nil { return err } if providers[1].Identity != nil && providers[0].AESCBC != nil && !enable { logrus.Infoln("Disabling secrets encryption") - if err := secretsencrypt.WriteEncryptionConfig(server.Runtime, curKeys, enable); err != nil { + if err := secretsencrypt.WriteEncryptionConfig(control.Runtime, curKeys, enable); err != nil { return err } } else if !enable { @@ -144,7 +144,7 @@ func encryptionEnable(ctx context.Context, server *config.Control, enable bool) return nil } else if providers[0].Identity != nil && providers[1].AESCBC != nil && enable { logrus.Infoln("Enabling secrets encryption") - if err := secretsencrypt.WriteEncryptionConfig(server.Runtime, curKeys, enable); err != nil { + if err := secretsencrypt.WriteEncryptionConfig(control.Runtime, curKeys, enable); err != nil { return err } } else if enable { @@ -153,13 +153,13 @@ func encryptionEnable(ctx context.Context, server *config.Control, enable bool) } else { return fmt.Errorf("unable to enable/disable secrets encryption, unknown configuration") } - if err := cluster.Save(ctx, server, true); err != nil { + if err := cluster.Save(ctx, control, true); err != nil { return err } - return reencryptAndRemoveKey(ctx, server, true, os.Getenv("NODE_NAME")) + return reencryptAndRemoveKey(ctx, control, true, os.Getenv("NODE_NAME")) } -func encryptionConfigHandler(ctx context.Context, server *config.Control) http.Handler { +func EncryptionConfig(ctx context.Context, control *config.Control) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPut { util.SendError(fmt.Errorf("method not allowed"), resp, req, http.StatusMethodNotAllowed) @@ -173,18 +173,18 @@ func encryptionConfigHandler(ctx context.Context, server *config.Control) http.H if encryptReq.Stage != nil { switch *encryptReq.Stage { case secretsencrypt.EncryptionPrepare: - err = encryptionPrepare(ctx, server, encryptReq.Force) + err = encryptionPrepare(ctx, control, encryptReq.Force) case secretsencrypt.EncryptionRotate: - err = encryptionRotate(ctx, server, encryptReq.Force) + err = encryptionRotate(ctx, control, encryptReq.Force) case secretsencrypt.EncryptionRotateKeys: - err = encryptionRotateKeys(ctx, server) + err = encryptionRotateKeys(ctx, control) case secretsencrypt.EncryptionReencryptActive: - err = encryptionReencrypt(ctx, server, encryptReq.Force, encryptReq.Skip) + err = encryptionReencrypt(ctx, control, encryptReq.Force, encryptReq.Skip) default: err = fmt.Errorf("unknown stage %s requested", *encryptReq.Stage) } } else if encryptReq.Enable != nil { - err = encryptionEnable(ctx, server, *encryptReq.Enable) + err = encryptionEnable(ctx, control, *encryptReq.Enable) } if err != nil { @@ -199,13 +199,13 @@ func encryptionConfigHandler(ctx context.Context, server *config.Control) http.H }) } -func encryptionPrepare(ctx context.Context, server *config.Control, force bool) error { +func encryptionPrepare(ctx context.Context, control *config.Control, force bool) error { states := secretsencrypt.EncryptionStart + "-" + secretsencrypt.EncryptionReencryptFinished - if err := verifyEncryptionHashAnnotation(server.Runtime, server.Runtime.Core.Core(), states); err != nil && !force { + if err := verifyEncryptionHashAnnotation(control.Runtime, control.Runtime.Core.Core(), states); err != nil && !force { return err } - curKeys, err := secretsencrypt.GetEncryptionKeys(server.Runtime, false) + curKeys, err := secretsencrypt.GetEncryptionKeys(control.Runtime, false) if err != nil { return err } @@ -215,29 +215,29 @@ func encryptionPrepare(ctx context.Context, server *config.Control, force bool) } logrus.Infoln("Adding secrets-encryption key: ", curKeys[len(curKeys)-1]) - if err := secretsencrypt.WriteEncryptionConfig(server.Runtime, curKeys, true); err != nil { + if err := secretsencrypt.WriteEncryptionConfig(control.Runtime, curKeys, true); err != nil { return err } nodeName := os.Getenv("NODE_NAME") err = retry.RetryOnConflict(retry.DefaultRetry, func() error { - node, err := server.Runtime.Core.Core().V1().Node().Get(nodeName, metav1.GetOptions{}) + node, err := control.Runtime.Core.Core().V1().Node().Get(nodeName, metav1.GetOptions{}) if err != nil { return err } - return secretsencrypt.WriteEncryptionHashAnnotation(server.Runtime, node, false, secretsencrypt.EncryptionPrepare) + return secretsencrypt.WriteEncryptionHashAnnotation(control.Runtime, node, false, secretsencrypt.EncryptionPrepare) }) if err != nil { return err } - return cluster.Save(ctx, server, true) + return cluster.Save(ctx, control, true) } -func encryptionRotate(ctx context.Context, server *config.Control, force bool) error { - if err := verifyEncryptionHashAnnotation(server.Runtime, server.Runtime.Core.Core(), secretsencrypt.EncryptionPrepare); err != nil && !force { +func encryptionRotate(ctx context.Context, control *config.Control, force bool) error { + if err := verifyEncryptionHashAnnotation(control.Runtime, control.Runtime.Core.Core(), secretsencrypt.EncryptionPrepare); err != nil && !force { return err } - curKeys, err := secretsencrypt.GetEncryptionKeys(server.Runtime, false) + curKeys, err := secretsencrypt.GetEncryptionKeys(control.Runtime, false) if err != nil { return err } @@ -245,49 +245,49 @@ func encryptionRotate(ctx context.Context, server *config.Control, force bool) e // Right rotate elements rotatedKeys := append(curKeys[len(curKeys)-1:], curKeys[:len(curKeys)-1]...) - if err = secretsencrypt.WriteEncryptionConfig(server.Runtime, rotatedKeys, true); err != nil { + if err = secretsencrypt.WriteEncryptionConfig(control.Runtime, rotatedKeys, true); err != nil { return err } logrus.Infoln("Encryption keys right rotated") nodeName := os.Getenv("NODE_NAME") err = retry.RetryOnConflict(retry.DefaultRetry, func() error { - node, err := server.Runtime.Core.Core().V1().Node().Get(nodeName, metav1.GetOptions{}) + node, err := control.Runtime.Core.Core().V1().Node().Get(nodeName, metav1.GetOptions{}) if err != nil { return err } - return secretsencrypt.WriteEncryptionHashAnnotation(server.Runtime, node, false, secretsencrypt.EncryptionRotate) + return secretsencrypt.WriteEncryptionHashAnnotation(control.Runtime, node, false, secretsencrypt.EncryptionRotate) }) if err != nil { return err } - return cluster.Save(ctx, server, true) + return cluster.Save(ctx, control, true) } -func encryptionReencrypt(ctx context.Context, server *config.Control, force bool, skip bool) error { - if err := verifyEncryptionHashAnnotation(server.Runtime, server.Runtime.Core.Core(), secretsencrypt.EncryptionRotate); err != nil && !force { +func encryptionReencrypt(ctx context.Context, control *config.Control, force bool, skip bool) error { + if err := verifyEncryptionHashAnnotation(control.Runtime, control.Runtime.Core.Core(), secretsencrypt.EncryptionRotate); err != nil && !force { return err } // Set the reencrypt-active annotation so other nodes know we are in the process of reencrypting. // As this stage is not persisted, we do not write the annotation to file nodeName := os.Getenv("NODE_NAME") if err := retry.RetryOnConflict(retry.DefaultRetry, func() error { - node, err := server.Runtime.Core.Core().V1().Node().Get(nodeName, metav1.GetOptions{}) + node, err := control.Runtime.Core.Core().V1().Node().Get(nodeName, metav1.GetOptions{}) if err != nil { return err } - return secretsencrypt.WriteEncryptionHashAnnotation(server.Runtime, node, true, secretsencrypt.EncryptionReencryptActive) + return secretsencrypt.WriteEncryptionHashAnnotation(control.Runtime, node, true, secretsencrypt.EncryptionReencryptActive) }); err != nil { return err } // We use a timeout of 10s for the reencrypt call, so finish the process as a go routine and return immediately. // No errors are returned to the user via CLI, any errors will be logged on the server - go reencryptAndRemoveKey(ctx, server, skip, nodeName) + go reencryptAndRemoveKey(ctx, control, skip, nodeName) return nil } -func addAndRotateKeys(server *config.Control) error { - curKeys, err := secretsencrypt.GetEncryptionKeys(server.Runtime, false) +func addAndRotateKeys(control *config.Control) error { + curKeys, err := secretsencrypt.GetEncryptionKeys(control.Runtime, false) if err != nil { return err } @@ -297,29 +297,29 @@ func addAndRotateKeys(server *config.Control) error { } logrus.Infoln("Adding secrets-encryption key: ", curKeys[len(curKeys)-1]) - if err := secretsencrypt.WriteEncryptionConfig(server.Runtime, curKeys, true); err != nil { + if err := secretsencrypt.WriteEncryptionConfig(control.Runtime, curKeys, true); err != nil { return err } // Right rotate elements rotatedKeys := append(curKeys[len(curKeys)-1:], curKeys[:len(curKeys)-1]...) logrus.Infoln("Rotating secrets-encryption keys") - return secretsencrypt.WriteEncryptionConfig(server.Runtime, rotatedKeys, true) + return secretsencrypt.WriteEncryptionConfig(control.Runtime, rotatedKeys, true) } // encryptionRotateKeys is both adds and rotates keys, and sets the annotaiton that triggers the // reencryption process. It is the preferred way to rotate keys, starting with v1.28 -func encryptionRotateKeys(ctx context.Context, server *config.Control) error { +func encryptionRotateKeys(ctx context.Context, control *config.Control) error { states := secretsencrypt.EncryptionStart + "-" + secretsencrypt.EncryptionReencryptFinished - if err := verifyEncryptionHashAnnotation(server.Runtime, server.Runtime.Core.Core(), states); err != nil { + if err := verifyEncryptionHashAnnotation(control.Runtime, control.Runtime.Core.Core(), states); err != nil { return err } - if err := verifyRotateKeysSupport(server.Runtime.Core.Core()); err != nil { + if err := verifyRotateKeysSupport(control.Runtime.Core.Core()); err != nil { return err } - reloadTime, reloadSuccesses, err := secretsencrypt.GetEncryptionConfigMetrics(server.Runtime, true) + reloadTime, reloadSuccesses, err := secretsencrypt.GetEncryptionConfigMetrics(control.Runtime, true) if err != nil { return err } @@ -328,72 +328,72 @@ func encryptionRotateKeys(ctx context.Context, server *config.Control) error { // As this stage is not persisted, we do not write the annotation to file nodeName := os.Getenv("NODE_NAME") if err = retry.RetryOnConflict(retry.DefaultRetry, func() error { - node, err := server.Runtime.Core.Core().V1().Node().Get(nodeName, metav1.GetOptions{}) + node, err := control.Runtime.Core.Core().V1().Node().Get(nodeName, metav1.GetOptions{}) if err != nil { return err } - return secretsencrypt.WriteEncryptionHashAnnotation(server.Runtime, node, true, secretsencrypt.EncryptionReencryptActive) + return secretsencrypt.WriteEncryptionHashAnnotation(control.Runtime, node, true, secretsencrypt.EncryptionReencryptActive) }); err != nil { return err } - if err := addAndRotateKeys(server); err != nil { + if err := addAndRotateKeys(control); err != nil { return err } - if err := secretsencrypt.WaitForEncryptionConfigReload(server.Runtime, reloadSuccesses, reloadTime); err != nil { + if err := secretsencrypt.WaitForEncryptionConfigReload(control.Runtime, reloadSuccesses, reloadTime); err != nil { return err } - return reencryptAndRemoveKey(ctx, server, false, nodeName) + return reencryptAndRemoveKey(ctx, control, false, nodeName) } -func reencryptAndRemoveKey(ctx context.Context, server *config.Control, skip bool, nodeName string) error { - if err := updateSecrets(ctx, server, nodeName); err != nil { +func reencryptAndRemoveKey(ctx context.Context, control *config.Control, skip bool, nodeName string) error { + if err := updateSecrets(ctx, control, nodeName); err != nil { return err } // If skipping, revert back to the previous stage and do not remove the key if skip { err := retry.RetryOnConflict(retry.DefaultRetry, func() error { - node, err := server.Runtime.Core.Core().V1().Node().Get(nodeName, metav1.GetOptions{}) + node, err := control.Runtime.Core.Core().V1().Node().Get(nodeName, metav1.GetOptions{}) if err != nil { return err } - secretsencrypt.BootstrapEncryptionHashAnnotation(node, server.Runtime) - _, err = server.Runtime.Core.Core().V1().Node().Update(node) + secretsencrypt.BootstrapEncryptionHashAnnotation(node, control.Runtime) + _, err = control.Runtime.Core.Core().V1().Node().Update(node) return err }) return err } // Remove last key - curKeys, err := secretsencrypt.GetEncryptionKeys(server.Runtime, false) + curKeys, err := secretsencrypt.GetEncryptionKeys(control.Runtime, false) if err != nil { return err } logrus.Infoln("Removing key: ", curKeys[len(curKeys)-1]) curKeys = curKeys[:len(curKeys)-1] - if err = secretsencrypt.WriteEncryptionConfig(server.Runtime, curKeys, true); err != nil { + if err = secretsencrypt.WriteEncryptionConfig(control.Runtime, curKeys, true); err != nil { return err } if err = retry.RetryOnConflict(retry.DefaultRetry, func() error { - node, err := server.Runtime.Core.Core().V1().Node().Get(nodeName, metav1.GetOptions{}) + node, err := control.Runtime.Core.Core().V1().Node().Get(nodeName, metav1.GetOptions{}) if err != nil { return err } - return secretsencrypt.WriteEncryptionHashAnnotation(server.Runtime, node, false, secretsencrypt.EncryptionReencryptFinished) + return secretsencrypt.WriteEncryptionHashAnnotation(control.Runtime, node, false, secretsencrypt.EncryptionReencryptFinished) }); err != nil { return err } - return cluster.Save(ctx, server, true) + return cluster.Save(ctx, control, true) } -func updateSecrets(ctx context.Context, server *config.Control, nodeName string) error { - k8s := server.Runtime.K8s +func updateSecrets(ctx context.Context, control *config.Control, nodeName string) error { + k8s := control.Runtime.K8s nodeRef := &corev1.ObjectReference{ Kind: "Node", Name: nodeName, diff --git a/pkg/server/token.go b/pkg/server/handlers/token.go similarity index 64% rename from pkg/server/token.go rename to pkg/server/handlers/token.go index efd095013f43..8007859e3271 100644 --- a/pkg/server/token.go +++ b/pkg/server/handlers/token.go @@ -1,4 +1,4 @@ -package server +package handlers import ( "context" @@ -6,8 +6,10 @@ import ( "fmt" "io" "net/http" + "os" "path/filepath" + "github.com/k3s-io/k3s/pkg/clientaccess" "github.com/k3s-io/k3s/pkg/cluster" "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/passwd" @@ -30,7 +32,7 @@ func getServerTokenRequest(req *http.Request) (TokenRotateRequest, error) { return result, err } -func tokenRequestHandler(ctx context.Context, server *config.Control) http.Handler { +func TokenRequest(ctx context.Context, control *config.Control) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPut { util.SendError(fmt.Errorf("method not allowed"), resp, req, http.StatusMethodNotAllowed) @@ -43,7 +45,7 @@ func tokenRequestHandler(ctx context.Context, server *config.Control) http.Handl util.SendError(err, resp, req, http.StatusBadRequest) return } - if err = tokenRotate(ctx, server, *sTokenReq.NewToken); err != nil { + if err = tokenRotate(ctx, control, *sTokenReq.NewToken); err != nil { util.SendErrorWithID(err, "token", resp, req, http.StatusInternalServerError) return } @@ -51,8 +53,20 @@ func tokenRequestHandler(ctx context.Context, server *config.Control) http.Handl }) } -func tokenRotate(ctx context.Context, server *config.Control, newToken string) error { - passwd, err := passwd.Read(server.Runtime.PasswdFile) +func WriteToken(token, file, certs string) error { + if len(token) == 0 { + return nil + } + + token, err := clientaccess.FormatToken(token, certs) + if err != nil { + return err + } + return os.WriteFile(file, []byte(token+"\n"), 0600) +} + +func tokenRotate(ctx context.Context, control *config.Control, newToken string) error { + passwd, err := passwd.Read(control.Runtime.PasswdFile) if err != nil { return err } @@ -76,24 +90,24 @@ func tokenRotate(ctx context.Context, server *config.Control, newToken string) e } // If the agent token is the same a server, we need to change both - if agentToken, found := passwd.Pass("node"); found && agentToken == oldToken && server.AgentToken == "" { + if agentToken, found := passwd.Pass("node"); found && agentToken == oldToken && control.AgentToken == "" { if err := passwd.EnsureUser("node", version.Program+":agent", newToken); err != nil { return err } } - if err := passwd.Write(server.Runtime.PasswdFile); err != nil { + if err := passwd.Write(control.Runtime.PasswdFile); err != nil { return err } - serverTokenFile := filepath.Join(server.DataDir, "token") - if err := writeToken("server:"+newToken, serverTokenFile, server.Runtime.ServerCA); err != nil { + serverTokenFile := filepath.Join(control.DataDir, "token") + if err := WriteToken("server:"+newToken, serverTokenFile, control.Runtime.ServerCA); err != nil { return err } - if err := cluster.RotateBootstrapToken(ctx, server, oldToken); err != nil { + if err := cluster.RotateBootstrapToken(ctx, control, oldToken); err != nil { return err } - server.Token = newToken - return cluster.Save(ctx, server, true) + control.Token = newToken + return cluster.Save(ctx, control, true) } diff --git a/pkg/server/router.go b/pkg/server/router.go deleted file mode 100644 index fca554027880..000000000000 --- a/pkg/server/router.go +++ /dev/null @@ -1,595 +0,0 @@ -package server - -import ( - "context" - "crypto" - "crypto/x509" - "fmt" - "net" - "net/http" - "os" - "path" - "path/filepath" - "strconv" - "strings" - "sync" - "time" - - "github.com/gorilla/mux" - "github.com/k3s-io/k3s/pkg/bootstrap" - "github.com/k3s-io/k3s/pkg/cli/cmds" - "github.com/k3s-io/k3s/pkg/daemons/config" - "github.com/k3s-io/k3s/pkg/etcd" - "github.com/k3s-io/k3s/pkg/nodepassword" - "github.com/k3s-io/k3s/pkg/server/auth" - "github.com/k3s-io/k3s/pkg/util" - "github.com/k3s-io/k3s/pkg/version" - "github.com/pkg/errors" - certutil "github.com/rancher/dynamiclistener/cert" - coreclient "github.com/rancher/wrangler/v3/pkg/generated/controllers/core/v1" - "github.com/sirupsen/logrus" - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/types" - "k8s.io/apimachinery/pkg/util/json" - "k8s.io/apimachinery/pkg/util/sets" - "k8s.io/apimachinery/pkg/util/wait" - "k8s.io/apiserver/pkg/authentication/user" - "k8s.io/apiserver/pkg/endpoints/request" - typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1" - bootstrapapi "k8s.io/cluster-bootstrap/token/api" - "k8s.io/kubernetes/pkg/auth/nodeidentifier" -) - -const ( - staticURL = "/static/" -) - -var ( - identifier = nodeidentifier.NewDefaultNodeIdentifier() -) - -func router(ctx context.Context, config *Config, cfg *cmds.Server) http.Handler { - serverConfig := &config.ControlConfig - nodeAuth := passwordBootstrap(ctx, config) - - prefix := "/v1-" + version.Program - authed := mux.NewRouter().SkipClean(true) - authed.Use(auth.HasRole(serverConfig, version.Program+":agent", user.NodesGroup, bootstrapapi.BootstrapDefaultGroup)) - authed.Path(prefix + "/serving-kubelet.crt").Handler(servingKubeletCert(serverConfig, serverConfig.Runtime.ServingKubeletKey, nodeAuth)) - authed.Path(prefix + "/client-kubelet.crt").Handler(clientKubeletCert(serverConfig, serverConfig.Runtime.ClientKubeletKey, nodeAuth)) - authed.Path(prefix + "/client-kube-proxy.crt").Handler(fileHandler(serverConfig.Runtime.ClientKubeProxyCert, serverConfig.Runtime.ClientKubeProxyKey)) - authed.Path(prefix + "/client-" + version.Program + "-controller.crt").Handler(fileHandler(serverConfig.Runtime.ClientK3sControllerCert, serverConfig.Runtime.ClientK3sControllerKey)) - authed.Path(prefix + "/client-ca.crt").Handler(fileHandler(serverConfig.Runtime.ClientCA)) - authed.Path(prefix + "/server-ca.crt").Handler(fileHandler(serverConfig.Runtime.ServerCA)) - authed.Path(prefix + "/apiservers").Handler(apiserversHandler(serverConfig)) - authed.Path(prefix + "/config").Handler(configHandler(serverConfig, cfg)) - authed.Path(prefix + "/readyz").Handler(readyzHandler(serverConfig)) - - if cfg.DisableAPIServer { - authed.NotFoundHandler = apiserverDisabled() - } else { - authed.NotFoundHandler = apiserver(serverConfig.Runtime) - } - - nodeAuthed := mux.NewRouter().SkipClean(true) - nodeAuthed.NotFoundHandler = authed - nodeAuthed.Use(auth.HasRole(serverConfig, user.NodesGroup)) - nodeAuthed.Path(prefix + "/connect").Handler(serverConfig.Runtime.Tunnel) - - serverAuthed := mux.NewRouter().SkipClean(true) - serverAuthed.NotFoundHandler = nodeAuthed - serverAuthed.Use(auth.HasRole(serverConfig, version.Program+":server")) - serverAuthed.Path(prefix + "/encrypt/status").Handler(encryptionStatusHandler(serverConfig)) - serverAuthed.Path(prefix + "/encrypt/config").Handler(encryptionConfigHandler(ctx, serverConfig)) - serverAuthed.Path(prefix + "/cert/cacerts").Handler(caCertReplaceHandler(serverConfig)) - serverAuthed.Path(prefix + "/server-bootstrap").Handler(bootstrapHandler(serverConfig.Runtime)) - serverAuthed.Path(prefix + "/token").Handler(tokenRequestHandler(ctx, serverConfig)) - - systemAuthed := mux.NewRouter().SkipClean(true) - systemAuthed.NotFoundHandler = serverAuthed - systemAuthed.MethodNotAllowedHandler = serverAuthed - systemAuthed.Use(auth.HasRole(serverConfig, user.SystemPrivilegedGroup)) - systemAuthed.Methods(http.MethodConnect).Handler(serverConfig.Runtime.Tunnel) - - staticDir := filepath.Join(serverConfig.DataDir, "static") - router := mux.NewRouter().SkipClean(true) - router.NotFoundHandler = systemAuthed - router.PathPrefix(staticURL).Handler(serveStatic(staticURL, staticDir)) - router.Path("/cacerts").Handler(cacerts(serverConfig.Runtime.ServerCA)) - router.Path("/ping").Handler(ping()) - - return router -} - -func apiserver(runtime *config.ControlRuntime) http.Handler { - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - if runtime != nil && runtime.APIServer != nil { - runtime.APIServer.ServeHTTP(resp, req) - } else { - util.SendError(util.ErrAPINotReady, resp, req, http.StatusServiceUnavailable) - } - }) -} - -func apiserverDisabled() http.Handler { - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - util.SendError(util.ErrAPIDisabled, resp, req, http.StatusServiceUnavailable) - }) -} - -func bootstrapHandler(runtime *config.ControlRuntime) http.Handler { - if runtime.HTTPBootstrap { - return bootstrap.Handler(&runtime.ControlRuntimeBootstrap) - } - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - logrus.Warnf("Received HTTP bootstrap request from %s, but embedded etcd is not enabled.", req.RemoteAddr) - util.SendError(errors.New("etcd disabled"), resp, req, http.StatusBadRequest) - }) -} - -func cacerts(serverCA string) http.Handler { - var ca []byte - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - if ca == nil { - var err error - ca, err = os.ReadFile(serverCA) - if err != nil { - util.SendError(err, resp, req) - return - } - } - resp.Header().Set("content-type", "text/plain") - resp.Write(ca) - }) -} - -func getNodeInfo(req *http.Request) (*nodeInfo, error) { - user, ok := request.UserFrom(req.Context()) - if !ok { - return nil, errors.New("auth user not set") - } - - nodeName := req.Header.Get(version.Program + "-Node-Name") - if nodeName == "" { - return nil, errors.New("node name not set") - } - - nodePassword := req.Header.Get(version.Program + "-Node-Password") - if nodePassword == "" { - return nil, errors.New("node password not set") - } - - return &nodeInfo{ - Name: strings.ToLower(nodeName), - Password: nodePassword, - User: user, - }, nil -} - -func getCACertAndKeys(caCertFile, caKeyFile, signingKeyFile string) ([]*x509.Certificate, crypto.Signer, crypto.Signer, error) { - keyBytes, err := os.ReadFile(signingKeyFile) - if err != nil { - return nil, nil, nil, err - } - - key, err := certutil.ParsePrivateKeyPEM(keyBytes) - if err != nil { - return nil, nil, nil, err - } - - caKeyBytes, err := os.ReadFile(caKeyFile) - if err != nil { - return nil, nil, nil, err - } - - caKey, err := certutil.ParsePrivateKeyPEM(caKeyBytes) - if err != nil { - return nil, nil, nil, err - } - - caBytes, err := os.ReadFile(caCertFile) - if err != nil { - return nil, nil, nil, err - } - - caCert, err := certutil.ParseCertsPEM(caBytes) - if err != nil { - return nil, nil, nil, err - } - - return caCert, caKey.(crypto.Signer), key.(crypto.Signer), nil -} - -func servingKubeletCert(server *config.Control, keyFile string, auth nodePassBootstrapper) http.Handler { - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - nodeName, errCode, err := auth(req) - if err != nil { - util.SendError(err, resp, req, errCode) - return - } - - caCerts, caKey, key, err := getCACertAndKeys(server.Runtime.ServerCA, server.Runtime.ServerCAKey, server.Runtime.ServingKubeletKey) - if err != nil { - util.SendError(err, resp, req) - return - } - - ips := []net.IP{net.ParseIP("127.0.0.1")} - - if nodeIP := req.Header.Get(version.Program + "-Node-IP"); nodeIP != "" { - for _, v := range strings.Split(nodeIP, ",") { - ip := net.ParseIP(v) - if ip == nil { - util.SendError(fmt.Errorf("invalid node IP address %s", ip), resp, req) - return - } - ips = append(ips, ip) - } - } - - cert, err := certutil.NewSignedCert(certutil.Config{ - CommonName: nodeName, - Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - AltNames: certutil.AltNames{ - DNSNames: []string{nodeName, "localhost"}, - IPs: ips, - }, - }, key, caCerts[0], caKey) - if err != nil { - util.SendError(err, resp, req) - return - } - - keyBytes, err := os.ReadFile(keyFile) - if err != nil { - http.Error(resp, err.Error(), http.StatusInternalServerError) - return - } - - resp.Write(util.EncodeCertsPEM(cert, caCerts)) - resp.Write(keyBytes) - }) -} - -func clientKubeletCert(server *config.Control, keyFile string, auth nodePassBootstrapper) http.Handler { - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - nodeName, errCode, err := auth(req) - if err != nil { - util.SendError(err, resp, req, errCode) - return - } - - caCerts, caKey, key, err := getCACertAndKeys(server.Runtime.ClientCA, server.Runtime.ClientCAKey, server.Runtime.ClientKubeletKey) - if err != nil { - util.SendError(err, resp, req) - return - } - - cert, err := certutil.NewSignedCert(certutil.Config{ - CommonName: "system:node:" + nodeName, - Organization: []string{user.NodesGroup}, - Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, - }, key, caCerts[0], caKey) - if err != nil { - util.SendError(err, resp, req) - return - } - - keyBytes, err := os.ReadFile(keyFile) - if err != nil { - http.Error(resp, err.Error(), http.StatusInternalServerError) - return - } - - resp.Write(util.EncodeCertsPEM(cert, caCerts)) - resp.Write(keyBytes) - }) -} - -func fileHandler(fileName ...string) http.Handler { - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - resp.Header().Set("Content-Type", "text/plain") - - if len(fileName) == 1 { - http.ServeFile(resp, req, fileName[0]) - return - } - - for _, f := range fileName { - bytes, err := os.ReadFile(f) - if err != nil { - util.SendError(errors.Wrapf(err, "failed to read %s", f), resp, req, http.StatusInternalServerError) - return - } - resp.Write(bytes) - } - }) -} - -// apiserversHandler returns a list of apiserver addresses. -// It attempts to merge results from both the apiserver and directly from etcd, -// in case we are recovering from an apiserver outage that rendered the endpoint list unavailable. -func apiserversHandler(server *config.Control) http.Handler { - collectAddresses := getAddressCollector(server) - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - ctx, cancel := context.WithTimeout(req.Context(), 5*time.Second) - defer cancel() - endpoints := collectAddresses(ctx) - resp.Header().Set("content-type", "application/json") - if err := json.NewEncoder(resp).Encode(endpoints); err != nil { - util.SendError(errors.Wrap(err, "failed to encode apiserver endpoints"), resp, req, http.StatusInternalServerError) - } - }) -} - -func configHandler(server *config.Control, cfg *cmds.Server) http.Handler { - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - // Startup hooks may read and modify cmds.Server in a goroutine, but as these are copied into - // config.Control before the startup hooks are called, any modifications need to be sync'd back - // into the struct before it is sent to agents. - // At this time we don't sync all the fields, just those known to be touched by startup hooks. - server.DisableKubeProxy = cfg.DisableKubeProxy - resp.Header().Set("content-type", "application/json") - if err := json.NewEncoder(resp).Encode(server); err != nil { - util.SendError(errors.Wrap(err, "failed to encode agent config"), resp, req, http.StatusInternalServerError) - } - }) -} - -func readyzHandler(server *config.Control) http.Handler { - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - if server.Runtime.Core == nil { - util.SendError(util.ErrCoreNotReady, resp, req, http.StatusServiceUnavailable) - return - } - data := []byte("ok") - resp.WriteHeader(http.StatusOK) - resp.Header().Set("Content-Type", "text/plain") - resp.Header().Set("Content-Length", strconv.Itoa(len(data))) - resp.Write(data) - }) -} - -func ping() http.Handler { - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - data := []byte("pong") - resp.WriteHeader(http.StatusOK) - resp.Header().Set("Content-Type", "text/plain") - resp.Header().Set("Content-Length", strconv.Itoa(len(data))) - resp.Write(data) - }) -} - -func serveStatic(urlPrefix, staticDir string) http.Handler { - return http.StripPrefix(urlPrefix, http.FileServer(http.Dir(staticDir))) -} - -// nodePassBootstrapper returns a node name, or http error code and error -type nodePassBootstrapper func(req *http.Request) (string, int, error) - -// nodeInfo contains information on the requesting node, derived from auth creds -// and request headers. -type nodeInfo struct { - Name string - Password string - User user.Info -} - -func passwordBootstrap(ctx context.Context, config *Config) nodePassBootstrapper { - runtime := config.ControlConfig.Runtime - deferredNodes := map[string]bool{} - var secretClient coreclient.SecretController - var nodeClient coreclient.NodeController - var mu sync.Mutex - - return nodePassBootstrapper(func(req *http.Request) (string, int, error) { - node, err := getNodeInfo(req) - if err != nil { - return "", http.StatusBadRequest, err - } - - nodeName, isNodeAuth := identifier.NodeIdentity(node.User) - if isNodeAuth && nodeName != node.Name { - return "", http.StatusBadRequest, errors.New("header node name does not match auth node name") - } - - if secretClient == nil || nodeClient == nil { - if runtime.Core != nil { - // initialize the client if we can - secretClient = runtime.Core.Core().V1().Secret() - nodeClient = runtime.Core.Core().V1().Node() - } else if node.Name == os.Getenv("NODE_NAME") { - // If we're verifying our own password, verify it locally and ensure a secret later. - return verifyLocalPassword(ctx, config, &mu, deferredNodes, node) - } else if config.ControlConfig.DisableAPIServer && !isNodeAuth { - // If we're running on an etcd-only node, and the request didn't use Node Identity auth, - // defer node password verification until an apiserver joins the cluster. - return verifyRemotePassword(ctx, config, &mu, deferredNodes, node) - } else { - // Otherwise, reject the request until the core is ready. - return "", http.StatusServiceUnavailable, util.ErrCoreNotReady - } - } - - // verify that the node exists, if using Node Identity auth - if err := verifyNode(ctx, nodeClient, node); err != nil { - return "", http.StatusUnauthorized, err - } - - // verify that the node password secret matches, or create it if it does not - if err := nodepassword.Ensure(secretClient, node.Name, node.Password); err != nil { - // if the verification failed, reject the request - if errors.Is(err, nodepassword.ErrVerifyFailed) { - return "", http.StatusForbidden, err - } - // If verification failed due to an error creating the node password secret, allow - // the request, but retry verification until the outage is resolved. This behavior - // allows nodes to join the cluster during outages caused by validating webhooks - // blocking secret creation - if the outage requires new nodes to join in order to - // run the webhook pods, we must fail open here to resolve the outage. - return verifyRemotePassword(ctx, config, &mu, deferredNodes, node) - } - - return node.Name, http.StatusOK, nil - }) -} - -func verifyLocalPassword(ctx context.Context, config *Config, mu *sync.Mutex, deferredNodes map[string]bool, node *nodeInfo) (string, int, error) { - // do not attempt to verify the node password if the local host is not running an agent and does not have a node resource. - if config.DisableAgent { - return node.Name, http.StatusOK, nil - } - - // use same password file location that the agent creates - nodePasswordRoot := "/" - if config.ControlConfig.Rootless { - nodePasswordRoot = filepath.Join(path.Dir(config.ControlConfig.DataDir), "agent") - } - nodeConfigPath := filepath.Join(nodePasswordRoot, "etc", "rancher", "node") - nodePasswordFile := filepath.Join(nodeConfigPath, "password") - - passBytes, err := os.ReadFile(nodePasswordFile) - if err != nil { - return "", http.StatusInternalServerError, errors.Wrap(err, "unable to read node password file") - } - - passHash, err := nodepassword.Hasher.CreateHash(strings.TrimSpace(string(passBytes))) - if err != nil { - return "", http.StatusInternalServerError, errors.Wrap(err, "unable to hash node password file") - } - - if err := nodepassword.Hasher.VerifyHash(passHash, node.Password); err != nil { - return "", http.StatusForbidden, errors.Wrap(err, "unable to verify local node password") - } - - mu.Lock() - defer mu.Unlock() - - if _, ok := deferredNodes[node.Name]; !ok { - deferredNodes[node.Name] = true - go ensureSecret(ctx, config, node) - logrus.Infof("Password verified locally for node %s", node.Name) - } - - return node.Name, http.StatusOK, nil -} - -func verifyRemotePassword(ctx context.Context, config *Config, mu *sync.Mutex, deferredNodes map[string]bool, node *nodeInfo) (string, int, error) { - mu.Lock() - defer mu.Unlock() - - if _, ok := deferredNodes[node.Name]; !ok { - deferredNodes[node.Name] = true - go ensureSecret(ctx, config, node) - logrus.Infof("Password verification deferred for node %s", node.Name) - } - - return node.Name, http.StatusOK, nil -} - -func verifyNode(ctx context.Context, nodeClient coreclient.NodeController, node *nodeInfo) error { - if nodeName, isNodeAuth := identifier.NodeIdentity(node.User); isNodeAuth { - if _, err := nodeClient.Cache().Get(nodeName); err != nil { - return errors.Wrap(err, "unable to verify node identity") - } - } - return nil -} - -func ensureSecret(ctx context.Context, config *Config, node *nodeInfo) { - runtime := config.ControlConfig.Runtime - _ = wait.PollUntilContextCancel(ctx, time.Second*5, true, func(ctx context.Context) (bool, error) { - if runtime.Core != nil { - secretClient := runtime.Core.Core().V1().Secret() - // This is consistent with events attached to the node generated by the kubelet - // https://github.com/kubernetes/kubernetes/blob/612130dd2f4188db839ea5c2dea07a96b0ad8d1c/pkg/kubelet/kubelet.go#L479-L485 - nodeRef := &corev1.ObjectReference{ - Kind: "Node", - Name: node.Name, - UID: types.UID(node.Name), - Namespace: "", - } - if err := nodepassword.Ensure(secretClient, node.Name, node.Password); err != nil { - runtime.Event.Eventf(nodeRef, corev1.EventTypeWarning, "NodePasswordValidationFailed", "Deferred node password secret validation failed: %v", err) - // Return true to stop polling if the password verification failed; only retry on secret creation errors. - return errors.Is(err, nodepassword.ErrVerifyFailed), nil - } - runtime.Event.Event(nodeRef, corev1.EventTypeNormal, "NodePasswordValidationComplete", "Deferred node password secret validation complete") - return true, nil - } - return false, nil - }) -} - -// addressGetter is a common signature for functions that return an address channel -type addressGetter func(ctx context.Context) <-chan []string - -// kubernetesGetter returns a function that returns a channel that can be read to get apiserver addresses from kubernetes endpoints -func kubernetesGetter(server *config.Control) addressGetter { - var endpointsClient typedcorev1.EndpointsInterface - return func(ctx context.Context) <-chan []string { - ch := make(chan []string, 1) - go func() { - if endpointsClient == nil { - if server.Runtime.K8s != nil { - endpointsClient = server.Runtime.K8s.CoreV1().Endpoints(metav1.NamespaceDefault) - } - } - if endpointsClient != nil { - if endpoint, err := endpointsClient.Get(ctx, "kubernetes", metav1.GetOptions{}); err != nil { - logrus.Debugf("Failed to get apiserver addresses from kubernetes: %v", err) - } else { - ch <- util.GetAddresses(endpoint) - } - } - close(ch) - }() - return ch - } -} - -// etcdGetter returns a function that returns a channel that can be read to get apiserver addresses from etcd -func etcdGetter(server *config.Control) addressGetter { - return func(ctx context.Context) <-chan []string { - ch := make(chan []string, 1) - go func() { - if addresses, err := etcd.GetAPIServerURLsFromETCD(ctx, server); err != nil { - logrus.Debugf("Failed to get apiserver addresses from etcd: %v", err) - } else { - ch <- addresses - } - close(ch) - }() - return ch - } -} - -// getAddressCollector returns a function that can be called to return -// apiserver addresses from both kubernetes and etcd -func getAddressCollector(server *config.Control) func(ctx context.Context) []string { - getFromKubernetes := kubernetesGetter(server) - getFromEtcd := etcdGetter(server) - - // read from both kubernetes and etcd in parallel, returning the collected results - return func(ctx context.Context) []string { - a := sets.Set[string]{} - r := []string{} - k8sCh := getFromKubernetes(ctx) - k8sOk := true - etcdCh := getFromEtcd(ctx) - etcdOk := true - - for k8sOk || etcdOk { - select { - case r, k8sOk = <-k8sCh: - a.Insert(r...) - case r, etcdOk = <-etcdCh: - a.Insert(r...) - case <-ctx.Done(): - return a.UnsortedList() - } - } - return a.UnsortedList() - } -} diff --git a/pkg/server/server.go b/pkg/server/server.go index a8c1e0d470f7..81958e1dc748 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -23,6 +23,7 @@ import ( "github.com/k3s-io/k3s/pkg/nodepassword" "github.com/k3s-io/k3s/pkg/rootlessports" "github.com/k3s-io/k3s/pkg/secretsencrypt" + "github.com/k3s-io/k3s/pkg/server/handlers" "github.com/k3s-io/k3s/pkg/static" "github.com/k3s-io/k3s/pkg/util" "github.com/k3s-io/k3s/pkg/version" @@ -58,7 +59,7 @@ func StartServer(ctx context.Context, config *Config, cfg *cmds.Server) error { wg := &sync.WaitGroup{} wg.Add(len(config.StartupHooks)) - config.ControlConfig.Runtime.Handler = router(ctx, config, cfg) + config.ControlConfig.Runtime.Handler = handlers.NewHandler(ctx, &config.ControlConfig, cfg) config.ControlConfig.Runtime.StartupHooksWg = wg shArgs := cmds.StartupHookArgs{ @@ -346,7 +347,7 @@ func printTokens(config *config.Control) error { var serverTokenFile string if config.Runtime.ServerToken != "" { serverTokenFile = filepath.Join(config.DataDir, "token") - if err := writeToken(config.Runtime.ServerToken, serverTokenFile, config.Runtime.ServerCA); err != nil { + if err := handlers.WriteToken(config.Runtime.ServerToken, serverTokenFile, config.Runtime.ServerCA); err != nil { return err } @@ -374,7 +375,7 @@ func printTokens(config *config.Control) error { return err } } - if err := writeToken(config.Runtime.AgentToken, agentTokenFile, config.Runtime.ServerCA); err != nil { + if err := handlers.WriteToken(config.Runtime.AgentToken, agentTokenFile, config.Runtime.ServerCA); err != nil { return err } } else if serverTokenFile != "" { @@ -490,18 +491,6 @@ func printToken(httpsPort int, advertiseIP, prefix, cmd, varName string) { logrus.Infof("%s %s %s -s https://%s:%d -t ${%s}", prefix, version.Program, cmd, advertiseIP, httpsPort, varName) } -func writeToken(token, file, certs string) error { - if len(token) == 0 { - return nil - } - - token, err := clientaccess.FormatToken(token, certs) - if err != nil { - return err - } - return os.WriteFile(file, []byte(token+"\n"), 0600) -} - func setNoProxyEnv(config *config.Control) error { splitter := func(c rune) bool { return c == ',' From e317a00f7fd02e7c29f7d9c80981e63be24c34f8 Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Sun, 15 Dec 2024 22:58:08 +0000 Subject: [PATCH 2/9] Remove unused Certificate field from Node struct Signed-off-by: Brad Davidson --- pkg/agent/config/config.go | 26 ++--- pkg/agent/flannel/setup_test.go | 2 +- pkg/daemons/config/types.go | 2 - pkg/server/handlers/handlers.go | 185 +++++++++++++++----------------- 4 files changed, 96 insertions(+), 119 deletions(-) diff --git a/pkg/agent/config/config.go b/pkg/agent/config/config.go index 9a4842af65de..9dcc9636ce89 100644 --- a/pkg/agent/config/config.go +++ b/pkg/agent/config/config.go @@ -238,27 +238,23 @@ func upgradeOldNodePasswordPath(oldNodePasswordFile, newNodePasswordFile string) } } -func getServingCert(nodeName string, nodeIPs []net.IP, servingCertFile, servingKeyFile, nodePasswordFile string, info *clientaccess.Info) (*tls.Certificate, error) { - servingCert, err := Request("/v1-"+version.Program+"/serving-kubelet.crt", info, getNodeNamedCrt(nodeName, nodeIPs, nodePasswordFile)) +func getServingCert(nodeName string, nodeIPs []net.IP, servingCertFile, servingKeyFile, nodePasswordFile string, info *clientaccess.Info) error { + body, err := Request("/v1-"+version.Program+"/serving-kubelet.crt", info, getNodeNamedCrt(nodeName, nodeIPs, nodePasswordFile)) if err != nil { - return nil, err + return err } - servingCert, servingKey := splitCertKeyPEM(servingCert) + servingCert, servingKey := splitCertKeyPEM(body) if err := os.WriteFile(servingCertFile, servingCert, 0600); err != nil { - return nil, errors.Wrapf(err, "failed to write node cert") + return errors.Wrapf(err, "failed to write node cert") } if err := os.WriteFile(servingKeyFile, servingKey, 0600); err != nil { - return nil, errors.Wrapf(err, "failed to write node key") + return errors.Wrapf(err, "failed to write node key") } - cert, err := tls.X509KeyPair(servingCert, servingKey) - if err != nil { - return nil, err - } - return &cert, nil + return nil } func getHostFile(filename, keyFile string, info *clientaccess.Info) error { @@ -303,11 +299,11 @@ func splitCertKeyPEM(bytes []byte) (certPem []byte, keyPem []byte) { func getNodeNamedHostFile(filename, keyFile, nodeName string, nodeIPs []net.IP, nodePasswordFile string, info *clientaccess.Info) error { basename := filepath.Base(filename) - fileBytes, err := Request("/v1-"+version.Program+"/"+basename, info, getNodeNamedCrt(nodeName, nodeIPs, nodePasswordFile)) + body, err := Request("/v1-"+version.Program+"/"+basename, info, getNodeNamedCrt(nodeName, nodeIPs, nodePasswordFile)) if err != nil { return err } - fileBytes, keyBytes := splitCertKeyPEM(fileBytes) + fileBytes, keyBytes := splitCertKeyPEM(body) if err := os.WriteFile(filename, fileBytes, 0600); err != nil { return errors.Wrapf(err, "failed to write cert %s", filename) @@ -499,8 +495,7 @@ func get(ctx context.Context, envInfo *cmds.Agent, proxy proxy.Proxy) (*config.N nodeExternalAndInternalIPs := append(nodeIPs, nodeExternalIPs...) // Ask the server to generate a kubelet server cert+key. These files are unique to this node. - servingCert, err := getServingCert(nodeName, nodeExternalAndInternalIPs, servingKubeletCert, servingKubeletKey, newNodePasswordFile, info) - if err != nil { + if err := getServingCert(nodeName, nodeExternalAndInternalIPs, servingKubeletCert, servingKubeletKey, newNodePasswordFile, info); err != nil { return nil, errors.Wrap(err, servingKubeletCert) } @@ -625,7 +620,6 @@ func get(ctx context.Context, envInfo *cmds.Agent, proxy proxy.Proxy) (*config.N applyCRIDockerdAddress(nodeConfig) applyContainerdQoSClassConfigFileIfPresent(envInfo, &nodeConfig.Containerd) nodeConfig.Containerd.Template = filepath.Join(envInfo.DataDir, "agent", "etc", "containerd", "config.toml.tmpl") - nodeConfig.Certificate = servingCert if envInfo.BindAddress != "" { nodeConfig.AgentConfig.ListenAddress = envInfo.BindAddress diff --git a/pkg/agent/flannel/setup_test.go b/pkg/agent/flannel/setup_test.go index 78d7bbf2405b..46473649a904 100644 --- a/pkg/agent/flannel/setup_test.go +++ b/pkg/agent/flannel/setup_test.go @@ -62,7 +62,7 @@ func Test_createFlannelConf(t *testing.T) { var agent = config.Agent{} agent.ClusterCIDR = stringToCIDR(tt.args)[0] agent.ClusterCIDRs = stringToCIDR(tt.args) - var nodeConfig = &config.Node{Docker: false, ContainerRuntimeEndpoint: "", SELinux: false, FlannelBackend: "vxlan", FlannelConfFile: "test_file", FlannelConfOverride: false, FlannelIface: nil, Containerd: containerd, Images: "", AgentConfig: agent, Token: "", Certificate: nil, ServerHTTPSPort: 0} + var nodeConfig = &config.Node{Docker: false, ContainerRuntimeEndpoint: "", SELinux: false, FlannelBackend: "vxlan", FlannelConfFile: "test_file", FlannelConfOverride: false, FlannelIface: nil, Containerd: containerd, Images: "", AgentConfig: agent, Token: "", ServerHTTPSPort: 0} t.Run(tt.name, func(t *testing.T) { if err := createFlannelConf(nodeConfig); (err != nil) != tt.wantErr { diff --git a/pkg/daemons/config/types.go b/pkg/daemons/config/types.go index 5dc84f4a7008..ee0996d19807 100644 --- a/pkg/daemons/config/types.go +++ b/pkg/daemons/config/types.go @@ -1,7 +1,6 @@ package config import ( - "crypto/tls" "fmt" "net" "net/http" @@ -57,7 +56,6 @@ type Node struct { Images string AgentConfig Agent Token string - Certificate *tls.Certificate ServerHTTPSPort int SupervisorPort int DefaultRuntime string diff --git a/pkg/server/handlers/handlers.go b/pkg/server/handlers/handlers.go index f060c4c17b40..101554a22ce1 100644 --- a/pkg/server/handlers/handlers.go +++ b/pkg/server/handlers/handlers.go @@ -5,6 +5,7 @@ import ( "crypto" "crypto/x509" "fmt" + "io" "net" "net/http" "os" @@ -63,13 +64,6 @@ func ServingKubeletCert(control *config.Control, auth nodepassword.NodeAuthValid return } - keyFile := control.Runtime.ServingKubeletKey - caCerts, caKey, key, err := getCACertAndKeys(control.Runtime.ServerCA, control.Runtime.ServerCAKey, keyFile) - if err != nil { - util.SendError(err, resp, req) - return - } - ips := []net.IP{net.ParseIP("127.0.0.1")} program := mux.Vars(req)["program"] if nodeIP := req.Header.Get(program + "-Node-IP"); nodeIP != "" { @@ -83,27 +77,14 @@ func ServingKubeletCert(control *config.Control, auth nodepassword.NodeAuthValid } } - cert, err := certutil.NewSignedCert(certutil.Config{ + signAndSend(resp, req, control.Runtime.ServerCA, control.Runtime.ServerCAKey, control.Runtime.ServingKubeletKey, certutil.Config{ CommonName: nodeName, Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, AltNames: certutil.AltNames{ DNSNames: []string{nodeName, "localhost"}, IPs: ips, }, - }, key, caCerts[0], caKey) - if err != nil { - util.SendError(err, resp, req) - return - } - - keyBytes, err := os.ReadFile(keyFile) - if err != nil { - http.Error(resp, err.Error(), http.StatusInternalServerError) - return - } - - resp.Write(util.EncodeCertsPEM(cert, caCerts)) - resp.Write(keyBytes) + }) }) } @@ -114,92 +95,30 @@ func ClientKubeletCert(control *config.Control, auth nodepassword.NodeAuthValida util.SendError(err, resp, req, errCode) return } - - keyFile := control.Runtime.ClientKubeletKey - caCerts, caKey, key, err := getCACertAndKeys(control.Runtime.ClientCA, control.Runtime.ClientCAKey, keyFile) - if err != nil { - util.SendError(err, resp, req) - return - } - - cert, err := certutil.NewSignedCert(certutil.Config{ + signAndSend(resp, req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeProxyKey, certutil.Config{ CommonName: "system:node:" + nodeName, Organization: []string{user.NodesGroup}, Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, - }, key, caCerts[0], caKey) - if err != nil { - util.SendError(err, resp, req) - return - } - - keyBytes, err := os.ReadFile(keyFile) - if err != nil { - http.Error(resp, err.Error(), http.StatusInternalServerError) - return - } - - resp.Write(util.EncodeCertsPEM(cert, caCerts)) - resp.Write(keyBytes) + }) }) } func ClientKubeProxyCert(control *config.Control) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - keyFile := control.Runtime.ClientKubeProxyKey - caCerts, caKey, key, err := getCACertAndKeys(control.Runtime.ClientCA, control.Runtime.ClientCAKey, keyFile) - if err != nil { - util.SendError(err, resp, req) - return - } - - cert, err := certutil.NewSignedCert(certutil.Config{ + signAndSend(resp, req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeProxyKey, certutil.Config{ CommonName: user.KubeProxy, Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, - }, key, caCerts[0], caKey) - if err != nil { - util.SendError(err, resp, req) - return - } - - keyBytes, err := os.ReadFile(keyFile) - if err != nil { - http.Error(resp, err.Error(), http.StatusInternalServerError) - return - } - - resp.Write(util.EncodeCertsPEM(cert, caCerts)) - resp.Write(keyBytes) + }) }) } func ClientControllerCert(control *config.Control) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - keyFile := control.Runtime.ClientK3sControllerKey - caCerts, caKey, key, err := getCACertAndKeys(control.Runtime.ClientCA, control.Runtime.ClientCAKey, keyFile) - if err != nil { - util.SendError(err, resp, req) - return - } - - // This user (system:k3s-controller by default) must be bound to a role in rolebindings.yaml or the downstream equivalent program := mux.Vars(req)["program"] - cert, err := certutil.NewSignedCert(certutil.Config{ + signAndSend(resp, req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientK3sControllerKey, certutil.Config{ CommonName: "system:" + program + "-controller", Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, - }, key, caCerts[0], caKey) - if err != nil { - util.SendError(err, resp, req) - return - } - - keyBytes, err := os.ReadFile(keyFile) - if err != nil { - http.Error(resp, err.Error(), http.StatusInternalServerError) - return - } - - resp.Write(util.EncodeCertsPEM(cert, caCerts)) - resp.Write(keyBytes) + }) }) } @@ -293,38 +212,104 @@ func Static(urlPrefix, staticDir string) http.Handler { return http.StripPrefix(urlPrefix, http.FileServer(http.Dir(staticDir))) } -func getCACertAndKeys(caCertFile, caKeyFile, signingKeyFile string) ([]*x509.Certificate, crypto.Signer, crypto.Signer, error) { - keyBytes, err := os.ReadFile(signingKeyFile) +// csrSigner wraps a CSR with a Public() method and dummy Sign() method to satisfy the +// crypto.Signer interface required by dynamiclistener's cert helpers. +type csrSigner struct { + csr *x509.CertificateRequest +} + +func (c *csrSigner) Public() crypto.PublicKey { + return c.csr.PublicKey +} + +func (c csrSigner) Sign(_ io.Reader, _ []byte, _ crypto.SignerOpts) ([]byte, error) { + return nil, errors.New("not implemented") +} + +// signAndSend generates a signed certificate using the requested CA cert/key and config, +// and sends it to the client. If the client request is a POST with a signing request as +// the body, the public key from the CSR is used to generate the certificate. If the +// client did not submit a signing request, the legacy shared key is used to generate the +// certificate, and the key is sent along with the certificate. +func signAndSend(resp http.ResponseWriter, req *http.Request, caCertFile, caKeyFile, signingKeyFile string, certConfig certutil.Config) { + caCerts, caKey, err := getCACertAndKey(caCertFile, caKeyFile) if err != nil { - return nil, nil, nil, err + util.SendError(err, resp, req) + return + } + + var key crypto.Signer + var keyBytes []byte + if csr, err := getCSR(req); err == nil { + // If the client sent a valid CSR, use the CSR to retrieve the public key + key = &csrSigner{csr: csr} + } else { + // For legacy clients, just use the common key + keyBytes, err = os.ReadFile(signingKeyFile) + if err != nil { + util.SendError(err, resp, req) + return + } + pk, err := certutil.ParsePrivateKeyPEM(keyBytes) + if err != nil { + util.SendError(err, resp, req) + return + } + key = pk.(crypto.Signer) } - key, err := certutil.ParsePrivateKeyPEM(keyBytes) + // create the signed cert using dynamiclistener cert utils + cert, err := certutil.NewSignedCert(certConfig, key, caCerts[0], caKey) if err != nil { - return nil, nil, nil, err + util.SendError(err, resp, req) + return + } + + // send the cert and CA bundle + resp.Write(util.EncodeCertsPEM(cert, caCerts)) + + // also send the common private key, if the client didn't send a CSR + if len(keyBytes) > 0 { + resp.Write(keyBytes) } +} +// getCACertAndKey loads the CA bundle and key at the specified paths. +func getCACertAndKey(caCertFile, caKeyFile string) ([]*x509.Certificate, crypto.Signer, error) { caKeyBytes, err := os.ReadFile(caKeyFile) if err != nil { - return nil, nil, nil, err + return nil, nil, err } caKey, err := certutil.ParsePrivateKeyPEM(caKeyBytes) if err != nil { - return nil, nil, nil, err + return nil, nil, err } caBytes, err := os.ReadFile(caCertFile) if err != nil { - return nil, nil, nil, err + return nil, nil, err } caCert, err := certutil.ParseCertsPEM(caBytes) if err != nil { - return nil, nil, nil, err + return nil, nil, err } - return caCert, caKey.(crypto.Signer), key.(crypto.Signer), nil + return caCert, caKey.(crypto.Signer), nil +} + +// getCSR decodes a x509.CertificateRequest from a POST request body. +// If the request is not a POST, or cannot be parsed as a request, an error is returned. +func getCSR(req *http.Request) (*x509.CertificateRequest, error) { + if req.Method != http.MethodPost { + return nil, mux.ErrMethodMismatch + } + csrBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + return x509.ParseCertificateRequest(csrBytes) } // addressGetter is a common signature for functions that return an address channel From 666e6ad13c0240e2b90eab1873054134880f1f5a Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Sun, 15 Dec 2024 23:49:47 +0000 Subject: [PATCH 3/9] Add client-side certificate generation support Clients now generate keys client-side and send CSRs. If the server is down-level and sends a cert+key instead of just responding with a cert signed with the client's public key, we use the key from the server instead. Signed-off-by: Brad Davidson --- pkg/agent/config/config.go | 145 ++++++++++++++++++++++++++----------- 1 file changed, 104 insertions(+), 41 deletions(-) diff --git a/pkg/agent/config/config.go b/pkg/agent/config/config.go index 9dcc9636ce89..2e5efd1a4de8 100644 --- a/pkg/agent/config/config.go +++ b/pkg/agent/config/config.go @@ -2,9 +2,11 @@ package config import ( "bufio" + "bytes" "context" cryptorand "crypto/rand" "crypto/tls" + "crypto/x509" "encoding/hex" "encoding/pem" "fmt" @@ -32,6 +34,7 @@ import ( "github.com/k3s-io/k3s/pkg/version" "github.com/k3s-io/k3s/pkg/vpn" "github.com/pkg/errors" + certutil "github.com/rancher/dynamiclistener/cert" "github.com/rancher/wharfie/pkg/registries" "github.com/rancher/wrangler/v3/pkg/slice" "github.com/sirupsen/logrus" @@ -133,9 +136,9 @@ func Request(path string, info *clientaccess.Info, requester HTTPRequester) ([]b return requester(u.String(), clientaccess.GetHTTPClient(info.CACerts, info.CertFile, info.KeyFile), info.Username, info.Password, info.Token()) } -func getNodeNamedCrt(nodeName string, nodeIPs []net.IP, nodePasswordFile string) HTTPRequester { +func getNodeNamedCrt(nodeName string, nodeIPs []net.IP, nodePasswordFile string, csr []byte) HTTPRequester { return func(u string, client *http.Client, username, password, token string) ([]byte, error) { - req, err := http.NewRequest(http.MethodGet, u, nil) + req, err := http.NewRequest(http.MethodPost, u, bytes.NewReader(csr)) if err != nil { return nil, err } @@ -238,47 +241,93 @@ func upgradeOldNodePasswordPath(oldNodePasswordFile, newNodePasswordFile string) } } -func getServingCert(nodeName string, nodeIPs []net.IP, servingCertFile, servingKeyFile, nodePasswordFile string, info *clientaccess.Info) error { - body, err := Request("/v1-"+version.Program+"/serving-kubelet.crt", info, getNodeNamedCrt(nodeName, nodeIPs, nodePasswordFile)) +// getKubeletServingCert fills the kubelet server certificate with content returned +// from the server. We attempt to POST a CSR to the server, in hopes that it will +// sign the cert using our locally generated key. If the server does not support CSR +// signing, the key generated by the server is used instead. +func getKubeletServingCert(nodeName string, nodeIPs []net.IP, certFile, keyFile, nodePasswordFile string, info *clientaccess.Info) error { + csr, err := getCSRBytes(keyFile) if err != nil { - return err + return errors.Wrapf(err, "failed to create certificate request %s", certFile) } - servingCert, servingKey := splitCertKeyPEM(body) - - if err := os.WriteFile(servingCertFile, servingCert, 0600); err != nil { - return errors.Wrapf(err, "failed to write node cert") + basename := filepath.Base(certFile) + body, err := Request("/v1-"+version.Program+"/"+basename, info, getNodeNamedCrt(nodeName, nodeIPs, nodePasswordFile, csr)) + if err != nil { + return err } - if err := os.WriteFile(servingKeyFile, servingKey, 0600); err != nil { - return errors.Wrapf(err, "failed to write node key") + // Always split the response, as down-level servers may send back a cert+key + // instead of signing a new cert with our key. If the response includes a key it + // must be used instead of the one we signed the CSR with. + certBytes, keyBytes := splitCertKeyPEM(body) + if err := os.WriteFile(certFile, certBytes, 0600); err != nil { + return errors.Wrapf(err, "failed to write cert %s", certFile) + } + if len(keyBytes) > 0 { + if err := os.WriteFile(keyFile, keyBytes, 0600); err != nil { + return errors.Wrapf(err, "failed to write key %s", keyFile) + } } - return nil } -func getHostFile(filename, keyFile string, info *clientaccess.Info) error { +// getHostFile fills a file with content returned from the server. +func getHostFile(filename string, info *clientaccess.Info) error { basename := filepath.Base(filename) fileBytes, err := info.Get("/v1-" + version.Program + "/" + basename) if err != nil { return err } - if keyFile == "" { - if err := os.WriteFile(filename, fileBytes, 0600); err != nil { - return errors.Wrapf(err, "failed to write cert %s", filename) - } - } else { - fileBytes, keyBytes := splitCertKeyPEM(fileBytes) - if err := os.WriteFile(filename, fileBytes, 0600); err != nil { - return errors.Wrapf(err, "failed to write cert %s", filename) - } + if err := os.WriteFile(filename, fileBytes, 0600); err != nil { + return errors.Wrapf(err, "failed to write cert %s", filename) + } + return nil +} + +// getClientCert fills a client certificate with content returned from the server. +// We attempt to POST a CSR to the server, in hopes that it will sign the cert using +// our locally generated key. If the server does not support CSR signing, the key +// generated by the server is used instead. +func getClientCert(certFile, keyFile string, info *clientaccess.Info) error { + csr, err := getCSRBytes(keyFile) + if err != nil { + return errors.Wrapf(err, "failed to create certificate request %s", certFile) + } + + basename := filepath.Base(certFile) + fileBytes, err := info.Post("/v1-"+version.Program+"/"+basename, csr) + if err != nil { + return err + } + + // Always split the response, as down-level servers may send back a cert+key + // instead of signing a new cert with our key. If the response includes a key it + // must be used instead of the one we signed the CSR with. + certBytes, keyBytes := splitCertKeyPEM(fileBytes) + if err := os.WriteFile(certFile, certBytes, 0600); err != nil { + return errors.Wrapf(err, "failed to write cert %s", certFile) + } + if len(keyBytes) > 0 { if err := os.WriteFile(keyFile, keyBytes, 0600); err != nil { - return errors.Wrapf(err, "failed to write key %s", filename) + return errors.Wrapf(err, "failed to write key %s", keyFile) } } return nil } +func getCSRBytes(keyFile string) ([]byte, error) { + keyBytes, _, err := certutil.LoadOrGenerateKeyFile(keyFile, false) + if err != nil { + return nil, err + } + key, err := certutil.ParsePrivateKeyPEM(keyBytes) + if err != nil { + return nil, err + } + return x509.CreateCertificateRequest(cryptorand.Reader, &x509.CertificateRequest{}, key) +} + func splitCertKeyPEM(bytes []byte) (certPem []byte, keyPem []byte) { for { b, rest := pem.Decode(bytes) @@ -297,19 +346,33 @@ func splitCertKeyPEM(bytes []byte) (certPem []byte, keyPem []byte) { return } -func getNodeNamedHostFile(filename, keyFile, nodeName string, nodeIPs []net.IP, nodePasswordFile string, info *clientaccess.Info) error { - basename := filepath.Base(filename) - body, err := Request("/v1-"+version.Program+"/"+basename, info, getNodeNamedCrt(nodeName, nodeIPs, nodePasswordFile)) +// getKubeletClientCert fills the kubelet client certificate with content returned +// from the server. We attempt to POST a CSR to the server, in hopes that it will +// sign the cert using our locally generated key. If the server does not support CSR +// signing, the key generated by the server is used instead. +func getKubeletClientCert(certFile, keyFile, nodeName string, nodeIPs []net.IP, nodePasswordFile string, info *clientaccess.Info) error { + csr, err := getCSRBytes(keyFile) + if err != nil { + return errors.Wrapf(err, "failed to create certificate request %s", certFile) + } + + basename := filepath.Base(certFile) + body, err := Request("/v1-"+version.Program+"/"+basename, info, getNodeNamedCrt(nodeName, nodeIPs, nodePasswordFile, csr)) if err != nil { return err } - fileBytes, keyBytes := splitCertKeyPEM(body) - if err := os.WriteFile(filename, fileBytes, 0600); err != nil { - return errors.Wrapf(err, "failed to write cert %s", filename) + // Always split the response, as down-level servers may send back a cert+key + // instead of signing a new cert with our key. If the response includes a key it + // must be used instead of the one we signed the CSR with. + certBytes, keyBytes := splitCertKeyPEM(body) + if err := os.WriteFile(certFile, certBytes, 0600); err != nil { + return errors.Wrapf(err, "failed to write cert %s", certFile) } - if err := os.WriteFile(keyFile, keyBytes, 0600); err != nil { - return errors.Wrapf(err, "failed to write key %s", filename) + if len(keyBytes) > 0 { + if err := os.WriteFile(keyFile, keyBytes, 0600); err != nil { + return errors.Wrapf(err, "failed to write key %s", keyFile) + } } return nil } @@ -395,12 +458,12 @@ func get(ctx context.Context, envInfo *cmds.Agent, proxy proxy.Proxy) (*config.N } clientCAFile := filepath.Join(envInfo.DataDir, "agent", "client-ca.crt") - if err := getHostFile(clientCAFile, "", info); err != nil { + if err := getHostFile(clientCAFile, info); err != nil { return nil, err } serverCAFile := filepath.Join(envInfo.DataDir, "agent", "server-ca.crt") - if err := getHostFile(serverCAFile, "", info); err != nil { + if err := getHostFile(serverCAFile, info); err != nil { return nil, err } @@ -494,13 +557,13 @@ func get(ctx context.Context, envInfo *cmds.Agent, proxy proxy.Proxy) (*config.N // that the cert will not be valid for, as they are not present in the list collected here. nodeExternalAndInternalIPs := append(nodeIPs, nodeExternalIPs...) - // Ask the server to generate a kubelet server cert+key. These files are unique to this node. - if err := getServingCert(nodeName, nodeExternalAndInternalIPs, servingKubeletCert, servingKubeletKey, newNodePasswordFile, info); err != nil { + // Ask the server to sign our kubelet server cert. + if err := getKubeletServingCert(nodeName, nodeExternalAndInternalIPs, servingKubeletCert, servingKubeletKey, newNodePasswordFile, info); err != nil { return nil, errors.Wrap(err, servingKubeletCert) } - // Ask the server to genrate a kubelet client cert+key. These files are unique to this node. - if err := getNodeNamedHostFile(clientKubeletCert, clientKubeletKey, nodeName, nodeIPs, newNodePasswordFile, info); err != nil { + // Ask the server to sign our kubelet client cert. + if err := getKubeletClientCert(clientKubeletCert, clientKubeletKey, nodeName, nodeIPs, newNodePasswordFile, info); err != nil { return nil, errors.Wrap(err, clientKubeletCert) } @@ -513,8 +576,8 @@ func get(ctx context.Context, envInfo *cmds.Agent, proxy proxy.Proxy) (*config.N clientKubeProxyCert := filepath.Join(envInfo.DataDir, "agent", "client-kube-proxy.crt") clientKubeProxyKey := filepath.Join(envInfo.DataDir, "agent", "client-kube-proxy.key") - // Ask the server to send us its kube-proxy client cert+key. These files are not unique to this node. - if err := getHostFile(clientKubeProxyCert, clientKubeProxyKey, info); err != nil { + // Ask the server to sign our kube-proxy client cert. + if err := getClientCert(clientKubeProxyCert, clientKubeProxyKey, info); err != nil { return nil, errors.Wrap(err, clientKubeProxyCert) } @@ -527,8 +590,8 @@ func get(ctx context.Context, envInfo *cmds.Agent, proxy proxy.Proxy) (*config.N clientK3sControllerCert := filepath.Join(envInfo.DataDir, "agent", "client-"+version.Program+"-controller.crt") clientK3sControllerKey := filepath.Join(envInfo.DataDir, "agent", "client-"+version.Program+"-controller.key") - // Ask the server to send us its agent controller client cert+key. These files are not unique to this node. - if err := getHostFile(clientK3sControllerCert, clientK3sControllerKey, info); err != nil { + // Ask the server to sign our agent controller client cert. + if err := getClientCert(clientK3sControllerCert, clientK3sControllerKey, info); err != nil { return nil, errors.Wrap(err, clientK3sControllerCert) } From dfc088f9db0b9fe4037d2a06ceb98029420d5bdf Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Thu, 12 Dec 2024 09:43:39 +0000 Subject: [PATCH 4/9] Handle cluster join as create if we're the only member Signed-off-by: Brad Davidson --- pkg/etcd/etcd.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/pkg/etcd/etcd.go b/pkg/etcd/etcd.go index 0282579bcaef..62ab3f0422f3 100644 --- a/pkg/etcd/etcd.go +++ b/pkg/etcd/etcd.go @@ -559,6 +559,7 @@ func (e *ETCD) join(ctx context.Context, clientAccessInfo *clientaccess.Info) er defer cancel() var ( + state string cluster []string add = true ) @@ -620,12 +621,19 @@ func (e *ETCD) join(ctx context.Context, clientAccessInfo *clientaccess.Info) er return err } cluster = append(cluster, fmt.Sprintf("%s=%s", e.name, e.peerURL())) + state = "existing" + } else if len(cluster) > 1 { + logrus.Infof("Starting etcd to join cluster with members %v", cluster) + state = "existing" + } else { + logrus.Infof("Starting etcd for new cluster") + state = "new" } - logrus.Infof("Starting etcd to join cluster with members %v", cluster) return e.cluster(ctx, false, executor.InitialOptions{ - Cluster: strings.Join(cluster, ","), - State: "existing", + AdvertisePeerURL: e.peerURL(), + Cluster: strings.Join(cluster, ","), + State: state, }) } From dbbd294a458ec130ac80b6555c2354f72a6ee450 Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Wed, 18 Dec 2024 19:47:39 +0000 Subject: [PATCH 5/9] Add test for join existing cluster Signed-off-by: Brad Davidson --- pkg/etcd/etcd_test.go | 68 ++++++++++++++++++++++++++++++++++++++----- tests/unit.go | 7 ++++- 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/pkg/etcd/etcd_test.go b/pkg/etcd/etcd_test.go index f875a24ad1b7..52f5d6d4615f 100644 --- a/pkg/etcd/etcd_test.go +++ b/pkg/etcd/etcd_test.go @@ -2,8 +2,10 @@ package etcd import ( "context" + "encoding/json" "net" "net/http" + "net/http/httptest" "os" "path/filepath" "sync" @@ -27,6 +29,7 @@ import ( "google.golang.org/grpc/reflection" "google.golang.org/grpc/status" utilnet "k8s.io/apimachinery/pkg/util/net" + "k8s.io/apimachinery/pkg/util/wait" ) func init() { @@ -234,6 +237,22 @@ func Test_UnitETCD_Register(t *testing.T) { } func Test_UnitETCD_Start(t *testing.T) { + // dummy supervisor API for testing + var memberAddr string + server := httptest.NewServer(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + if req.URL.Path == "/db/info" { + members := []*etcdserverpb.Member{{ + ClientURLs: []string{"https://" + net.JoinHostPort(memberAddr, "2379")}, + PeerURLs: []string{"https://" + net.JoinHostPort(memberAddr, "2380")}, + }} + resp.Header().Set("Content-Type", "application/json") + json.NewEncoder(resp).Encode(&Members{ + Members: members, + }) + } + })) + defer server.Close() + type contextInfo struct { ctx context.Context cancel context.CancelFunc @@ -265,9 +284,6 @@ func Test_UnitETCD_Start(t *testing.T) { address: mustGetAddress(), name: "default", }, - args: args{ - clientAccessInfo: nil, - }, setup: func(e *ETCD, ctxInfo *contextInfo) error { ctxInfo.ctx, ctxInfo.cancel = context.WithCancel(context.Background()) e.config.EtcdDisableSnapshots = true @@ -294,8 +310,37 @@ func Test_UnitETCD_Start(t *testing.T) { name: "default", cron: cron.New(), }, + setup: func(e *ETCD, ctxInfo *contextInfo) error { + ctxInfo.ctx, ctxInfo.cancel = context.WithCancel(context.Background()) + testutil.GenerateRuntime(e.config) + return nil + }, + teardown: func(e *ETCD, ctxInfo *contextInfo) error { + // RemoveSelf will fail with a specific error, but it still does cleanup for testing purposes + err := e.RemoveSelf(ctxInfo.ctx) + ctxInfo.cancel() + time.Sleep(5 * time.Second) + testutil.CleanupDataDir(e.config) + if err != nil && err.Error() != etcdserver.ErrNotEnoughStartedMembers.Error() { + return err + } + return nil + }, + }, + { + name: "valid clientAccessInfo", + fields: fields{ + config: generateTestConfig(), + address: mustGetAddress(), + name: "default", + cron: cron.New(), + }, args: args{ - clientAccessInfo: nil, + clientAccessInfo: &clientaccess.Info{ + BaseURL: "http://" + server.Listener.Addr().String(), + Username: "server", + Password: "token", + }, }, setup: func(e *ETCD, ctxInfo *contextInfo) error { ctxInfo.ctx, ctxInfo.cancel = context.WithCancel(context.Background()) @@ -322,9 +367,6 @@ func Test_UnitETCD_Start(t *testing.T) { name: "default", cron: cron.New(), }, - args: args{ - clientAccessInfo: nil, - }, setup: func(e *ETCD, ctxInfo *contextInfo) error { ctxInfo.ctx, ctxInfo.cancel = context.WithCancel(context.Background()) if err := testutil.GenerateRuntime(e.config); err != nil { @@ -364,6 +406,18 @@ func Test_UnitETCD_Start(t *testing.T) { if err := e.Start(tt.fields.context.ctx, tt.args.clientAccessInfo); (err != nil) != tt.wantErr { t.Errorf("ETCD.Start() error = %v, wantErr %v", err, tt.wantErr) } + if !tt.wantErr { + memberAddr = e.address + if err := wait.PollUntilContextTimeout(tt.fields.context.ctx, time.Second, time.Minute, true, func(ctx context.Context) (bool, error) { + if _, err := e.getETCDStatus(tt.fields.context.ctx, ""); err != nil { + t.Logf("Waiting to get etcd status: %v", err) + return false, nil + } + return true, nil + }); err != nil { + t.Errorf("Failed to get etcd status: %v", err) + } + } if err := tt.teardown(e, &tt.fields.context); err != nil { t.Errorf("Teardown for ETCD.Start() failed = %v", err) } diff --git a/tests/unit.go b/tests/unit.go index 5bc40cea3b55..61aecba6497a 100644 --- a/tests/unit.go +++ b/tests/unit.go @@ -43,7 +43,12 @@ func CleanupDataDir(cnf *config.Control) { // GenerateRuntime creates a temporary data dir and configures // config.ControlRuntime with all the appropriate certificate keys. func GenerateRuntime(cnf *config.Control) error { - cnf.Runtime = config.NewRuntime(nil) + // reuse ready channel from existing runtime if set + var readyCh <-chan struct{} + if cnf.Runtime != nil { + readyCh = cnf.Runtime.ContainerRuntimeReady + } + cnf.Runtime = config.NewRuntime(readyCh) if err := GenerateDataDir(cnf); err != nil { return err } From 2fdf244c17736075079b7b0e6c3011d3ab7d5214 Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Tue, 17 Dec 2024 22:39:06 +0000 Subject: [PATCH 6/9] Move core/v1 mock into tests package for reuse Signed-off-by: Brad Davidson --- pkg/etcd/s3/s3_test.go | 159 +++++++---------------------------------- tests/mock/core.go | 113 +++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+), 132 deletions(-) create mode 100644 tests/mock/core.go diff --git a/pkg/etcd/s3/s3_test.go b/pkg/etcd/s3/s3_test.go index 9dc8f9e41fc4..aea2e3e077c3 100644 --- a/pkg/etcd/s3/s3_test.go +++ b/pkg/etcd/s3/s3_test.go @@ -17,16 +17,13 @@ import ( "github.com/gorilla/mux" "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/etcd/snapshot" + "github.com/k3s-io/k3s/tests/mock" "github.com/rancher/dynamiclistener/cert" "github.com/rancher/wrangler/v3/pkg/generated/controllers/core" - corev1 "github.com/rancher/wrangler/v3/pkg/generated/controllers/core/v1" - "github.com/rancher/wrangler/v3/pkg/generic/fake" "github.com/sirupsen/logrus" "go.uber.org/mock/gomock" v1 "k8s.io/api/core/v1" - apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/utils/lru" ) @@ -95,7 +92,7 @@ func Test_UnitControllerGetClient(t *testing.T) { }, wantErr: true, setup: func(t *testing.T, a args, f fields, c *Client) (core.Interface, error) { - coreMock := newCoreMock(gomock.NewController(t)) + coreMock := mock.NewCore(gomock.NewController(t)) return coreMock, nil }, }, @@ -119,7 +116,7 @@ func Test_UnitControllerGetClient(t *testing.T) { }, wantErr: true, setup: func(t *testing.T, a args, f fields, c *Client) (core.Interface, error) { - coreMock := newCoreMock(gomock.NewController(t)) + coreMock := mock.NewCore(gomock.NewController(t)) return coreMock, nil }, }, @@ -144,7 +141,7 @@ func Test_UnitControllerGetClient(t *testing.T) { }, wantErr: true, setup: func(t *testing.T, a args, f fields, c *Client) (core.Interface, error) { - coreMock := newCoreMock(gomock.NewController(t)) + coreMock := mock.NewCore(gomock.NewController(t)) return coreMock, nil }, }, @@ -167,9 +164,9 @@ func Test_UnitControllerGetClient(t *testing.T) { }, wantErr: true, setup: func(t *testing.T, a args, f fields, c *Client) (core.Interface, error) { - coreMock := newCoreMock(gomock.NewController(t)) - coreMock.v1.secret.EXPECT().Get(metav1.NamespaceSystem, "my-etcd-s3-config-secret", gomock.Any()).AnyTimes().DoAndReturn(func(namespace, name string, _ metav1.GetOptions) (*v1.Secret, error) { - return nil, errorNotFound("secret", name) + coreMock := mock.NewCore(gomock.NewController(t)) + coreMock.V1Mock.SecretMock.EXPECT().Get(metav1.NamespaceSystem, "my-etcd-s3-config-secret", gomock.Any()).AnyTimes().DoAndReturn(func(namespace, name string, _ metav1.GetOptions) (*v1.Secret, error) { + return nil, mock.ErrorNotFound("secret", name) }) return coreMock, nil }, @@ -192,8 +189,8 @@ func Test_UnitControllerGetClient(t *testing.T) { clientCache: lru.New(5), }, setup: func(t *testing.T, a args, f fields, c *Client) (core.Interface, error) { - coreMock := newCoreMock(gomock.NewController(t)) - coreMock.v1.secret.EXPECT().Get(metav1.NamespaceSystem, "my-etcd-s3-config-secret", gomock.Any()).AnyTimes().DoAndReturn(func(namespace, name string, _ metav1.GetOptions) (*v1.Secret, error) { + coreMock := mock.NewCore(gomock.NewController(t)) + coreMock.V1Mock.SecretMock.EXPECT().Get(metav1.NamespaceSystem, "my-etcd-s3-config-secret", gomock.Any()).AnyTimes().DoAndReturn(func(namespace, name string, _ metav1.GetOptions) (*v1.Secret, error) { return &v1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: name, @@ -231,8 +228,8 @@ func Test_UnitControllerGetClient(t *testing.T) { clientCache: lru.New(5), }, setup: func(t *testing.T, a args, f fields, c *Client) (core.Interface, error) { - coreMock := newCoreMock(gomock.NewController(t)) - coreMock.v1.secret.EXPECT().Get(metav1.NamespaceSystem, "my-etcd-s3-config-secret", gomock.Any()).AnyTimes().DoAndReturn(func(namespace, name string, _ metav1.GetOptions) (*v1.Secret, error) { + coreMock := mock.NewCore(gomock.NewController(t)) + coreMock.V1Mock.SecretMock.EXPECT().Get(metav1.NamespaceSystem, "my-etcd-s3-config-secret", gomock.Any()).AnyTimes().DoAndReturn(func(namespace, name string, _ metav1.GetOptions) (*v1.Secret, error) { return &v1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: name, @@ -250,7 +247,7 @@ func Test_UnitControllerGetClient(t *testing.T) { }, }, nil }) - coreMock.v1.configMap.EXPECT().Get(metav1.NamespaceSystem, "my-etcd-s3-ca", gomock.Any()).AnyTimes().DoAndReturn(func(namespace, name string, _ metav1.GetOptions) (*v1.ConfigMap, error) { + coreMock.V1Mock.ConfigMapMock.EXPECT().Get(metav1.NamespaceSystem, "my-etcd-s3-ca", gomock.Any()).AnyTimes().DoAndReturn(func(namespace, name string, _ metav1.GetOptions) (*v1.ConfigMap, error) { return &v1.ConfigMap{ ObjectMeta: metav1.ObjectMeta{ Name: name, @@ -286,8 +283,8 @@ func Test_UnitControllerGetClient(t *testing.T) { }, wantErr: true, setup: func(t *testing.T, a args, f fields, c *Client) (core.Interface, error) { - coreMock := newCoreMock(gomock.NewController(t)) - coreMock.v1.secret.EXPECT().Get(metav1.NamespaceSystem, "my-etcd-s3-config-secret", gomock.Any()).AnyTimes().DoAndReturn(func(namespace, name string, _ metav1.GetOptions) (*v1.Secret, error) { + coreMock := mock.NewCore(gomock.NewController(t)) + coreMock.V1Mock.SecretMock.EXPECT().Get(metav1.NamespaceSystem, "my-etcd-s3-config-secret", gomock.Any()).AnyTimes().DoAndReturn(func(namespace, name string, _ metav1.GetOptions) (*v1.Secret, error) { return &v1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: name, @@ -307,8 +304,8 @@ func Test_UnitControllerGetClient(t *testing.T) { }, }, nil }) - coreMock.v1.configMap.EXPECT().Get(metav1.NamespaceSystem, "my-etcd-s3-ca", gomock.Any()).AnyTimes().DoAndReturn(func(namespace, name string, _ metav1.GetOptions) (*v1.ConfigMap, error) { - return nil, errorNotFound("configmap", name) + coreMock.V1Mock.ConfigMapMock.EXPECT().Get(metav1.NamespaceSystem, "my-etcd-s3-ca", gomock.Any()).AnyTimes().DoAndReturn(func(namespace, name string, _ metav1.GetOptions) (*v1.ConfigMap, error) { + return nil, mock.ErrorNotFound("configmap", name) }) return coreMock, nil }, @@ -334,7 +331,7 @@ func Test_UnitControllerGetClient(t *testing.T) { clientCache: lru.New(5), }, setup: func(t *testing.T, a args, f fields, c *Client) (core.Interface, error) { - coreMock := newCoreMock(gomock.NewController(t)) + coreMock := mock.NewCore(gomock.NewController(t)) return coreMock, nil }, }, @@ -359,7 +356,7 @@ func Test_UnitControllerGetClient(t *testing.T) { clientCache: lru.New(5), }, setup: func(t *testing.T, a args, f fields, c *Client) (core.Interface, error) { - coreMock := newCoreMock(gomock.NewController(t)) + coreMock := mock.NewCore(gomock.NewController(t)) return coreMock, nil }, }, @@ -383,7 +380,7 @@ func Test_UnitControllerGetClient(t *testing.T) { clientCache: lru.New(5), }, setup: func(t *testing.T, a args, f fields, c *Client) (core.Interface, error) { - coreMock := newCoreMock(gomock.NewController(t)) + coreMock := mock.NewCore(gomock.NewController(t)) return coreMock, nil }, }, @@ -415,8 +412,8 @@ func Test_UnitControllerGetClient(t *testing.T) { Timeout: *defaultEtcdS3.Timeout.DeepCopy(), } f.clientCache.Add(*c.etcdS3, c) - coreMock := newCoreMock(gomock.NewController(t)) - coreMock.v1.secret.EXPECT().Get(metav1.NamespaceSystem, "my-etcd-s3-config-secret", gomock.Any()).AnyTimes().DoAndReturn(func(namespace, name string, _ metav1.GetOptions) (*v1.Secret, error) { + coreMock := mock.NewCore(gomock.NewController(t)) + coreMock.V1Mock.SecretMock.EXPECT().Get(metav1.NamespaceSystem, "my-etcd-s3-config-secret", gomock.Any()).AnyTimes().DoAndReturn(func(namespace, name string, _ metav1.GetOptions) (*v1.Secret, error) { return &v1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: name, @@ -458,7 +455,7 @@ func Test_UnitControllerGetClient(t *testing.T) { setup: func(t *testing.T, a args, f fields, c *Client) (core.Interface, error) { c.etcdS3 = a.etcdS3 f.clientCache.Add(*c.etcdS3, c) - coreMock := newCoreMock(gomock.NewController(t)) + coreMock := mock.NewCore(gomock.NewController(t)) return coreMock, nil }, }, @@ -484,7 +481,7 @@ func Test_UnitControllerGetClient(t *testing.T) { clientCache: lru.New(5), }, setup: func(t *testing.T, a args, f fields, c *Client) (core.Interface, error) { - coreMock := newCoreMock(gomock.NewController(t)) + coreMock := mock.NewCore(gomock.NewController(t)) return coreMock, nil }, }, @@ -511,7 +508,7 @@ func Test_UnitControllerGetClient(t *testing.T) { }, wantErr: true, setup: func(t *testing.T, a args, f fields, c *Client) (core.Interface, error) { - coreMock := newCoreMock(gomock.NewController(t)) + coreMock := mock.NewCore(gomock.NewController(t)) return coreMock, nil }, }, @@ -538,7 +535,7 @@ func Test_UnitControllerGetClient(t *testing.T) { }, wantErr: true, setup: func(t *testing.T, a args, f fields, c *Client) (core.Interface, error) { - coreMock := newCoreMock(gomock.NewController(t)) + coreMock := mock.NewCore(gomock.NewController(t)) return coreMock, nil }, }, @@ -563,7 +560,7 @@ func Test_UnitControllerGetClient(t *testing.T) { clientCache: lru.New(5), }, setup: func(t *testing.T, a args, f fields, c *Client) (core.Interface, error) { - coreMock := newCoreMock(gomock.NewController(t)) + coreMock := mock.NewCore(gomock.NewController(t)) return coreMock, nil }, }, @@ -589,7 +586,7 @@ func Test_UnitControllerGetClient(t *testing.T) { }, wantErr: true, setup: func(t *testing.T, a args, f fields, c *Client) (core.Interface, error) { - coreMock := newCoreMock(gomock.NewController(t)) + coreMock := mock.NewCore(gomock.NewController(t)) return coreMock, nil }, }, @@ -1447,108 +1444,6 @@ func Test_UnitClientSnapshotRetention(t *testing.T) { } } -// -// Mocks so that we can call Runtime.Core.Core().V1() without a functioning apiserver -// - -// explicit interface check for core mock -var _ core.Interface = &coreMock{} - -type coreMock struct { - v1 *v1Mock -} - -func newCoreMock(c *gomock.Controller) *coreMock { - return &coreMock{ - v1: newV1Mock(c), - } -} - -func (m *coreMock) V1() corev1.Interface { - return m.v1 -} - -// explicit interface check for core v1 mock -var _ corev1.Interface = &v1Mock{} - -type v1Mock struct { - configMap *fake.MockControllerInterface[*v1.ConfigMap, *v1.ConfigMapList] - endpoints *fake.MockControllerInterface[*v1.Endpoints, *v1.EndpointsList] - event *fake.MockControllerInterface[*v1.Event, *v1.EventList] - namespace *fake.MockNonNamespacedControllerInterface[*v1.Namespace, *v1.NamespaceList] - node *fake.MockNonNamespacedControllerInterface[*v1.Node, *v1.NodeList] - persistentVolume *fake.MockNonNamespacedControllerInterface[*v1.PersistentVolume, *v1.PersistentVolumeList] - persistentVolumeClaim *fake.MockControllerInterface[*v1.PersistentVolumeClaim, *v1.PersistentVolumeClaimList] - pod *fake.MockControllerInterface[*v1.Pod, *v1.PodList] - secret *fake.MockControllerInterface[*v1.Secret, *v1.SecretList] - service *fake.MockControllerInterface[*v1.Service, *v1.ServiceList] - serviceAccount *fake.MockControllerInterface[*v1.ServiceAccount, *v1.ServiceAccountList] -} - -func newV1Mock(c *gomock.Controller) *v1Mock { - return &v1Mock{ - configMap: fake.NewMockControllerInterface[*v1.ConfigMap, *v1.ConfigMapList](c), - endpoints: fake.NewMockControllerInterface[*v1.Endpoints, *v1.EndpointsList](c), - event: fake.NewMockControllerInterface[*v1.Event, *v1.EventList](c), - namespace: fake.NewMockNonNamespacedControllerInterface[*v1.Namespace, *v1.NamespaceList](c), - node: fake.NewMockNonNamespacedControllerInterface[*v1.Node, *v1.NodeList](c), - persistentVolume: fake.NewMockNonNamespacedControllerInterface[*v1.PersistentVolume, *v1.PersistentVolumeList](c), - persistentVolumeClaim: fake.NewMockControllerInterface[*v1.PersistentVolumeClaim, *v1.PersistentVolumeClaimList](c), - pod: fake.NewMockControllerInterface[*v1.Pod, *v1.PodList](c), - secret: fake.NewMockControllerInterface[*v1.Secret, *v1.SecretList](c), - service: fake.NewMockControllerInterface[*v1.Service, *v1.ServiceList](c), - serviceAccount: fake.NewMockControllerInterface[*v1.ServiceAccount, *v1.ServiceAccountList](c), - } -} - -func (m *v1Mock) ConfigMap() corev1.ConfigMapController { - return m.configMap -} - -func (m *v1Mock) Endpoints() corev1.EndpointsController { - return m.endpoints -} - -func (m *v1Mock) Event() corev1.EventController { - return m.event -} - -func (m *v1Mock) Namespace() corev1.NamespaceController { - return m.namespace -} - -func (m *v1Mock) Node() corev1.NodeController { - return m.node -} - -func (m *v1Mock) PersistentVolume() corev1.PersistentVolumeController { - return m.persistentVolume -} - -func (m *v1Mock) PersistentVolumeClaim() corev1.PersistentVolumeClaimController { - return m.persistentVolumeClaim -} - -func (m *v1Mock) Pod() corev1.PodController { - return m.pod -} - -func (m *v1Mock) Secret() corev1.SecretController { - return m.secret -} - -func (m *v1Mock) Service() corev1.ServiceController { - return m.service -} - -func (m *v1Mock) ServiceAccount() corev1.ServiceAccountController { - return m.serviceAccount -} - -func errorNotFound(gv, name string) error { - return apierrors.NewNotFound(schema.ParseGroupResource(gv), name) -} - // // ListObjects response body template // diff --git a/tests/mock/core.go b/tests/mock/core.go new file mode 100644 index 000000000000..432a12090162 --- /dev/null +++ b/tests/mock/core.go @@ -0,0 +1,113 @@ +package mock + +import ( + "github.com/rancher/wrangler/v3/pkg/generated/controllers/core" + corev1 "github.com/rancher/wrangler/v3/pkg/generated/controllers/core/v1" + "github.com/rancher/wrangler/v3/pkg/generic/fake" + "go.uber.org/mock/gomock" + v1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime/schema" +) + +// +// Mocks so that we can call Runtime.Core.Core().V1() without a functioning apiserver +// + +// explicit interface check for core mock +var _ core.Interface = &CoreMock{} + +type CoreMock struct { + V1Mock *V1Mock +} + +func NewCore(c *gomock.Controller) *CoreMock { + return &CoreMock{ + V1Mock: NewV1(c), + } +} + +func (m *CoreMock) V1() corev1.Interface { + return m.V1Mock +} + +// explicit interface check for core v1 mock +var _ corev1.Interface = &V1Mock{} + +type V1Mock struct { + ConfigMapMock *fake.MockControllerInterface[*v1.ConfigMap, *v1.ConfigMapList] + EndpointsMock *fake.MockControllerInterface[*v1.Endpoints, *v1.EndpointsList] + EventMock *fake.MockControllerInterface[*v1.Event, *v1.EventList] + NamespaceMock *fake.MockNonNamespacedControllerInterface[*v1.Namespace, *v1.NamespaceList] + NodeMock *fake.MockNonNamespacedControllerInterface[*v1.Node, *v1.NodeList] + PersistentVolumeMock *fake.MockNonNamespacedControllerInterface[*v1.PersistentVolume, *v1.PersistentVolumeList] + PersistentVolumeClaimMock *fake.MockControllerInterface[*v1.PersistentVolumeClaim, *v1.PersistentVolumeClaimList] + PodMock *fake.MockControllerInterface[*v1.Pod, *v1.PodList] + SecretMock *fake.MockControllerInterface[*v1.Secret, *v1.SecretList] + ServiceMock *fake.MockControllerInterface[*v1.Service, *v1.ServiceList] + ServiceAccountMock *fake.MockControllerInterface[*v1.ServiceAccount, *v1.ServiceAccountList] +} + +func NewV1(c *gomock.Controller) *V1Mock { + return &V1Mock{ + ConfigMapMock: fake.NewMockControllerInterface[*v1.ConfigMap, *v1.ConfigMapList](c), + EndpointsMock: fake.NewMockControllerInterface[*v1.Endpoints, *v1.EndpointsList](c), + EventMock: fake.NewMockControllerInterface[*v1.Event, *v1.EventList](c), + NamespaceMock: fake.NewMockNonNamespacedControllerInterface[*v1.Namespace, *v1.NamespaceList](c), + NodeMock: fake.NewMockNonNamespacedControllerInterface[*v1.Node, *v1.NodeList](c), + PersistentVolumeMock: fake.NewMockNonNamespacedControllerInterface[*v1.PersistentVolume, *v1.PersistentVolumeList](c), + PersistentVolumeClaimMock: fake.NewMockControllerInterface[*v1.PersistentVolumeClaim, *v1.PersistentVolumeClaimList](c), + PodMock: fake.NewMockControllerInterface[*v1.Pod, *v1.PodList](c), + SecretMock: fake.NewMockControllerInterface[*v1.Secret, *v1.SecretList](c), + ServiceMock: fake.NewMockControllerInterface[*v1.Service, *v1.ServiceList](c), + ServiceAccountMock: fake.NewMockControllerInterface[*v1.ServiceAccount, *v1.ServiceAccountList](c), + } +} + +func (m *V1Mock) ConfigMap() corev1.ConfigMapController { + return m.ConfigMapMock +} + +func (m *V1Mock) Endpoints() corev1.EndpointsController { + return m.EndpointsMock +} + +func (m *V1Mock) Event() corev1.EventController { + return m.EventMock +} + +func (m *V1Mock) Namespace() corev1.NamespaceController { + return m.NamespaceMock +} + +func (m *V1Mock) Node() corev1.NodeController { + return m.NodeMock +} + +func (m *V1Mock) PersistentVolume() corev1.PersistentVolumeController { + return m.PersistentVolumeMock +} + +func (m *V1Mock) PersistentVolumeClaim() corev1.PersistentVolumeClaimController { + return m.PersistentVolumeClaimMock +} + +func (m *V1Mock) Pod() corev1.PodController { + return m.PodMock +} + +func (m *V1Mock) Secret() corev1.SecretController { + return m.SecretMock +} + +func (m *V1Mock) Service() corev1.ServiceController { + return m.ServiceMock +} + +func (m *V1Mock) ServiceAccount() corev1.ServiceAccountController { + return m.ServiceAccountMock +} + +func ErrorNotFound(gv, name string) error { + return apierrors.NewNotFound(schema.ParseGroupResource(gv), name) +} From 84f9f4c7e3b874bffbb66996ec33bcbacab14edd Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Wed, 18 Dec 2024 00:41:23 +0000 Subject: [PATCH 7/9] Move additional core/v1 mocks into tests package Convert nodepassword tests to use shared mocks Signed-off-by: Brad Davidson --- pkg/nodepassword/nodepassword_test.go | 115 ++++++-------------- tests/mock/core.go | 151 ++++++++++++++++++++++---- 2 files changed, 160 insertions(+), 106 deletions(-) diff --git a/pkg/nodepassword/nodepassword_test.go b/pkg/nodepassword/nodepassword_test.go index af4bbd71a638..c7b16a5a0ed0 100644 --- a/pkg/nodepassword/nodepassword_test.go +++ b/pkg/nodepassword/nodepassword_test.go @@ -8,7 +8,7 @@ import ( "runtime" "testing" - "github.com/rancher/wrangler/v3/pkg/generic/fake" + "github.com/k3s-io/k3s/tests/mock" "go.uber.org/mock/gomock" v1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -28,10 +28,11 @@ func Test_UnitAsserts(t *testing.T) { func Test_UnitEnsureDelete(t *testing.T) { logMemUsage(t) - ctrl := gomock.NewController(t) - secretClient := fake.NewMockControllerInterface[*v1.Secret, *v1.SecretList](ctrl) - secretCache := fake.NewMockCacheInterface[*v1.Secret](ctrl) - secretStore := &mockSecretStore{} + v1Mock := mock.NewV1(gomock.NewController(t)) + + secretClient := v1Mock.SecretMock + secretCache := v1Mock.SecretCache + secretStore := &mock.SecretStore{} // Set up expected call counts for tests // Expect to see 2 creates, any number of cache gets, and 2 deletes. @@ -59,15 +60,15 @@ func Test_UnitMigrateFile(t *testing.T) { nodePasswordFile := generateNodePasswordFile(migrateNumNodes) defer os.Remove(nodePasswordFile) - ctrl := gomock.NewController(t) + v1Mock := mock.NewV1(gomock.NewController(t)) - secretClient := fake.NewMockControllerInterface[*v1.Secret, *v1.SecretList](ctrl) - secretCache := fake.NewMockCacheInterface[*v1.Secret](ctrl) - secretStore := &mockSecretStore{} + secretClient := v1Mock.SecretMock + secretCache := v1Mock.SecretCache + secretStore := &mock.SecretStore{} - nodeClient := fake.NewMockNonNamespacedControllerInterface[*v1.Node, *v1.NodeList](ctrl) - nodeCache := fake.NewMockNonNamespacedCacheInterface[*v1.Node](ctrl) - nodeStore := &mockNodeStore{} + nodeClient := v1Mock.NodeMock + nodeCache := v1Mock.NodeCache + nodeStore := &mock.NodeStore{} // Set up expected call counts for tests // Expect to see 1 node list, any number of cache gets, and however many @@ -93,19 +94,20 @@ func Test_UnitMigrateFileNodes(t *testing.T) { nodePasswordFile := generateNodePasswordFile(migrateNumNodes) defer os.Remove(nodePasswordFile) - ctrl := gomock.NewController(t) + v1Mock := mock.NewV1(gomock.NewController(t)) - secretClient := fake.NewMockControllerInterface[*v1.Secret, *v1.SecretList](ctrl) - secretCache := fake.NewMockCacheInterface[*v1.Secret](ctrl) - secretStore := &mockSecretStore{} + secretClient := v1Mock.SecretMock + secretCache := v1Mock.SecretCache + secretStore := &mock.SecretStore{} - nodeClient := fake.NewMockNonNamespacedControllerInterface[*v1.Node, *v1.NodeList](ctrl) - nodeCache := fake.NewMockNonNamespacedCacheInterface[*v1.Node](ctrl) - nodeStore := &mockNodeStore{} + nodeClient := v1Mock.NodeMock + nodeCache := v1Mock.NodeCache + nodeStore := &mock.NodeStore{} - nodeStore.nodes = make([]v1.Node, createNumNodes, createNumNodes) - for i := range nodeStore.nodes { - nodeStore.nodes[i].Name = fmt.Sprintf("node%d", i+1) + for i := 0; i < createNumNodes; i++ { + if _, err := nodeStore.Create(&v1.Node{ObjectMeta: metav1.ObjectMeta{Name: fmt.Sprintf("node%d", i+1)}}); err != nil { + t.Fatal(err) + } } // Set up expected call counts for tests @@ -124,9 +126,13 @@ func Test_UnitMigrateFileNodes(t *testing.T) { } logMemUsage(t) - for _, node := range nodeStore.nodes { - assertNotEqual(t, Ensure(secretClient, node.Name, "wrong-password"), nil) - assertEqual(t, Ensure(secretClient, node.Name, node.Name), nil) + if nodes, err := nodeStore.List(labels.Everything()); err != nil { + t.Fatal(err) + } else { + for _, node := range nodes { + assertNotEqual(t, Ensure(secretClient, node.Name, "wrong-password"), nil) + assertEqual(t, Ensure(secretClient, node.Name, node.Name), nil) + } } newNode := fmt.Sprintf("node%d", migrateNumNodes+1) @@ -141,65 +147,6 @@ func Test_PasswordError(t *testing.T) { assertNotEqual(t, errors.Unwrap(err), nil) } -// -------------------------- -// mock secret store interface - -type mockSecretStore struct { - entries map[string]map[string]v1.Secret -} - -func (m *mockSecretStore) Create(secret *v1.Secret) (*v1.Secret, error) { - if m.entries == nil { - m.entries = map[string]map[string]v1.Secret{} - } - if _, ok := m.entries[secret.Namespace]; !ok { - m.entries[secret.Namespace] = map[string]v1.Secret{} - } - if _, ok := m.entries[secret.Namespace][secret.Name]; ok { - return nil, errorAlreadyExists() - } - m.entries[secret.Namespace][secret.Name] = *secret - return secret, nil -} - -func (m *mockSecretStore) Delete(namespace, name string, options *metav1.DeleteOptions) error { - if m.entries == nil { - return errorNotFound() - } - if _, ok := m.entries[namespace]; !ok { - return errorNotFound() - } - if _, ok := m.entries[namespace][name]; !ok { - return errorNotFound() - } - delete(m.entries[namespace], name) - return nil -} - -func (m *mockSecretStore) Get(namespace, name string) (*v1.Secret, error) { - if m.entries == nil { - return nil, errorNotFound() - } - if _, ok := m.entries[namespace]; !ok { - return nil, errorNotFound() - } - if secret, ok := m.entries[namespace][name]; ok { - return &secret, nil - } - return nil, errorNotFound() -} - -// -------------------------- -// mock node store interface - -type mockNodeStore struct { - nodes []v1.Node -} - -func (m *mockNodeStore) List(ls labels.Selector) ([]v1.Node, error) { - return m.nodes, nil -} - // -------------------------- // utility functions diff --git a/tests/mock/core.go b/tests/mock/core.go index 432a12090162..37468c714bc9 100644 --- a/tests/mock/core.go +++ b/tests/mock/core.go @@ -7,6 +7,8 @@ import ( "go.uber.org/mock/gomock" v1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime/schema" ) @@ -35,32 +37,54 @@ func (m *CoreMock) V1() corev1.Interface { var _ corev1.Interface = &V1Mock{} type V1Mock struct { - ConfigMapMock *fake.MockControllerInterface[*v1.ConfigMap, *v1.ConfigMapList] - EndpointsMock *fake.MockControllerInterface[*v1.Endpoints, *v1.EndpointsList] - EventMock *fake.MockControllerInterface[*v1.Event, *v1.EventList] - NamespaceMock *fake.MockNonNamespacedControllerInterface[*v1.Namespace, *v1.NamespaceList] - NodeMock *fake.MockNonNamespacedControllerInterface[*v1.Node, *v1.NodeList] - PersistentVolumeMock *fake.MockNonNamespacedControllerInterface[*v1.PersistentVolume, *v1.PersistentVolumeList] - PersistentVolumeClaimMock *fake.MockControllerInterface[*v1.PersistentVolumeClaim, *v1.PersistentVolumeClaimList] - PodMock *fake.MockControllerInterface[*v1.Pod, *v1.PodList] - SecretMock *fake.MockControllerInterface[*v1.Secret, *v1.SecretList] - ServiceMock *fake.MockControllerInterface[*v1.Service, *v1.ServiceList] - ServiceAccountMock *fake.MockControllerInterface[*v1.ServiceAccount, *v1.ServiceAccountList] + ConfigMapMock *fake.MockControllerInterface[*v1.ConfigMap, *v1.ConfigMapList] + ConfigMapCache *fake.MockCacheInterface[*v1.ConfigMap] + EndpointsMock *fake.MockControllerInterface[*v1.Endpoints, *v1.EndpointsList] + EndpointsCache *fake.MockCacheInterface[*v1.Endpoints] + EventMock *fake.MockControllerInterface[*v1.Event, *v1.EventList] + EventCache *fake.MockCacheInterface[*v1.Event] + NamespaceMock *fake.MockNonNamespacedControllerInterface[*v1.Namespace, *v1.NamespaceList] + NamespaceCache *fake.MockNonNamespacedCacheInterface[*v1.Namespace] + NodeMock *fake.MockNonNamespacedControllerInterface[*v1.Node, *v1.NodeList] + NodeCache *fake.MockNonNamespacedCacheInterface[*v1.Node] + PersistentVolumeMock *fake.MockNonNamespacedControllerInterface[*v1.PersistentVolume, *v1.PersistentVolumeList] + PersistentVolumeCache *fake.MockNonNamespacedCacheInterface[*v1.PersistentVolume] + PersistentVolumeClaimMock *fake.MockControllerInterface[*v1.PersistentVolumeClaim, *v1.PersistentVolumeClaimList] + PersistentVolumeClaimCache *fake.MockCacheInterface[*v1.PersistentVolumeClaim] + PodMock *fake.MockControllerInterface[*v1.Pod, *v1.PodList] + PodCache *fake.MockCacheInterface[*v1.Pod] + SecretMock *fake.MockControllerInterface[*v1.Secret, *v1.SecretList] + SecretCache *fake.MockCacheInterface[*v1.Secret] + ServiceMock *fake.MockControllerInterface[*v1.Service, *v1.ServiceList] + ServiceCache *fake.MockCacheInterface[*v1.Service] + ServiceAccountMock *fake.MockControllerInterface[*v1.ServiceAccount, *v1.ServiceAccountList] + ServiceAccountCache *fake.MockCacheInterface[*v1.ServiceAccount] } func NewV1(c *gomock.Controller) *V1Mock { return &V1Mock{ - ConfigMapMock: fake.NewMockControllerInterface[*v1.ConfigMap, *v1.ConfigMapList](c), - EndpointsMock: fake.NewMockControllerInterface[*v1.Endpoints, *v1.EndpointsList](c), - EventMock: fake.NewMockControllerInterface[*v1.Event, *v1.EventList](c), - NamespaceMock: fake.NewMockNonNamespacedControllerInterface[*v1.Namespace, *v1.NamespaceList](c), - NodeMock: fake.NewMockNonNamespacedControllerInterface[*v1.Node, *v1.NodeList](c), - PersistentVolumeMock: fake.NewMockNonNamespacedControllerInterface[*v1.PersistentVolume, *v1.PersistentVolumeList](c), - PersistentVolumeClaimMock: fake.NewMockControllerInterface[*v1.PersistentVolumeClaim, *v1.PersistentVolumeClaimList](c), - PodMock: fake.NewMockControllerInterface[*v1.Pod, *v1.PodList](c), - SecretMock: fake.NewMockControllerInterface[*v1.Secret, *v1.SecretList](c), - ServiceMock: fake.NewMockControllerInterface[*v1.Service, *v1.ServiceList](c), - ServiceAccountMock: fake.NewMockControllerInterface[*v1.ServiceAccount, *v1.ServiceAccountList](c), + ConfigMapMock: fake.NewMockControllerInterface[*v1.ConfigMap, *v1.ConfigMapList](c), + ConfigMapCache: fake.NewMockCacheInterface[*v1.ConfigMap](c), + EndpointsMock: fake.NewMockControllerInterface[*v1.Endpoints, *v1.EndpointsList](c), + EndpointsCache: fake.NewMockCacheInterface[*v1.Endpoints](c), + EventMock: fake.NewMockControllerInterface[*v1.Event, *v1.EventList](c), + EventCache: fake.NewMockCacheInterface[*v1.Event](c), + NamespaceMock: fake.NewMockNonNamespacedControllerInterface[*v1.Namespace, *v1.NamespaceList](c), + NamespaceCache: fake.NewMockNonNamespacedCacheInterface[*v1.Namespace](c), + NodeMock: fake.NewMockNonNamespacedControllerInterface[*v1.Node, *v1.NodeList](c), + NodeCache: fake.NewMockNonNamespacedCacheInterface[*v1.Node](c), + PersistentVolumeMock: fake.NewMockNonNamespacedControllerInterface[*v1.PersistentVolume, *v1.PersistentVolumeList](c), + PersistentVolumeCache: fake.NewMockNonNamespacedCacheInterface[*v1.PersistentVolume](c), + PersistentVolumeClaimMock: fake.NewMockControllerInterface[*v1.PersistentVolumeClaim, *v1.PersistentVolumeClaimList](c), + PersistentVolumeClaimCache: fake.NewMockCacheInterface[*v1.PersistentVolumeClaim](c), + PodMock: fake.NewMockControllerInterface[*v1.Pod, *v1.PodList](c), + PodCache: fake.NewMockCacheInterface[*v1.Pod](c), + SecretMock: fake.NewMockControllerInterface[*v1.Secret, *v1.SecretList](c), + SecretCache: fake.NewMockCacheInterface[*v1.Secret](c), + ServiceMock: fake.NewMockControllerInterface[*v1.Service, *v1.ServiceList](c), + ServiceCache: fake.NewMockCacheInterface[*v1.Service](c), + ServiceAccountMock: fake.NewMockControllerInterface[*v1.ServiceAccount, *v1.ServiceAccountList](c), + ServiceAccountCache: fake.NewMockCacheInterface[*v1.ServiceAccount](c), } } @@ -108,6 +132,89 @@ func (m *V1Mock) ServiceAccount() corev1.ServiceAccountController { return m.ServiceAccountMock } +// mock secret store interface + +type SecretStore struct { + secrets map[string]map[string]v1.Secret +} + +func (m *SecretStore) Create(secret *v1.Secret) (*v1.Secret, error) { + if m.secrets == nil { + m.secrets = map[string]map[string]v1.Secret{} + } + if _, ok := m.secrets[secret.Namespace]; !ok { + m.secrets[secret.Namespace] = map[string]v1.Secret{} + } + if _, ok := m.secrets[secret.Namespace][secret.Name]; ok { + return nil, ErrorAlreadyExists("secret", secret.Name) + } + m.secrets[secret.Namespace][secret.Name] = *secret + return secret, nil +} + +func (m *SecretStore) Delete(namespace, name string, options *metav1.DeleteOptions) error { + if m.secrets == nil { + return ErrorNotFound("secret", name) + } + if _, ok := m.secrets[namespace]; !ok { + return ErrorNotFound("secret", name) + } + if _, ok := m.secrets[namespace][name]; !ok { + return ErrorNotFound("secret", name) + } + delete(m.secrets[namespace], name) + return nil +} + +func (m *SecretStore) Get(namespace, name string) (*v1.Secret, error) { + if m.secrets == nil { + return nil, ErrorNotFound("secret", name) + } + if _, ok := m.secrets[namespace]; !ok { + return nil, ErrorNotFound("secret", name) + } + if secret, ok := m.secrets[namespace][name]; ok { + return &secret, nil + } + return nil, ErrorNotFound("secret", name) +} + +// mock node store interface + +type NodeStore struct { + nodes map[string]v1.Node +} + +func (m *NodeStore) Create(node *v1.Node) (*v1.Node, error) { + if m.nodes == nil { + m.nodes = map[string]v1.Node{} + } + if _, ok := m.nodes[node.Name]; ok { + return nil, ErrorAlreadyExists("node", node.Name) + } + m.nodes[node.Name] = *node + return node, nil +} + +func (m *NodeStore) List(ls labels.Selector) ([]v1.Node, error) { + nodes := []v1.Node{} + if ls == nil { + ls = labels.Everything() + } + for _, node := range m.nodes { + if ls.Matches(labels.Set(node.Labels)) { + nodes = append(nodes, node) + } + } + return nodes, nil +} + +// utility functions + func ErrorNotFound(gv, name string) error { return apierrors.NewNotFound(schema.ParseGroupResource(gv), name) } + +func ErrorAlreadyExists(gv, name string) error { + return apierrors.NewAlreadyExists(schema.ParseGroupResource(gv), name) +} From 882efd1873d5fba079e914a4be1a372ac5b7b326 Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Tue, 17 Dec 2024 23:23:54 +0000 Subject: [PATCH 8/9] Replace *core.Factory with CoreFactory interface Make this field an interface instead of pointer to allow mocking. Not sure why wrangler has a type that returns an interface instead of just making it an interface itself. Wrangler in general is hard to mock for testing. Signed-off-by: Brad Davidson --- pkg/cluster/https.go | 10 +++++++--- pkg/daemons/config/types.go | 9 ++++++++- tests/mock/core.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/pkg/cluster/https.go b/pkg/cluster/https.go index 1b25d321bcdd..5705384440c1 100644 --- a/pkg/cluster/https.go +++ b/pkg/cluster/https.go @@ -141,9 +141,13 @@ func (c *Cluster) initClusterAndHTTPS(ctx context.Context) error { func tlsStorage(ctx context.Context, dataDir string, runtime *config.ControlRuntime) dynamiclistener.TLSStorage { fileStorage := file.New(filepath.Join(dataDir, "tls/dynamic-cert.json")) cache := memory.NewBacked(fileStorage) - return kubernetes.New(ctx, func() *core.Factory { - return runtime.Core - }, metav1.NamespaceSystem, version.Program+"-serving", cache) + coreGetter := func() *core.Factory { + if coreFactory, ok := runtime.Core.(*core.Factory); ok { + return coreFactory + } + return nil + } + return kubernetes.New(ctx, coreGetter, metav1.NamespaceSystem, version.Program+"-serving", cache) } // wrapHandler wraps the dynamiclistener request handler, adding a User-Agent value to diff --git a/pkg/daemons/config/types.go b/pkg/daemons/config/types.go index ee0996d19807..2de46082cefe 100644 --- a/pkg/daemons/config/types.go +++ b/pkg/daemons/config/types.go @@ -1,6 +1,7 @@ package config import ( + "context" "fmt" "net" "net/http" @@ -373,11 +374,17 @@ type ControlRuntime struct { K8s kubernetes.Interface K3s *k3s.Factory - Core *core.Factory + Core CoreFactory Event record.EventRecorder EtcdConfig endpoint.ETCDConfig } +type CoreFactory interface { + Core() core.Interface + Sync(ctx context.Context) error + Start(ctx context.Context, defaultThreadiness int) error +} + func NewRuntime(containerRuntimeReady <-chan struct{}) *ControlRuntime { return &ControlRuntime{ ContainerRuntimeReady: containerRuntimeReady, diff --git a/tests/mock/core.go b/tests/mock/core.go index 37468c714bc9..b42be171c047 100644 --- a/tests/mock/core.go +++ b/tests/mock/core.go @@ -1,6 +1,9 @@ package mock import ( + "context" + + "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/rancher/wrangler/v3/pkg/generated/controllers/core" corev1 "github.com/rancher/wrangler/v3/pkg/generated/controllers/core/v1" "github.com/rancher/wrangler/v3/pkg/generic/fake" @@ -16,6 +19,31 @@ import ( // Mocks so that we can call Runtime.Core.Core().V1() without a functioning apiserver // +// explicit interface check for core factory mock +var _ config.CoreFactory = &CoreFactoryMock{} + +type CoreFactoryMock struct { + CoreMock *CoreMock +} + +func NewCoreFactory(c *gomock.Controller) *CoreFactoryMock { + return &CoreFactoryMock{ + CoreMock: NewCore(c), + } +} + +func (m *CoreFactoryMock) Core() core.Interface { + return m.CoreMock +} + +func (m *CoreFactoryMock) Sync(ctx context.Context) error { + return nil +} + +func (m *CoreFactoryMock) Start(ctx context.Context, defaultThreadiness int) error { + return nil +} + // explicit interface check for core mock var _ core.Interface = &CoreMock{} From eb8144a37574ea5b287005b76eb163c4a619a43c Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Mon, 16 Dec 2024 22:18:48 +0000 Subject: [PATCH 9/9] Add tests for supervisor request handlers Signed-off-by: Brad Davidson --- pkg/server/handlers/handlers_test.go | 909 +++++++++++++++++++++++++++ tests/mock/core.go | 10 + 2 files changed, 919 insertions(+) create mode 100644 pkg/server/handlers/handlers_test.go diff --git a/pkg/server/handlers/handlers_test.go b/pkg/server/handlers/handlers_test.go new file mode 100644 index 000000000000..311a3fe11afc --- /dev/null +++ b/pkg/server/handlers/handlers_test.go @@ -0,0 +1,909 @@ +package handlers + +import ( + "bytes" + "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "io" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/k3s-io/k3s/pkg/authenticator" + "github.com/k3s-io/k3s/pkg/cli/cmds" + "github.com/k3s-io/k3s/pkg/daemons/config" + testutil "github.com/k3s-io/k3s/tests" + "github.com/k3s-io/k3s/tests/mock" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/types" + certutil "github.com/rancher/dynamiclistener/cert" + "github.com/sirupsen/logrus" + "go.uber.org/mock/gomock" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apiserver/pkg/authentication/user" +) + +func init() { + logrus.SetLevel(logrus.DebugLevel) +} + +func Test_UnitHandlers(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + control := &config.Control{ + Token: "token", + AgentToken: "agent-token", + ServerNodeName: "k3s-server-1", + } + + os.Setenv("NODE_NAME", control.ServerNodeName) + control.DataDir = t.TempDir() + testutil.GenerateRuntime(control) + + // add dummy handler for tunnel/proxy CONNECT requests, since we're not + // setting up a whole remotedialer tunnel server here + control.Runtime.Tunnel = http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {}) + + // wire up mock controllers and cache stores + secretStore := &mock.SecretStore{} + nodeStore := &mock.NodeStore{} + nodeStore.Create(&v1.Node{ObjectMeta: metav1.ObjectMeta{Name: control.ServerNodeName}}) + nodeStore.Create(&v1.Node{ObjectMeta: metav1.ObjectMeta{Name: "k3s-agent-1"}}) + + ctrl := gomock.NewController(t) + coreFactory := mock.NewCoreFactory(ctrl) + coreFactory.CoreMock.V1Mock.SecretMock.EXPECT().Cache().AnyTimes().Return(coreFactory.CoreMock.V1Mock.SecretCache) + coreFactory.CoreMock.V1Mock.SecretMock.EXPECT().Create(gomock.Any()).AnyTimes().DoAndReturn(secretStore.Create) + coreFactory.CoreMock.V1Mock.SecretCache.EXPECT().Get(gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn(secretStore.Get) + coreFactory.CoreMock.V1Mock.NodeMock.EXPECT().Cache().AnyTimes().Return(coreFactory.CoreMock.V1Mock.NodeCache) + coreFactory.CoreMock.V1Mock.NodeCache.EXPECT().Get(gomock.Any()).AnyTimes().DoAndReturn(nodeStore.Get) + control.Runtime.Core = coreFactory + + // add authenticator + auth, err := authenticator.FromArgs([]string{ + "--basic-auth-file=" + control.Runtime.PasswdFile, + "--client-ca-file=" + control.Runtime.ClientCA, + }) + NewWithT(t).Expect(err).ToNot(HaveOccurred()) + control.Runtime.Authenticator = auth + + // finally, bind request handlers + control.Runtime.Handler = NewHandler(ctx, control, &cmds.Server{}) + + type sub struct { + name string + prepare func(control *config.Control, req *http.Request) + match func(control *config.Control) types.GomegaMatcher + } + + genericFailures := []sub{ + { + name: "anonymous", + match: func(_ *config.Control) types.GomegaMatcher { + return HaveHTTPStatus(http.StatusForbidden) + }, + }, { + name: "bad basic", + prepare: func(control *config.Control, req *http.Request) { + req.SetBasicAuth("server", control.AgentToken) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return HaveHTTPStatus(http.StatusUnauthorized) + }, + }, { + name: "valid cert but untrusted CA", + prepare: func(control *config.Control, req *http.Request) { + withNewClientCert(req, control.Runtime.ServerCA, control.Runtime.ServerCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:" + control.ServerNodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return HaveHTTPStatus(http.StatusUnauthorized) + }, + }, { + name: "valid cert but no RBAC", + prepare: func(control *config.Control, req *http.Request) { + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:monitoring", + Organization: []string{user.MonitoringGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return HaveHTTPStatus(http.StatusForbidden) + }, + }, + } + + tests := []struct { + method string + path string + subs []sub + }{ + //** paths accessible with node cert or agent token, and specific headers ** + { + method: http.MethodGet, + path: "/v1-k3s/serving-kubelet.crt", + subs: append(genericFailures, + sub{ + name: "valid basic but missing headers", + prepare: func(control *config.Control, req *http.Request) { + req.SetBasicAuth("node", control.AgentToken) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return HaveHTTPStatus(http.StatusBadRequest) + }, + }, + sub{ + name: "valid cert but missing headers", + prepare: func(control *config.Control, req *http.Request) { + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:" + control.ServerNodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return HaveHTTPStatus(http.StatusBadRequest) + }, + }, + sub{ + name: "valid cert but wrong node name", + prepare: func(control *config.Control, req *http.Request) { + req.Header.Add("k3s-Node-Name", control.ServerNodeName) + req.Header.Add("k3s-Node-Password", "password") + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:k3s-agent-1", + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return HaveHTTPStatus(http.StatusBadRequest) + }, + }, + sub{ + name: "valid cert but nonexistent node", + prepare: func(control *config.Control, req *http.Request) { + req.Header.Add("k3s-Node-Name", "nonexistent") + req.Header.Add("k3s-Node-Password", "password") + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:nonexistent", + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return HaveHTTPStatus(http.StatusUnauthorized) + }, + }, + sub{ + name: "valid basic legacy key", + prepare: func(control *config.Control, req *http.Request) { + req.Header.Add("k3s-Node-Name", control.ServerNodeName) + req.Header.Add("k3s-Node-Password", "password") + req.SetBasicAuth("node", control.AgentToken) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(ContainSubstring("PRIVATE KEY")), + ) + }, + }, + sub{ + name: "valid cert legacy key", + prepare: func(control *config.Control, req *http.Request) { + req.Header.Add("k3s-Node-Name", control.ServerNodeName) + req.Header.Add("k3s-Node-Password", "password") + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:" + control.ServerNodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(ContainSubstring("PRIVATE KEY")), + ) + }, + }, + sub{ + name: "valid basic different node", + prepare: func(control *config.Control, req *http.Request) { + req.Header.Add("k3s-Node-Name", "k3s-agent-1") + req.Header.Add("k3s-Node-Password", "password") + req.SetBasicAuth("node", control.AgentToken) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(ContainSubstring("PRIVATE KEY")), + ) + }, + }, + sub{ + name: "valid basic bad node password", + prepare: func(control *config.Control, req *http.Request) { + req.Header.Add("k3s-Node-Name", "k3s-agent-1") + req.Header.Add("k3s-Node-Password", "invalid-password") + req.SetBasicAuth("node", control.AgentToken) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusForbidden), + ) + }, + }, + ), + }, { + method: http.MethodPost, + path: "/v1-k3s/serving-kubelet.crt", + subs: append(genericFailures, + sub{ + name: "valid basic client key", + prepare: func(control *config.Control, req *http.Request) { + req.Header.Add("k3s-Node-Name", control.ServerNodeName) + req.Header.Add("k3s-Node-Password", "password") + withCertificateRequest(req) + req.SetBasicAuth("node", control.AgentToken) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(Not(ContainSubstring("PRIVATE KEY"))), + ) + }, + }, + sub{ + name: "valid cert client key", + prepare: func(control *config.Control, req *http.Request) { + req.Header.Add("k3s-Node-Name", control.ServerNodeName) + req.Header.Add("k3s-Node-Password", "password") + withCertificateRequest(req) + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:" + control.ServerNodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(Not(ContainSubstring("PRIVATE KEY"))), + ) + }, + }, + ), + }, { + method: http.MethodGet, + path: "/v1-k3s/client-kubelet.crt", + subs: append(genericFailures, + sub{ + name: "valid basic but missing headers", + prepare: func(control *config.Control, req *http.Request) { + req.SetBasicAuth("node", control.AgentToken) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return HaveHTTPStatus(http.StatusBadRequest) + }, + }, + sub{ + name: "valid cert but missing headers", + prepare: func(control *config.Control, req *http.Request) { + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:" + control.ServerNodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return HaveHTTPStatus(http.StatusBadRequest) + }, + }, + sub{ + name: "valid basic legacy key", + prepare: func(control *config.Control, req *http.Request) { + req.Header.Add("k3s-Node-Name", control.ServerNodeName) + req.Header.Add("k3s-Node-Password", "password") + req.SetBasicAuth("node", control.AgentToken) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(ContainSubstring("PRIVATE KEY")), + ) + }, + }, + sub{ + name: "valid cert legacy key", + prepare: func(control *config.Control, req *http.Request) { + req.Header.Add("k3s-Node-Name", control.ServerNodeName) + req.Header.Add("k3s-Node-Password", "password") + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:" + control.ServerNodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(ContainSubstring("PRIVATE KEY")), + ) + }, + }, + ), + }, { + method: http.MethodPost, + path: "/v1-k3s/client-kubelet.crt", + subs: append(genericFailures, + sub{ + name: "valid basic client key", + prepare: func(control *config.Control, req *http.Request) { + req.Header.Add("k3s-Node-Name", control.ServerNodeName) + req.Header.Add("k3s-Node-Password", "password") + withCertificateRequest(req) + req.SetBasicAuth("node", control.AgentToken) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(Not(ContainSubstring("PRIVATE KEY"))), + ) + }, + }, + sub{ + name: "valid cert client key", + prepare: func(control *config.Control, req *http.Request) { + req.Header.Add("k3s-Node-Name", control.ServerNodeName) + req.Header.Add("k3s-Node-Password", "password") + withCertificateRequest(req) + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:" + control.ServerNodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(Not(ContainSubstring("PRIVATE KEY"))), + ) + }, + }, + ), + }, + //** paths accessible with node cert or agent token ** + { + method: http.MethodGet, + path: "/v1-k3s/client-kube-proxy.crt", + subs: append(genericFailures, + sub{ + name: "valid basic legacy key", + prepare: func(control *config.Control, req *http.Request) { + req.SetBasicAuth("node", control.AgentToken) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(ContainSubstring("PRIVATE KEY")), + ) + }, + }, + sub{ + name: "valid cert legacy key", + prepare: func(control *config.Control, req *http.Request) { + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:" + control.ServerNodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(ContainSubstring("PRIVATE KEY")), + ) + }, + }, + ), + }, { + method: http.MethodPost, + path: "/v1-k3s/client-kube-proxy.crt", + subs: append(genericFailures, + sub{ + name: "valid basic client key", + prepare: func(control *config.Control, req *http.Request) { + withCertificateRequest(req) + req.SetBasicAuth("node", control.AgentToken) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(Not(ContainSubstring("PRIVATE KEY"))), + ) + }, + }, + sub{ + name: "valid cert client key", + prepare: func(control *config.Control, req *http.Request) { + withCertificateRequest(req) + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:" + control.ServerNodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(Not(ContainSubstring("PRIVATE KEY"))), + ) + }, + }, + ), + }, { + method: http.MethodGet, + path: "/v1-k3s/client-k3s-controller.crt", + subs: append(genericFailures, + sub{ + name: "valid basic legacy key", + prepare: func(control *config.Control, req *http.Request) { + req.SetBasicAuth("node", control.AgentToken) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(ContainSubstring("PRIVATE KEY")), + ) + }, + }, + sub{ + name: "valid cert legacy key", + prepare: func(control *config.Control, req *http.Request) { + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:" + control.ServerNodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(ContainSubstring("PRIVATE KEY")), + ) + }, + }, + ), + }, { + method: http.MethodPost, + path: "/v1-k3s/client-k3s-controller.crt", + subs: append(genericFailures, + sub{ + name: "valid basic client key", + prepare: func(control *config.Control, req *http.Request) { + withCertificateRequest(req) + req.SetBasicAuth("node", control.AgentToken) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(Not(ContainSubstring("PRIVATE KEY"))), + ) + }, + }, + sub{ + name: "valid cert client key", + prepare: func(control *config.Control, req *http.Request) { + withCertificateRequest(req) + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:" + control.ServerNodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(Not(ContainSubstring("PRIVATE KEY"))), + ) + }, + }, + ), + }, { + method: http.MethodGet, + path: "/v1-k3s/client-ca.crt", + subs: append(genericFailures, + sub{ + name: "valid basic", + prepare: func(control *config.Control, req *http.Request) { + req.SetBasicAuth("node", control.AgentToken) + }, + match: func(control *config.Control) types.GomegaMatcher { + certs, _ := os.ReadFile(control.Runtime.ClientCA) + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(certs), + ) + }, + }, + sub{ + name: "valid cert", + prepare: func(control *config.Control, req *http.Request) { + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:" + control.ServerNodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + certs, _ := os.ReadFile(control.Runtime.ClientCA) + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(certs), + ) + }, + }, + ), + }, { + method: http.MethodGet, + path: "/v1-k3s/server-ca.crt", + subs: append(genericFailures, + sub{ + name: "valid basic", + prepare: func(control *config.Control, req *http.Request) { + req.SetBasicAuth("node", control.AgentToken) + }, + match: func(_ *config.Control) types.GomegaMatcher { + certs, _ := os.ReadFile(control.Runtime.ServerCA) + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(certs), + ) + }, + }, + sub{ + name: "valid cert", + prepare: func(control *config.Control, req *http.Request) { + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:" + control.ServerNodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + certs, _ := os.ReadFile(control.Runtime.ServerCA) + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(certs), + ) + }, + }, + ), + }, { + method: http.MethodGet, + path: "/v1-k3s/apiservers", + subs: append(genericFailures, + sub{ + name: "valid basic", + prepare: func(control *config.Control, req *http.Request) { + req.SetBasicAuth("node", control.AgentToken) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPHeaderWithValue("content-type", "application/json"), + ) + }, + }, + sub{ + name: "valid cert", + prepare: func(control *config.Control, req *http.Request) { + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:" + control.ServerNodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPHeaderWithValue("content-type", "application/json"), + ) + }, + }, + ), + }, { + method: http.MethodGet, + path: "/v1-k3s/config", + subs: append(genericFailures, + sub{ + name: "valid basic", + prepare: func(control *config.Control, req *http.Request) { + req.SetBasicAuth("node", control.AgentToken) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPHeaderWithValue("content-type", "application/json"), + ) + }, + }, + sub{ + name: "valid cert", + prepare: func(control *config.Control, req *http.Request) { + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:" + control.ServerNodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPHeaderWithValue("content-type", "application/json"), + ) + }, + }, + ), + }, { + method: http.MethodGet, + path: "/v1-k3s/readyz", + subs: append(genericFailures, + sub{ + name: "valid basic", + prepare: func(control *config.Control, req *http.Request) { + req.SetBasicAuth("node", control.AgentToken) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody("ok"), + ) + }, + }, + sub{ + name: "valid cert", + prepare: func(control *config.Control, req *http.Request) { + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:" + control.ServerNodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody("ok"), + ) + }, + }, + ), + }, + //** paths accessible with node cert ** + { + method: http.MethodGet, + path: "/v1-k3s/connect", + subs: append(genericFailures, + sub{ + name: "valid cert", + prepare: func(control *config.Control, req *http.Request) { + withNewClientCert(req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientKubeletKey, certutil.Config{ + CommonName: "system:node:" + control.ServerNodeName, + Organization: []string{user.NodesGroup}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return HaveHTTPStatus(http.StatusOK) + }, + }, + ), + }, + //** paths accessible with server token ** + { + method: http.MethodGet, + path: "/v1-k3s/encrypt/status", + subs: append(genericFailures, + sub{ + name: "valid basic", + prepare: func(control *config.Control, req *http.Request) { + req.SetBasicAuth("server", control.Token) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return HaveHTTPStatus(http.StatusOK) + }, + }, + ), + }, { + method: http.MethodGet, + path: "/v1-k3s/encrypt/config", + subs: append(genericFailures, + sub{ + name: "valid basic", + prepare: func(control *config.Control, req *http.Request) { + req.SetBasicAuth("server", control.Token) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return HaveHTTPStatus(http.StatusMethodNotAllowed) + }, + }, + ), + }, { + method: http.MethodGet, + path: "/v1-k3s/cert/cacerts", + subs: append(genericFailures, + sub{ + name: "valid basic", + prepare: func(control *config.Control, req *http.Request) { + req.SetBasicAuth("server", control.Token) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return HaveHTTPStatus(http.StatusMethodNotAllowed) + }, + }, + ), + }, { + method: http.MethodGet, + path: "/v1-k3s/server-bootstrap", + subs: append(genericFailures, + sub{ + name: "valid basic", + prepare: func(control *config.Control, req *http.Request) { + req.SetBasicAuth("server", control.Token) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusBadRequest), + HaveHTTPBody(ContainSubstring("etcd disabled")), + ) + }, + }, + ), + }, { + method: http.MethodGet, + path: "/v1-k3s/token", + subs: append(genericFailures, + sub{ + name: "valid basic", + prepare: func(control *config.Control, req *http.Request) { + req.SetBasicAuth("server", control.Token) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return HaveHTTPStatus(http.StatusMethodNotAllowed) + }, + }, + ), + }, + //** paths accessible with apiserver cert ** + { + method: http.MethodConnect, + path: "/", + subs: append(genericFailures, + sub{ + name: "valid cert", + prepare: func(control *config.Control, req *http.Request) { + withClientCert(req, control.Runtime.ClientKubeAPICert) + }, + match: func(_ *config.Control) types.GomegaMatcher { + return HaveHTTPStatus(http.StatusOK) + }, + }, + ), + }, + //** paths accessible anonymously ** + { + method: http.MethodGet, + path: "/ping", + subs: []sub{ + { + name: "anonymous", + match: func(_ *config.Control) types.GomegaMatcher { + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody("pong"), + ) + }, + }, + }, + }, { + method: http.MethodGet, + path: "/cacerts", + subs: []sub{ + { + name: "anonymous", + match: func(control *config.Control) types.GomegaMatcher { + certs, _ := os.ReadFile(control.Runtime.ServerCA) + return And( + HaveHTTPStatus(http.StatusOK), + HaveHTTPBody(certs), + ) + }, + }, + }, + }, + } + + for _, tt := range tests { + + t.Run(tt.method+" "+tt.path, func(t *testing.T) { + for _, ss := range tt.subs { + t.Run("handles "+ss.name+" request", func(t *testing.T) { + req := httptest.NewRequest(tt.method, tt.path, nil) + + if ss.prepare != nil { + ss.prepare(control, req) + } + + resp := httptest.NewRecorder() + control.Runtime.Handler.ServeHTTP(resp, req) + t.Logf("Validating response: %s %s %s", resp.Result().Proto, resp.Result().Status, resp.Result().Header.Get("Content-Type")) + NewWithT(t).Expect(resp).To(ss.match(control)) + }) + } + }) + } + + os.Unsetenv("NODE_NAME") + testutil.CleanupDataDir(control) + cancel() +} + +func withClientCert(req *http.Request, certFile string) { + bytes, err := os.ReadFile(certFile) + if err != nil { + panic(err) + } + certs, err := certutil.ParseCertsPEM(bytes) + if err != nil { + panic(err) + } + req.TLS = &tls.ConnectionState{ + PeerCertificates: certs, + } +} + +func withNewClientCert(req *http.Request, caCertFile, caKeyFile, signingKeyFile string, certConfig certutil.Config) { + caCerts, caKey, err := getCACertAndKey(caCertFile, caKeyFile) + if err != nil { + panic(err) + } + keyBytes, err := os.ReadFile(signingKeyFile) + if err != nil { + panic(err) + } + key, err := certutil.ParsePrivateKeyPEM(keyBytes) + if err != nil { + panic(err) + } + cert, err := certutil.NewSignedCert(certConfig, key.(crypto.Signer), caCerts[0], caKey) + if err != nil { + panic(err) + } + + req.TLS = &tls.ConnectionState{} + req.TLS.PeerCertificates = append(req.TLS.PeerCertificates, cert) + req.TLS.PeerCertificates = append(req.TLS.PeerCertificates, caCerts...) +} + +func withCertificateRequest(req *http.Request) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + csr, err := x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{}, key) + if err != nil { + panic(err) + } + req.Body = io.NopCloser(bytes.NewReader(csr)) +} diff --git a/tests/mock/core.go b/tests/mock/core.go index b42be171c047..e97f20845399 100644 --- a/tests/mock/core.go +++ b/tests/mock/core.go @@ -224,6 +224,16 @@ func (m *NodeStore) Create(node *v1.Node) (*v1.Node, error) { return node, nil } +func (m *NodeStore) Get(name string) (*v1.Node, error) { + if m.nodes == nil { + return nil, ErrorNotFound("node", name) + } + if node, ok := m.nodes[name]; ok { + return &node, nil + } + return nil, ErrorNotFound("node", name) +} + func (m *NodeStore) List(ls labels.Selector) ([]v1.Node, error) { nodes := []v1.Node{} if ls == nil {