diff --git a/windows/svc/example/service.go b/windows/svc/example/service.go index 373da64dd..e6ca268e1 100644 --- a/windows/svc/example/service.go +++ b/windows/svc/example/service.go @@ -8,9 +8,11 @@ package main import ( "fmt" + "os" "strings" "time" + "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/debug" "golang.org/x/sys/windows/svc/eventlog" @@ -27,9 +29,17 @@ func (m *exampleService) Execute(args []string, r <-chan svc.ChangeRequest, chan slowtick := time.Tick(2 * time.Second) tick := fasttick changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} + + // Simulate failure after 5 seconds + failureTimer := time.NewTimer(5 * time.Second) + defer failureTimer.Stop() + loop: for { select { + case <-failureTimer.C: + // Simulate failure by returning a non-zero exit code + return false, uint32(windows.ERROR_UNEXP_NET_ERR) case <-tick: beep() elog.Info(1, "beep") @@ -81,6 +91,11 @@ func runService(name string, isDebug bool) { err = run(name, &exampleService{}) if err != nil { elog.Error(1, fmt.Sprintf("%s service failed: %v", name, err)) + if exitErr, ok := err.(*svc.ExitError); ok { + os.Exit(int(exitErr.Code)) + } else { + os.Exit(1) + } return } elog.Info(1, fmt.Sprintf("%s service stopped", name)) diff --git a/windows/svc/service.go b/windows/svc/service.go index c4f74924d..fdece6ac0 100644 --- a/windows/svc/service.go +++ b/windows/svc/service.go @@ -9,6 +9,7 @@ package svc import ( "errors" + "fmt" "sync" "unsafe" @@ -132,10 +133,11 @@ type ctlEvent struct { // service provides access to windows service api. type service struct { - name string - h windows.Handle - c chan ctlEvent - handler Handler + name string + h windows.Handle + c chan ctlEvent + handler Handler + exitCode uint32 } type exitCode struct { @@ -143,6 +145,14 @@ type exitCode struct { errno uint32 } +type ExitError struct { + Code uint32 +} + +func (e *ExitError) Error() string { + return fmt.Sprintf("service exited with error code %d", e.Code) +} + func (s *service) updateStatus(status *Status, ec *exitCode) error { if s.h == 0 { return errors.New("updateStatus with no service status handle") @@ -274,6 +284,7 @@ loop: } theService.updateStatus(&Status{State: Stopped}, &ec) + theService.exitCode = ec.errno return windows.NO_ERROR } diff --git a/windows/svc/svc_test.go b/windows/svc/svc_test.go index cd2cd467c..91e2838f2 100644 --- a/windows/svc/svc_test.go +++ b/windows/svc/svc_test.go @@ -239,3 +239,106 @@ func TestIsWindowsServiceWhenParentExits(t *testing.T) { } } } + +func TestServiceRestart(t *testing.T) { + if os.Getenv("GO_BUILDER_NAME") == "" { + // Don't install services on arbitrary users' machines. + t.Skip("Skipping test that modifies system services: GO_BUILDER_NAME not set") + } + if testing.Short() { + t.Skip("Skipping test in short mode that modifies system services") + } + + const name = "svctestservice" + + m, err := mgr.Connect() + if err != nil { + t.Fatalf("SCM connection failed: %v", err) + } + defer m.Disconnect() + + // Build the service executable + exepath := filepath.Join(t.TempDir(), "a.exe") + o, err := exec.Command("go", "build", "-o", exepath, "golang.org/x/sys/windows/svc/example").CombinedOutput() + if err != nil { + t.Fatalf("Failed to build service program: %v\n%v", err, string(o)) + } + + // Ensure any existing service is stopped and deleted + stopAndDeleteIfInstalled(t, m, name) + + // Create the service + s, err := m.CreateService(name, exepath, mgr.Config{DisplayName: "x-sys svc test service"}) + if err != nil { + t.Fatalf("CreateService(%s) failed: %v", name, err) + } + defer s.Close() + + // Set the service to restart on failure + actions := []mgr.RecoveryAction{ + {Type: mgr.ServiceRestart, Delay: 1 * time.Second}, // Restart after 1 second + } + err = s.SetRecoveryActions(actions, 0) + if err != nil { + t.Fatalf("Failed to set service recovery actions: %v", err) + } + + // Set the flag to perform recovery actions on non-crash failures + err = s.SetRecoveryActionsOnNonCrashFailures(true) + if err != nil { + t.Fatalf("Failed to set RecoveryActionsOnNonCrashFailures: %v", err) + } + + // Start the service + testState(t, s, svc.Stopped) + err = s.Start() + if err != nil { + t.Fatalf("Start(%s) failed: %v", s.Name, err) + } + + // Wait for the service to start + waitState(t, s, svc.Running) + + // Get the initial process ID + status, err := s.Query() + if err != nil { + t.Fatalf("Query(%s) failed: %v", s.Name, err) + } + initialPID := status.ProcessId + t.Logf("Initial PID: %d", initialPID) + + // Wait up to 30 seconds for the PID to change, indicating a restart + var newPID uint32 + success := false + for i := 0; i < 30; i++ { + time.Sleep(1 * time.Second) + + status, err = s.Query() + if err != nil { + t.Fatalf("Query(%s) failed: %v", s.Name, err) + } + newPID = status.ProcessId + + if newPID != 0 && newPID != initialPID { + success = true + t.Logf("Service restarted successfully, new PID: %d", newPID) + break + } + } + + if !success { + t.Fatalf("Service did not restart within the expected time") + } + + // Cleanup: Stop and delete the service + _, err = s.Control(svc.Stop) + if err != nil { + t.Fatalf("Control(%s) failed: %v", s.Name, err) + } + waitState(t, s, svc.Stopped) + + err = s.Delete() + if err != nil { + t.Fatalf("Delete failed: %v", err) + } +}