Skip to content

Commit

Permalink
Fix issue with loadbalancer failover to default server
Browse files Browse the repository at this point in the history
The loadbalancer should only fail over to the default server if all other server have failed, and it should force fail-back to a preferred server as soon as one passes health checks.

The loadbalancer tests have been improved to ensure that this occurs.

Signed-off-by: Brad Davidson <[email protected]>
  • Loading branch information
brandond committed Nov 14, 2024
1 parent b93fd98 commit cd4dded
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 41 deletions.
2 changes: 2 additions & 0 deletions pkg/agent/loadbalancer/loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net
if !allChecksFailed {
defer server.closeAll()
}
} else {
logrus.Debugf("Dial health check failed for %s", targetServer)
}

newServer, err := lb.nextServer(targetServer)
Expand Down
186 changes: 157 additions & 29 deletions pkg/agent/loadbalancer/loadbalancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"testing"
"time"

"github.com/k3s-io/k3s/pkg/cli/cmds"
"github.com/sirupsen/logrus"
)

Expand All @@ -24,7 +23,7 @@ type testServer struct {
prefix string
}

func createServer(prefix string) (*testServer, error) {
func createServer(ctx context.Context, prefix string) (*testServer, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, err
Expand All @@ -34,6 +33,10 @@ func createServer(prefix string) (*testServer, error) {
listener: listener,
}
go s.serve()
go func() {
<-ctx.Done()
s.close()
}()
return s, nil
}

Expand All @@ -49,6 +52,7 @@ func (s *testServer) serve() {
}

func (s *testServer) close() {
logrus.Printf("testServer %s closing", s.prefix)
s.listener.Close()
for _, conn := range s.conns {
conn.Close()
Expand All @@ -65,6 +69,10 @@ func (s *testServer) echo(conn net.Conn) {
}
}

func (s *testServer) address() string {
return s.listener.Addr().String()
}

func ping(conn net.Conn) (string, error) {
fmt.Fprintf(conn, "ping\n")
result, err := bufio.NewReader(conn).ReadString('\n')
Expand All @@ -74,25 +82,31 @@ func ping(conn net.Conn) (string, error) {
return strings.TrimSpace(result), nil
}

// Test_UnitFailOver creates a LB using a default server (ie fixed registration endpoint)
// and then adds a new server (a node). The node server is then closed, and it is confirmed
// that new connections use the default server.
func Test_UnitFailOver(t *testing.T) {
tmpDir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ogServe, err := createServer("og")
defaultServer, err := createServer(ctx, "default")
if err != nil {
t.Fatalf("createServer(og) failed: %v", err)
t.Fatalf("createServer(default) failed: %v", err)
}

lbServe, err := createServer("lb")
node1Server, err := createServer(ctx, "node1")
if err != nil {
t.Fatalf("createServer(lb) failed: %v", err)
t.Fatalf("createServer(node1) failed: %v", err)
}

cfg := cmds.Agent{
ServerURL: fmt.Sprintf("http://%s/", ogServe.listener.Addr().String()),
DataDir: tmpDir,
node2Server, err := createServer(ctx, "node2")
if err != nil {
t.Fatalf("createServer(node2) failed: %v", err)
}

lb, err := New(context.TODO(), cfg.DataDir, SupervisorServiceName, cfg.ServerURL, RandomPort, false)
// start the loadbalancer with the default server as the only server
lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address(), RandomPort, false)
if err != nil {
t.Fatalf("New() failed: %v", err)
}
Expand All @@ -103,50 +117,123 @@ func Test_UnitFailOver(t *testing.T) {
}
localAddress := parsedURL.Host

lb.Update([]string{lbServe.listener.Addr().String()})
// add the node as a new server address.
lb.Update([]string{node1Server.address()})

// make sure connections go to the node
conn1, err := net.Dial("tcp", localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)
}
result1, err := ping(conn1)
if err != nil {
if result, err := ping(conn1); err != nil {
t.Fatalf("ping(conn1) failed: %v", err)
} else if result != "node1:ping" {
t.Fatalf("Unexpected ping(conn1) result: %v", result)
}
if result1 != "lb:ping" {
t.Fatalf("Unexpected ping result: %v", result1)
}

lbServe.close()
t.Log("conn1 tested OK")

// set failing health check for node 1
lb.SetHealthCheck(node1Server.address(), func() bool { return false })

// Server connections are checked every second, now that node 1 is failed
// the connections to it should be closed.
time.Sleep(2 * time.Second)

_, err = ping(conn1)
if err == nil {
if _, err := ping(conn1); err == nil {
t.Fatal("Unexpected successful ping on closed connection conn1")
}

t.Log("conn1 closed on failure OK")

// make sure connection still goes to the first node - it is failing health checks but so
// is the default endpoint, so it should be tried first with health checks disabled,
// before failing back to the default.
conn2, err := net.Dial("tcp", localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)

}
result2, err := ping(conn2)
if err != nil {
if result, err := ping(conn2); err != nil {
t.Fatalf("ping(conn2) failed: %v", err)
} else if result != "node1:ping" {
t.Fatalf("Unexpected ping(conn2) result: %v", result)
}
if result2 != "og:ping" {
t.Fatalf("Unexpected ping result: %v", result2)

t.Log("conn2 tested OK")

// make sure the health checks don't close the connection we just made -
// connections should only be closed when it transitions from health to unhealthy.
time.Sleep(2 * time.Second)

if result, err := ping(conn2); err != nil {
t.Fatalf("ping(conn2) failed: %v", err)
} else if result != "node1:ping" {
t.Fatalf("Unexpected ping(conn2) result: %v", result)
}

t.Log("conn2 tested OK again")

// shut down the first node server to force failover to the default
node1Server.close()

// make sure new connections go to the default, and existing connections are closed
conn3, err := net.Dial("tcp", localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)

}
if result, err := ping(conn3); err != nil {
t.Fatalf("ping(conn3) failed: %v", err)
} else if result != "default:ping" {
t.Fatalf("Unexpected ping(conn3) result: %v", result)
}

t.Log("conn3 tested OK")

if _, err := ping(conn2); err == nil {
t.Fatal("Unexpected successful ping on closed connection conn2")
}

t.Log("conn2 closed on failure OK")

// add the second node as a new server address.
lb.Update([]string{node2Server.address()})

// make sure connection now goes to the second node,
// and connections to the default are closed.
conn4, err := net.Dial("tcp", localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)

}
if result, err := ping(conn4); err != nil {
t.Fatalf("ping(conn4) failed: %v", err)
} else if result != "node2:ping" {
t.Fatalf("Unexpected ping(conn4) result: %v", result)
}

t.Log("conn4 tested OK")

// Server connections are checked every second, now that we have a healthy
// server, connections to the default server should be closed
time.Sleep(2 * time.Second)

if _, err := ping(conn3); err == nil {
t.Fatal("Unexpected successful ping on connection conn3")
}

t.Log("conn3 closed on failure OK")
}

// Test_UnitFailFast confirms that connnections to invalid addresses fail quickly
func Test_UnitFailFast(t *testing.T) {
tmpDir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

cfg := cmds.Agent{
ServerURL: "http://127.0.0.1:0/",
DataDir: tmpDir,
}

lb, err := New(context.TODO(), cfg.DataDir, SupervisorServiceName, cfg.ServerURL, RandomPort, false)
serverURL := "http://127.0.0.1:0/"
lb, err := New(ctx, tmpDir, SupervisorServiceName, serverURL, RandomPort, false)
if err != nil {
t.Fatalf("New() failed: %v", err)
}
Expand All @@ -172,3 +259,44 @@ func Test_UnitFailFast(t *testing.T) {
t.Fatal("Test timed out")
}
}

// Test_UnitFailUnreachable confirms that connnections to unreachable addresses do fail
// within the expected duration
func Test_UnitFailUnreachable(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test in short mode.")
}
tmpDir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

serverAddr := "192.0.2.1:6443"
lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+serverAddr, RandomPort, false)
if err != nil {
t.Fatalf("New() failed: %v", err)
}

// Set failing health check to reduce retries
lb.SetHealthCheck(serverAddr, func() bool { return false })

conn, err := net.Dial("tcp", lb.localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)
}

done := make(chan error)
go func() {
_, err = ping(conn)
done <- err
}()
timeout := time.After(11 * time.Second)

select {
case err := <-done:
if err == nil {
t.Fatal("Unexpected successful ping from unreachable address")
}
case <-timeout:
t.Fatal("Test timed out")
}
}
Loading

0 comments on commit cd4dded

Please sign in to comment.